diff --git a/api/admin/accounts.go b/api/admin/accounts.go index fd0859c..d6db154 100644 --- a/api/admin/accounts.go +++ b/api/admin/accounts.go @@ -6,9 +6,9 @@ import ( "net/http" "strconv" + "github.com/dhax/go-base/auth/pwdless" "github.com/go-ozzo/ozzo-validation" - "github.com/dhax/go-base/auth" "github.com/go-chi/chi" "github.com/go-chi/render" ) @@ -20,11 +20,11 @@ var ( // AccountStore defines database operations for account management. type AccountStore interface { - List(f auth.AccountFilter) ([]auth.Account, int, error) - Create(*auth.Account) error - Get(id int) (*auth.Account, error) - Update(*auth.Account) error - Delete(*auth.Account) error + List(f pwdless.AccountFilter) ([]pwdless.Account, int, error) + Create(*pwdless.Account) error + Get(id int) (*pwdless.Account, error) + Update(*pwdless.Account) error + Delete(*pwdless.Account) error } // AccountResource implements account management handler. @@ -70,7 +70,7 @@ func (rs *AccountResource) accountCtx(next http.Handler) http.Handler { } type accountRequest struct { - *auth.Account + *pwdless.Account } func (d *accountRequest) Bind(r *http.Request) error { @@ -78,20 +78,20 @@ func (d *accountRequest) Bind(r *http.Request) error { } type accountResponse struct { - *auth.Account + *pwdless.Account } -func newAccountResponse(a *auth.Account) *accountResponse { +func newAccountResponse(a *pwdless.Account) *accountResponse { resp := &accountResponse{Account: a} return resp } type accountListResponse struct { - Accounts []auth.Account `json:"accounts"` - Count int `json:"count"` + Accounts []pwdless.Account `json:"accounts"` + Count int `json:"count"` } -func newAccountListResponse(a []auth.Account, count int) *accountListResponse { +func newAccountListResponse(a []pwdless.Account, count int) *accountListResponse { resp := &accountListResponse{ Accounts: a, Count: count, @@ -100,7 +100,7 @@ func newAccountListResponse(a []auth.Account, count int) *accountListResponse { } func (rs *AccountResource) list(w http.ResponseWriter, r *http.Request) { - f := auth.NewAccountFilter(r.URL.Query()) + f := pwdless.NewAccountFilter(r.URL.Query()) al, count, err := rs.Store.List(f) if err != nil { render.Render(w, r, ErrRender(err)) @@ -129,12 +129,12 @@ func (rs *AccountResource) create(w http.ResponseWriter, r *http.Request) { } func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) { - acc := r.Context().Value(ctxAccount).(*auth.Account) + acc := r.Context().Value(ctxAccount).(*pwdless.Account) render.Respond(w, r, newAccountResponse(acc)) } func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) { - acc := r.Context().Value(ctxAccount).(*auth.Account) + acc := r.Context().Value(ctxAccount).(*pwdless.Account) data := &accountRequest{Account: acc} if err := render.Bind(r, data); err != nil { render.Render(w, r, ErrInvalidRequest(err)) @@ -155,7 +155,7 @@ func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) { } func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) { - acc := r.Context().Value(ctxAccount).(*auth.Account) + acc := r.Context().Value(ctxAccount).(*pwdless.Account) if err := rs.Store.Delete(acc); err != nil { render.Render(w, r, ErrInvalidRequest(err)) return diff --git a/api/admin/api.go b/api/admin/api.go index a18bab7..c91a433 100644 --- a/api/admin/api.go +++ b/api/admin/api.go @@ -9,7 +9,7 @@ import ( "github.com/go-chi/chi" "github.com/go-pg/pg" - "github.com/dhax/go-base/auth" + "github.com/dhax/go-base/auth/authorize" "github.com/dhax/go-base/database" "github.com/dhax/go-base/logging" ) @@ -44,7 +44,7 @@ func NewAPI(db *pg.DB) (*API, error) { // Router provides admin application routes. func (a *API) Router() *chi.Mux { r := chi.NewRouter() - r.Use(auth.RequiresRole(roleAdmin)) + r.Use(authorize.RequiresRole(roleAdmin)) r.Get("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello Admin")) diff --git a/api/api.go b/api/api.go index abfde1b..eb3de06 100644 --- a/api/api.go +++ b/api/api.go @@ -10,7 +10,8 @@ import ( "github.com/dhax/go-base/api/admin" "github.com/dhax/go-base/api/app" - "github.com/dhax/go-base/auth" + "github.com/dhax/go-base/auth/jwt" + "github.com/dhax/go-base/auth/pwdless" "github.com/dhax/go-base/database" "github.com/dhax/go-base/email" "github.com/dhax/go-base/logging" @@ -37,7 +38,7 @@ func New() (*chi.Mux, error) { } authStore := database.NewAuthStore(db) - authResource, err := auth.NewResource(authStore, mailer) + authResource, err := pwdless.NewResource(authStore, mailer) if err != nil { logger.WithField("module", "auth").Error(err) return nil, err @@ -70,8 +71,8 @@ func New() (*chi.Mux, error) { r.Mount("/auth", authResource.Router()) r.Group(func(r chi.Router) { - r.Use(authResource.Token.Verifier()) - r.Use(auth.Authenticator) + r.Use(authResource.TokenAuth.Verifier()) + r.Use(jwt.Authenticator) r.Mount("/admin", adminAPI.Router()) r.Mount("/api", appAPI.Router()) }) diff --git a/api/app/account.go b/api/app/account.go index 2b3adfd..e76e5de 100644 --- a/api/app/account.go +++ b/api/app/account.go @@ -10,16 +10,17 @@ import ( "github.com/go-chi/render" validation "github.com/go-ozzo/ozzo-validation" - "github.com/dhax/go-base/auth" + "github.com/dhax/go-base/auth/jwt" + "github.com/dhax/go-base/auth/pwdless" ) // AccountStore defines database operations for account. type AccountStore interface { - Get(id int) (*auth.Account, error) - Update(*auth.Account) error - Delete(*auth.Account) error - UpdateToken(*auth.Token) error - DeleteToken(*auth.Token) error + Get(id int) (*pwdless.Account, error) + Update(*pwdless.Account) error + Delete(*pwdless.Account) error + UpdateToken(*jwt.Token) error + DeleteToken(*jwt.Token) error } // AccountResource implements account management handler. @@ -49,7 +50,7 @@ func (rs *AccountResource) router() *chi.Mux { func (rs *AccountResource) accountCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims := auth.ClaimsFromCtx(r.Context()) + claims := jwt.ClaimsFromCtx(r.Context()) log(r).WithField("account_id", claims.ID) account, err := rs.Store.Get(claims.ID) if err != nil { @@ -63,7 +64,7 @@ func (rs *AccountResource) accountCtx(next http.Handler) http.Handler { } type accountRequest struct { - *auth.Account + *pwdless.Account // override protected data here, although not really necessary here // as we limit updated database columns in store as well ProtectedID int `json:"id"` @@ -78,21 +79,21 @@ func (d *accountRequest) Bind(r *http.Request) error { } type accountResponse struct { - *auth.Account + *pwdless.Account } -func newAccountResponse(a *auth.Account) *accountResponse { +func newAccountResponse(a *pwdless.Account) *accountResponse { resp := &accountResponse{Account: a} return resp } func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) { - acc := r.Context().Value(ctxAccount).(*auth.Account) + acc := r.Context().Value(ctxAccount).(*pwdless.Account) render.Respond(w, r, newAccountResponse(acc)) } func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) { - acc := r.Context().Value(ctxAccount).(*auth.Account) + acc := r.Context().Value(ctxAccount).(*pwdless.Account) data := &accountRequest{Account: acc} if err := render.Bind(r, data); err != nil { render.Render(w, r, ErrInvalidRequest(err)) @@ -113,7 +114,7 @@ func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) { } func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) { - acc := r.Context().Value(ctxAccount).(*auth.Account) + acc := r.Context().Value(ctxAccount).(*pwdless.Account) if err := rs.Store.Delete(acc); err != nil { render.Render(w, r, ErrRender(err)) return @@ -142,10 +143,10 @@ func (rs *AccountResource) updateToken(w http.ResponseWriter, r *http.Request) { render.Render(w, r, ErrInvalidRequest(err)) return } - acc := r.Context().Value(ctxAccount).(*auth.Account) + acc := r.Context().Value(ctxAccount).(*pwdless.Account) for _, t := range acc.Token { if t.ID == id { - if err := rs.Store.UpdateToken(&auth.Token{ + if err := rs.Store.UpdateToken(&jwt.Token{ ID: t.ID, Identifier: data.Identifier, }); err != nil { @@ -163,10 +164,10 @@ func (rs *AccountResource) deleteToken(w http.ResponseWriter, r *http.Request) { render.Render(w, r, ErrBadRequest) return } - acc := r.Context().Value(ctxAccount).(*auth.Account) + acc := r.Context().Value(ctxAccount).(*pwdless.Account) for _, t := range acc.Token { if t.ID == id { - rs.Store.DeleteToken(&auth.Token{ID: t.ID}) + rs.Store.DeleteToken(&jwt.Token{ID: t.ID}) } } render.Respond(w, r, http.NoBody) diff --git a/api/app/profile.go b/api/app/profile.go index 00aa200..3744efa 100644 --- a/api/app/profile.go +++ b/api/app/profile.go @@ -4,7 +4,7 @@ import ( "context" "net/http" - "github.com/dhax/go-base/auth" + "github.com/dhax/go-base/auth/jwt" "github.com/dhax/go-base/models" "github.com/go-chi/chi" "github.com/go-chi/render" @@ -39,7 +39,7 @@ func (rs *ProfileResource) router() *chi.Mux { func (rs *ProfileResource) profileCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims := auth.ClaimsFromCtx(r.Context()) + claims := jwt.ClaimsFromCtx(r.Context()) p, err := rs.Store.Get(claims.ID) if err != nil { log(r).WithField("profileCtx", claims.Sub).Error(err) diff --git a/auth/api.go b/auth/api.go deleted file mode 100644 index 0c063b3..0000000 --- a/auth/api.go +++ /dev/null @@ -1,92 +0,0 @@ -// Package auth provides JSON Web Token (JWT) authentication and authorization middleware. -// It implements a passwordless authentication flow by sending login tokens vie email which are then exchanged for JWT access and refresh tokens. -package auth - -import ( - "net/http" - "time" - - "github.com/dhax/go-base/email" - "github.com/dhax/go-base/logging" - "github.com/go-chi/chi" - "github.com/go-chi/render" - "github.com/sirupsen/logrus" -) - -// Storer defines database operations on account and token data. -type Storer interface { - GetByID(id int) (*Account, error) - GetByEmail(email string) (*Account, error) - GetByRefreshToken(token string) (*Account, *Token, error) - UpdateAccount(a *Account) error - SaveRefreshToken(t *Token) error - DeleteRefreshToken(t *Token) error - PurgeExpiredToken() error -} - -// Mailer defines methods to send account emails. -type Mailer interface { - LoginToken(name, email string, c email.ContentLoginToken) error -} - -// Resource implements passwordless token authentication against a database. -type Resource struct { - Login *LoginTokenAuth - Token *TokenAuth - store Storer - mailer Mailer -} - -// NewResource returns a configured authentication resource. -func NewResource(store Storer, mailer Mailer) (*Resource, error) { - loginAuth, err := NewLoginTokenAuth() - if err != nil { - return nil, err - } - - tokenAuth, err := NewTokenAuth() - if err != nil { - return nil, err - } - - resource := &Resource{ - Login: loginAuth, - Token: tokenAuth, - store: store, - mailer: mailer, - } - - resource.cleanupTicker() - - return resource, nil -} - -// Router provides necessary routes for passwordless authentication flow. -func (rs *Resource) Router() *chi.Mux { - r := chi.NewRouter() - r.Use(render.SetContentType(render.ContentTypeJSON)) - r.Post("/login", rs.login) - r.Post("/token", rs.token) - r.Group(func(r chi.Router) { - r.Use(rs.Token.Verifier()) - r.Use(AuthenticateRefreshJWT) - r.Post("/refresh", rs.refresh) - r.Post("/logout", rs.logout) - }) - return r -} - -func (rs *Resource) cleanupTicker() { - ticker := time.NewTicker(time.Hour * 1) - go func() { - for range ticker.C { - if err := rs.store.PurgeExpiredToken(); err != nil { - logging.Logger.WithField("auth", "cleanup").Error(err) - } - } - }() -} - -func log(r *http.Request) logrus.FieldLogger { - return logging.GetLogEntry(r) -} diff --git a/auth/authorize/errors.go b/auth/authorize/errors.go new file mode 100644 index 0000000..392094f --- /dev/null +++ b/auth/authorize/errors.go @@ -0,0 +1,31 @@ +package authorize + +import ( + "net/http" + + "github.com/go-chi/render" +) + +// ErrResponse renderer type for handling all sorts of errors. +type ErrResponse struct { + Err error `json:"-"` // low-level runtime error + HTTPStatusCode int `json:"-"` // http response status code + + StatusText string `json:"status"` // user-level status message + AppCode int64 `json:"code,omitempty"` // application-specific error code + ErrorText string `json:"error,omitempty"` // application-level error message, for debugging +} + +// Render sets the application-specific error code in AppCode. +func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error { + render.Status(r, e.HTTPStatusCode) + return nil +} + +// The list of default error types without specific error message. +var ( + ErrForbidden = &ErrResponse{ + HTTPStatusCode: http.StatusForbidden, + StatusText: http.StatusText(http.StatusForbidden), + } +) diff --git a/auth/authorizer.go b/auth/authorize/roles.go similarity index 86% rename from auth/authorizer.go rename to auth/authorize/roles.go index ada3de2..ad29b99 100644 --- a/auth/authorizer.go +++ b/auth/authorize/roles.go @@ -1,16 +1,18 @@ -package auth +package authorize import ( "net/http" "github.com/go-chi/render" + + "github.com/dhax/go-base/auth/jwt" ) // RequiresRole middleware restricts access to accounts having role parameter in their jwt claims. func RequiresRole(role string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { hfn := func(w http.ResponseWriter, r *http.Request) { - claims := ClaimsFromCtx(r.Context()) + claims := jwt.ClaimsFromCtx(r.Context()) if !hasRole(role, claims.Roles) { render.Render(w, r, ErrForbidden) return diff --git a/auth/crypto.go b/auth/crypto.go deleted file mode 100644 index d16b448..0000000 --- a/auth/crypto.go +++ /dev/null @@ -1,19 +0,0 @@ -package auth - -import ( - "crypto/rand" -) - -const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - -func randStringBytes(n int) string { - buf := make([]byte, n) - if _, err := rand.Read(buf); err != nil { - panic(err) - } - - for k, v := range buf { - buf[k] = letterBytes[v%byte(len(letterBytes))] - } - return string(buf) -} diff --git a/auth/errors.go b/auth/errors.go deleted file mode 100644 index 70d6263..0000000 --- a/auth/errors.go +++ /dev/null @@ -1,64 +0,0 @@ -package auth - -import ( - "net/http" - - "github.com/go-chi/render" -) - -// ErrResponse renderer type for handling all sorts of errors. -type ErrResponse struct { - Err error `json:"-"` // low-level runtime error - HTTPStatusCode int `json:"-"` // http response status code - - StatusText string `json:"status"` // user-level status message - AppCode int64 `json:"code,omitempty"` // application-specific error code - ErrorText string `json:"error,omitempty"` // application-level error message, for debugging -} - -// Render sets the application-specific error code in AppCode. -func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error { - render.Status(r, e.HTTPStatusCode) - return nil -} - -// ErrUnauthorized renders status 401 Unauthorized with custom error message. -func ErrUnauthorized(err error) render.Renderer { - return &ErrResponse{ - Err: err, - HTTPStatusCode: http.StatusUnauthorized, - StatusText: http.StatusText(http.StatusUnauthorized), - ErrorText: err.Error(), - } -} - -// ErrRender returns status 422 Unprocessable Entity for invalid request body -func ErrRender(err error) render.Renderer { - return &ErrResponse{ - Err: err, - HTTPStatusCode: http.StatusUnprocessableEntity, - StatusText: http.StatusText(http.StatusUnprocessableEntity), - ErrorText: err.Error(), - } -} - -// ErrInvalidRequest returns status 422 Unprocessable Entity with validation errors -func ErrInvalidRequest(err error) render.Renderer { - return &ErrResponse{ - Err: err, - HTTPStatusCode: http.StatusUnprocessableEntity, - StatusText: http.StatusText(http.StatusUnprocessableEntity), - ErrorText: err.Error(), - } -} - -// The list of default error types without specific error message. -var ( - ErrBadRequest = &ErrResponse{HTTPStatusCode: http.StatusBadRequest, StatusText: http.StatusText(http.StatusBadRequest)} - - ErrForbidden = &ErrResponse{HTTPStatusCode: http.StatusForbidden, StatusText: http.StatusText(http.StatusForbidden)} - - ErrNotFound = &ErrResponse{HTTPStatusCode: http.StatusNotFound, StatusText: http.StatusText(http.StatusNotFound)} - - ErrInternalServerError = &ErrResponse{HTTPStatusCode: http.StatusInternalServerError, StatusText: http.StatusText(http.StatusInternalServerError)} -) diff --git a/auth/authenticator.go b/auth/jwt/authenticator.go similarity index 63% rename from auth/authenticator.go rename to auth/jwt/authenticator.go index a2a0508..74bcb23 100644 --- a/auth/authenticator.go +++ b/auth/jwt/authenticator.go @@ -1,12 +1,13 @@ -package auth +package jwt import ( "context" - "errors" "net/http" "github.com/go-chi/jwtauth" "github.com/go-chi/render" + + "github.com/dhax/go-base/logging" ) type ctxKey int @@ -16,13 +17,6 @@ const ( ctxRefreshToken ) -var ( - errTokenUnauthorized = errors.New("token unauthorized") - errTokenExpired = errors.New("token expired") - errInvalidAccessToken = errors.New("invalid access token") - errInvalidRefreshToken = errors.New("invalid refresh token") -) - // ClaimsFromCtx retrieves the parsed AppClaims from request context. func ClaimsFromCtx(ctx context.Context) AppClaims { return ctx.Value(ctxClaims).(AppClaims) @@ -41,23 +35,27 @@ func Authenticator(next http.Handler) http.Handler { token, claims, err := jwtauth.FromContext(r.Context()) if err != nil { - log(r).Warn(err) - render.Render(w, r, ErrUnauthorized(errTokenUnauthorized)) + logging.GetLogEntry(r).Warn(err) + render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized)) return } if !token.Valid { - render.Render(w, r, ErrUnauthorized(errTokenExpired)) + render.Render(w, r, ErrUnauthorized(ErrTokenExpired)) return } // Token is authenticated, parse claims - pc, ok := parseClaims(claims) - if !ok { - render.Render(w, r, ErrUnauthorized(errInvalidAccessToken)) + var c AppClaims + err = c.ParseClaims(claims) + if err != nil { + logging.GetLogEntry(r).Error(err) + render.Render(w, r, ErrUnauthorized(ErrInvalidAccessToken)) return } - ctx := context.WithValue(r.Context(), ctxClaims, pc) + + // Set AppClaims on context + ctx := context.WithValue(r.Context(), ctxClaims, c) next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -67,21 +65,25 @@ func AuthenticateRefreshJWT(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token, claims, err := jwtauth.FromContext(r.Context()) if err != nil { - log(r).Warn(err) - render.Render(w, r, ErrUnauthorized(errTokenUnauthorized)) + logging.GetLogEntry(r).Warn(err) + render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized)) return } if !token.Valid { - render.Render(w, r, ErrUnauthorized(errTokenExpired)) + render.Render(w, r, ErrUnauthorized(ErrTokenExpired)) return } - refreshToken, ok := parseRefreshClaims(claims) - if !ok { - render.Render(w, r, ErrUnauthorized(errInvalidRefreshToken)) + + // Token is authenticated, parse refresh token string + var c RefreshClaims + err = c.ParseClaims(claims) + if err != nil { + logging.GetLogEntry(r).Error(err) + render.Render(w, r, ErrUnauthorized(ErrInvalidRefreshToken)) return } - // Token is authenticated, set on context - ctx := context.WithValue(r.Context(), ctxRefreshToken, refreshToken) + // Set refresh token string on context + ctx := context.WithValue(r.Context(), ctxRefreshToken, c.Token) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/auth/jwt/claims.go b/auth/jwt/claims.go new file mode 100644 index 0000000..4d4e285 --- /dev/null +++ b/auth/jwt/claims.go @@ -0,0 +1,59 @@ +package jwt + +import ( + "errors" + + "github.com/go-chi/jwtauth" +) + +// AppClaims represent the claims parsed from JWT access token. +type AppClaims struct { + ID int + Sub string + Roles []string +} + +// ParseClaims parses JWT claims into AppClaims. +func (c *AppClaims) ParseClaims(claims jwtauth.Claims) error { + id, ok := claims.Get("id") + if !ok { + return errors.New("could not parse claim id") + } + c.ID = int(id.(float64)) + + sub, ok := claims.Get("sub") + if !ok { + return errors.New("could not parse claim sub") + } + c.Sub = sub.(string) + + rl, ok := claims.Get("roles") + if !ok { + return errors.New("could not parse claims roles") + } + + var roles []string + if rl != nil { + for _, v := range rl.([]interface{}) { + roles = append(roles, v.(string)) + } + } + c.Roles = roles + + return nil +} + +// RefreshClaims represent the claims parsed from JWT refresh token. +type RefreshClaims struct { + Token string +} + +// ParseClaims parses the JWT claims into RefreshClaims. +func (c *RefreshClaims) ParseClaims(claims jwtauth.Claims) error { + token, ok := claims.Get("token") + if !ok { + return errors.New("could not parse claim token") + } + c.Token = token.(string) + return nil +} diff --git a/auth/jwt/errors.go b/auth/jwt/errors.go new file mode 100644 index 0000000..6535489 --- /dev/null +++ b/auth/jwt/errors.go @@ -0,0 +1,42 @@ +package jwt + +import ( + "errors" + "net/http" + + "github.com/go-chi/render" +) + +// The list of jwt token errors presented to the end user. +var ( + ErrTokenUnauthorized = errors.New("token unauthorized") + ErrTokenExpired = errors.New("token expired") + ErrInvalidAccessToken = errors.New("invalid access token") + ErrInvalidRefreshToken = errors.New("invalid refresh token") +) + +// ErrResponse renderer type for handling all sorts of errors. +type ErrResponse struct { + Err error `json:"-"` // low-level runtime error + HTTPStatusCode int `json:"-"` // http response status code + + StatusText string `json:"status"` // user-level status message + AppCode int64 `json:"code,omitempty"` // application-specific error code + ErrorText string `json:"error,omitempty"` // application-level error message, for debugging +} + +// Render sets the application-specific error code in AppCode. +func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error { + render.Status(r, e.HTTPStatusCode) + return nil +} + +// ErrUnauthorized renders status 401 Unauthorized with custom error message. +func ErrUnauthorized(err error) render.Renderer { + return &ErrResponse{ + Err: err, + HTTPStatusCode: http.StatusUnauthorized, + StatusText: http.StatusText(http.StatusUnauthorized), + ErrorText: err.Error(), + } +} diff --git a/auth/token.go b/auth/jwt/token.go similarity index 98% rename from auth/token.go rename to auth/jwt/token.go index bfbc0ef..48aed23 100644 --- a/auth/token.go +++ b/auth/jwt/token.go @@ -1,4 +1,4 @@ -package auth +package jwt import ( "time" diff --git a/auth/jwt.go b/auth/jwt/tokenauth.go similarity index 59% rename from auth/jwt.go rename to auth/jwt/tokenauth.go index 67f0efb..ce69f5e 100644 --- a/auth/jwt.go +++ b/auth/jwt/tokenauth.go @@ -1,6 +1,7 @@ -package auth +package jwt import ( + "crypto/rand" "net/http" "time" @@ -8,18 +9,11 @@ import ( "github.com/spf13/viper" ) -// AppClaims represent the claims extracted from JWT token. -type AppClaims struct { - ID int - Sub string - Roles []string -} - // TokenAuth implements JWT authentication flow. type TokenAuth struct { JwtAuth *jwtauth.JwtAuth - jwtExpiry time.Duration - jwtRefreshExpiry time.Duration + JwtExpiry time.Duration + JwtRefreshExpiry time.Duration } // NewTokenAuth configures and returns a JWT authentication instance. @@ -31,8 +25,8 @@ func NewTokenAuth() (*TokenAuth, error) { a := &TokenAuth{ JwtAuth: jwtauth.New("HS256", []byte(secret), nil), - jwtExpiry: viper.GetDuration("auth_jwt_expiry"), - jwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"), + JwtExpiry: viper.GetDuration("auth_jwt_expiry"), + JwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"), } return a, nil @@ -59,7 +53,7 @@ func (a *TokenAuth) GenTokenPair(ca jwtauth.Claims, cr jwtauth.Claims) (string, // CreateJWT returns an access token for provided account claims. func (a *TokenAuth) CreateJWT(c jwtauth.Claims) (string, error) { c.SetIssuedNow() - c.SetExpiryIn(a.jwtExpiry) + c.SetExpiryIn(a.JwtExpiry) _, tokenString, err := a.JwtAuth.Encode(c) return tokenString, err } @@ -67,46 +61,21 @@ func (a *TokenAuth) CreateJWT(c jwtauth.Claims) (string, error) { // CreateRefreshJWT returns a refresh token for provided token Claims. func (a *TokenAuth) CreateRefreshJWT(c jwtauth.Claims) (string, error) { c.SetIssuedNow() - c.SetExpiryIn(a.jwtRefreshExpiry) + c.SetExpiryIn(a.JwtRefreshExpiry) _, tokenString, err := a.JwtAuth.Encode(c) return tokenString, err } -func parseClaims(c jwtauth.Claims) (AppClaims, bool) { - var claims AppClaims - allOK := true - id, ok := c.Get("id") - if !ok { - allOK = false - } - claims.ID = int(id.(float64)) +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - sub, ok := c.Get("sub") - if !ok { - allOK = false - } - claims.Sub = sub.(string) - - rl, ok := c.Get("roles") - if !ok { - allOK = false +func randStringBytes(n int) string { + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + panic(err) } - var roles []string - if rl != nil { - for _, v := range rl.([]interface{}) { - roles = append(roles, v.(string)) - } + for k, v := range buf { + buf[k] = letterBytes[v%byte(len(letterBytes))] } - claims.Roles = roles - - return claims, allOK -} - -func parseRefreshClaims(c jwtauth.Claims) (string, bool) { - token, ok := c.Get("token") - if !ok { - return "", false - } - return token.(string), ok + return string(buf) } diff --git a/auth/mockStorer.go b/auth/mockStorer.go deleted file mode 100644 index 1938add..0000000 --- a/auth/mockStorer.go +++ /dev/null @@ -1,67 +0,0 @@ -package auth - -// MockStorer mocks Storer interface. -type MockStorer struct { - GetByIDFn func(id int) (*Account, error) - GetByIDInvoked bool - - GetByEmailFn func(email string) (*Account, error) - GetByEmailInvoked bool - - GetByRefreshTokenFn func(token string) (*Account, *Token, error) - GetByRefreshTokenInvoked bool - - UpdateAccountFn func(a *Account) error - UpdateAccountInvoked bool - - SaveRefreshTokenFn func(t *Token) error - SaveRefreshTokenInvoked bool - - DeleteRefreshTokenFn func(t *Token) error - DeleteRefreshTokenInvoked bool - - PurgeExpiredTokenFn func() error - PurgeExpiredTokenInvoked bool -} - -// GetByID mock returns an account by ID. -func (s *MockStorer) GetByID(id int) (*Account, error) { - s.GetByIDInvoked = true - return s.GetByIDFn(id) -} - -// GetByEmail mock returns an account by email. -func (s *MockStorer) GetByEmail(email string) (*Account, error) { - s.GetByEmailInvoked = true - return s.GetByEmailFn(email) -} - -// GetByRefreshToken mock returns an account and refresh token by token identifier. -func (s *MockStorer) GetByRefreshToken(token string) (*Account, *Token, error) { - s.GetByRefreshTokenInvoked = true - return s.GetByRefreshTokenFn(token) -} - -// UpdateAccount mock upates account data related to authentication. -func (s *MockStorer) UpdateAccount(a *Account) error { - s.UpdateAccountInvoked = true - return s.UpdateAccountFn(a) -} - -// SaveRefreshToken mock creates or updates a refresh token. -func (s *MockStorer) SaveRefreshToken(t *Token) error { - s.SaveRefreshTokenInvoked = true - return s.SaveRefreshTokenFn(t) -} - -// DeleteRefreshToken mock deletes a refresh token. -func (s *MockStorer) DeleteRefreshToken(t *Token) error { - s.DeleteRefreshTokenInvoked = true - return s.DeleteRefreshTokenFn(t) -} - -// PurgeExpiredToken mock deletes expired refresh token. -func (s *MockStorer) PurgeExpiredToken() error { - s.PurgeExpiredTokenInvoked = true - return s.PurgeExpiredTokenFn() -} diff --git a/auth/account.go b/auth/pwdless/account.go similarity index 95% rename from auth/account.go rename to auth/pwdless/account.go index 3b43125..e8f5e2b 100644 --- a/auth/account.go +++ b/auth/pwdless/account.go @@ -1,10 +1,11 @@ -package auth +package pwdless import ( "net/url" "strings" "time" + "github.com/dhax/go-base/auth/jwt" "github.com/go-chi/jwtauth" validation "github.com/go-ozzo/ozzo-validation" "github.com/go-ozzo/ozzo-validation/is" @@ -23,7 +24,7 @@ type Account struct { Active bool `sql:",notnull" json:"active"` Roles []string `pg:",array" json:"roles,omitempty"` - Token []Token `json:"token,omitempty"` + Token []jwt.Token `json:"token,omitempty"` } // BeforeInsert hook executed before database insert operation. @@ -38,11 +39,8 @@ func (a *Account) BeforeInsert(db orm.DB) error { // BeforeUpdate hook executed before database update operation. func (a *Account) BeforeUpdate(db orm.DB) error { - if err := a.Validate(); err != nil { - return err - } a.UpdatedAt = time.Now() - return nil + return a.Validate() } // BeforeDelete hook executed before database delete operation. diff --git a/auth/handler.go b/auth/pwdless/api.go similarity index 51% rename from auth/handler.go rename to auth/pwdless/api.go index 5919eec..4816824 100644 --- a/auth/handler.go +++ b/auth/pwdless/api.go @@ -1,29 +1,93 @@ -package auth +// Package pwdless provides JSON Web Token (JWT) authentication and authorization middleware. +// It implements a passwordless authentication flow by sending login tokens vie email which are then exchanged for JWT access and refresh tokens. +package pwdless import ( - "errors" "fmt" "net/http" "path" "strings" "time" + "github.com/dhax/go-base/auth/jwt" + "github.com/dhax/go-base/email" + "github.com/dhax/go-base/logging" + "github.com/go-chi/chi" "github.com/go-chi/render" validation "github.com/go-ozzo/ozzo-validation" "github.com/go-ozzo/ozzo-validation/is" "github.com/mssola/user_agent" uuid "github.com/satori/go.uuid" - - "github.com/dhax/go-base/email" + "github.com/sirupsen/logrus" ) -// The list of error types presented to the end user as error message. -var ( - ErrInvalidLogin = errors.New("invalid email address") - ErrUnknownLogin = errors.New("email not registered") - ErrLoginDisabled = errors.New("login for account disabled") - ErrLoginToken = errors.New("invalid or expired login token") -) +// AuthStorer defines database operations on accounts and tokens. +type AuthStorer interface { + GetAccount(id int) (*Account, error) + GetAccountByEmail(email string) (*Account, error) + UpdateAccount(a *Account) error + + GetToken(token string) (*jwt.Token, error) + CreateOrUpdateToken(t *jwt.Token) error + DeleteToken(t *jwt.Token) error + PurgeExpiredToken() error +} + +// Mailer defines methods to send account emails. +type Mailer interface { + LoginToken(name, email string, c email.ContentLoginToken) error +} + +// Resource implements passwordless account authentication against a database. +type Resource struct { + LoginAuth *LoginTokenAuth + TokenAuth *jwt.TokenAuth + Store AuthStorer + Mailer Mailer +} + +// NewResource returns a configured authentication resource. +func NewResource(authStore AuthStorer, mailer Mailer) (*Resource, error) { + loginAuth, err := NewLoginTokenAuth() + if err != nil { + return nil, err + } + + tokenAuth, err := jwt.NewTokenAuth() + if err != nil { + return nil, err + } + + resource := &Resource{ + LoginAuth: loginAuth, + TokenAuth: tokenAuth, + Store: authStore, + Mailer: mailer, + } + + resource.choresTicker() + + return resource, nil +} + +// Router provides necessary routes for passwordless authentication flow. +func (rs *Resource) Router() *chi.Mux { + r := chi.NewRouter() + r.Use(render.SetContentType(render.ContentTypeJSON)) + r.Post("/login", rs.login) + r.Post("/token", rs.token) + r.Group(func(r chi.Router) { + r.Use(rs.TokenAuth.Verifier()) + r.Use(jwt.AuthenticateRefreshJWT) + r.Post("/refresh", rs.refresh) + r.Post("/logout", rs.logout) + }) + return r +} + +func log(r *http.Request) logrus.FieldLogger { + return logging.GetLogEntry(r) +} type loginRequest struct { Email string @@ -46,7 +110,7 @@ func (rs *Resource) login(w http.ResponseWriter, r *http.Request) { return } - acc, err := rs.store.GetByEmail(body.Email) + acc, err := rs.Store.GetAccountByEmail(body.Email) if err != nil { log(r).WithField("email", body.Email).Warn(err) render.Render(w, r, ErrUnauthorized(ErrUnknownLogin)) @@ -58,17 +122,17 @@ func (rs *Resource) login(w http.ResponseWriter, r *http.Request) { return } - lt := rs.Login.CreateToken(acc.ID) + lt := rs.LoginAuth.CreateToken(acc.ID) go func() { content := email.ContentLoginToken{ Email: acc.Email, Name: acc.Name, - URL: path.Join(rs.Login.loginURL, lt.Token), + URL: path.Join(rs.LoginAuth.loginURL, lt.Token), Token: lt.Token, Expiry: lt.Expiry, } - if err := rs.mailer.LoginToken(acc.Name, acc.Email, content); err != nil { + if err := rs.Mailer.LoginToken(acc.Name, acc.Email, content); err != nil { log(r).WithField("module", "email").Error(err) } }() @@ -101,13 +165,13 @@ func (rs *Resource) token(w http.ResponseWriter, r *http.Request) { return } - id, err := rs.Login.GetAccountID(body.Token) + id, err := rs.LoginAuth.GetAccountID(body.Token) if err != nil { render.Render(w, r, ErrUnauthorized(ErrLoginToken)) return } - acc, err := rs.store.GetByID(id) + acc, err := rs.Store.GetAccount(id) if err != nil { // account deleted before login token expired render.Render(w, r, ErrUnauthorized(ErrUnknownLogin)) @@ -122,22 +186,22 @@ func (rs *Resource) token(w http.ResponseWriter, r *http.Request) { ua := user_agent.New(r.UserAgent()) browser, _ := ua.Browser() - token := &Token{ + token := &jwt.Token{ Token: uuid.NewV4().String(), - Expiry: time.Now().Add(rs.Token.jwtRefreshExpiry), + Expiry: time.Now().Add(rs.TokenAuth.JwtRefreshExpiry), UpdatedAt: time.Now(), AccountID: acc.ID, Mobile: ua.Mobile(), Identifier: fmt.Sprintf("%s on %s", browser, ua.OS()), } - if err := rs.store.SaveRefreshToken(token); err != nil { + if err := rs.Store.CreateOrUpdateToken(token); err != nil { log(r).Error(err) render.Render(w, r, ErrInternalServerError) return } - access, refresh, err := rs.Token.GenTokenPair(acc.Claims(), token.Claims()) + access, refresh, err := rs.TokenAuth.GenTokenPair(acc.Claims(), token.Claims()) if err != nil { log(r).Error(err) render.Render(w, r, ErrInternalServerError) @@ -145,7 +209,7 @@ func (rs *Resource) token(w http.ResponseWriter, r *http.Request) { } acc.LastLogin = time.Now() - if err := rs.store.UpdateAccount(acc); err != nil { + if err := rs.Store.UpdateAccount(acc); err != nil { log(r).Error(err) render.Render(w, r, ErrInternalServerError) return @@ -158,17 +222,23 @@ func (rs *Resource) token(w http.ResponseWriter, r *http.Request) { } func (rs *Resource) refresh(w http.ResponseWriter, r *http.Request) { - rt := RefreshTokenFromCtx(r.Context()) + rt := jwt.RefreshTokenFromCtx(r.Context()) - acc, token, err := rs.store.GetByRefreshToken(rt) + token, err := rs.Store.GetToken(rt) if err != nil { - render.Render(w, r, ErrUnauthorized(errTokenExpired)) + render.Render(w, r, ErrUnauthorized(jwt.ErrTokenExpired)) return } if time.Now().After(token.Expiry) { - rs.store.DeleteRefreshToken(token) - render.Render(w, r, ErrUnauthorized(errTokenExpired)) + rs.Store.DeleteToken(token) + render.Render(w, r, ErrUnauthorized(jwt.ErrTokenExpired)) + return + } + + acc, err := rs.Store.GetAccount(token.AccountID) + if err != nil { + render.Render(w, r, ErrUnauthorized(ErrUnknownLogin)) return } @@ -178,24 +248,24 @@ func (rs *Resource) refresh(w http.ResponseWriter, r *http.Request) { } token.Token = uuid.NewV4().String() - token.Expiry = time.Now().Add(rs.Token.jwtRefreshExpiry) + token.Expiry = time.Now().Add(rs.TokenAuth.JwtRefreshExpiry) token.UpdatedAt = time.Now() - access, refresh, err := rs.Token.GenTokenPair(acc.Claims(), token.Claims()) + access, refresh, err := rs.TokenAuth.GenTokenPair(acc.Claims(), token.Claims()) if err != nil { log(r).Error(err) render.Render(w, r, ErrInternalServerError) return } - if err := rs.store.SaveRefreshToken(token); err != nil { + if err := rs.Store.CreateOrUpdateToken(token); err != nil { log(r).Error(err) render.Render(w, r, ErrInternalServerError) return } acc.LastLogin = time.Now() - if err := rs.store.UpdateAccount(acc); err != nil { + if err := rs.Store.UpdateAccount(acc); err != nil { log(r).Error(err) render.Render(w, r, ErrInternalServerError) return @@ -208,13 +278,13 @@ func (rs *Resource) refresh(w http.ResponseWriter, r *http.Request) { } func (rs *Resource) logout(w http.ResponseWriter, r *http.Request) { - rt := RefreshTokenFromCtx(r.Context()) - _, token, err := rs.store.GetByRefreshToken(rt) + rt := jwt.RefreshTokenFromCtx(r.Context()) + token, err := rs.Store.GetToken(rt) if err != nil { - render.Render(w, r, ErrUnauthorized(errTokenExpired)) + render.Render(w, r, ErrUnauthorized(jwt.ErrTokenExpired)) return } - rs.store.DeleteRefreshToken(token) + rs.Store.DeleteToken(token) render.Respond(w, r, http.NoBody) } diff --git a/auth/handler_test.go b/auth/pwdless/api_test.go similarity index 72% rename from auth/handler_test.go rename to auth/pwdless/api_test.go index e1101d7..11ce752 100644 --- a/auth/handler_test.go +++ b/auth/pwdless/api_test.go @@ -1,9 +1,10 @@ -package auth +package pwdless import ( "bytes" "encoding/json" "errors" + "fmt" "io" "io/ioutil" "net/http" @@ -13,16 +14,18 @@ import ( "testing" "time" - "github.com/dhax/go-base/email" - "github.com/dhax/go-base/logging" "github.com/go-chi/chi" "github.com/go-chi/jwtauth" "github.com/spf13/viper" + + "github.com/dhax/go-base/auth/jwt" + "github.com/dhax/go-base/email" + "github.com/dhax/go-base/logging" ) var ( auth *Resource - authstore MockStorer + authStore MockAuthStore mailer email.MockMailer ts *httptest.Server ) @@ -34,8 +37,9 @@ func TestMain(m *testing.M) { viper.SetDefault("log_level", "error") var err error - auth, err = NewResource(&authstore, &mailer) + auth, err = NewResource(&authStore, &mailer) if err != nil { + fmt.Println(err) os.Exit(1) } @@ -51,7 +55,7 @@ func TestMain(m *testing.M) { } func TestAuthResource_login(t *testing.T) { - authstore.GetByEmailFn = func(email string) (*Account, error) { + authStore.GetAccountByEmailFn = func(email string) (*Account, error) { var err error a := Account{ ID: 1, @@ -100,20 +104,20 @@ func TestAuthResource_login(t *testing.T) { if tc.err != nil && !strings.Contains(body, tc.err.Error()) { t.Errorf(" got: %s, expected to contain: %s", body, tc.err.Error()) } - if tc.err == ErrInvalidLogin && authstore.GetByEmailInvoked { + if tc.err == ErrInvalidLogin && authStore.GetAccountByEmailInvoked { t.Error("GetByLoginToken invoked for invalid email") } if tc.err == nil && !mailer.LoginTokenInvoked { t.Error("emailService.LoginToken not invoked") } - authstore.GetByEmailInvoked = false + authStore.GetAccountByEmailInvoked = false mailer.LoginTokenInvoked = false }) } } func TestAuthResource_token(t *testing.T) { - authstore.GetByIDFn = func(id int) (*Account, error) { + authStore.GetAccountFn = func(id int) (*Account, error) { var err error a := Account{ ID: id, @@ -130,11 +134,11 @@ func TestAuthResource_token(t *testing.T) { } return &a, err } - authstore.UpdateAccountFn = func(a *Account) error { + authStore.UpdateAccountFn = func(a *Account) error { a.LastLogin = time.Now() return nil } - authstore.SaveRefreshTokenFn = func(a *Token) error { + authStore.CreateOrUpdateTokenFn = func(a *jwt.Token) error { return nil } @@ -154,7 +158,7 @@ func TestAuthResource_token(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - token := auth.Login.CreateToken(tc.id) + token := auth.LoginAuth.CreateToken(tc.id) if tc.token != "" { token.Token = tc.token } @@ -171,25 +175,37 @@ func TestAuthResource_token(t *testing.T) { if tc.err != nil && !strings.Contains(body, tc.err.Error()) { t.Errorf("got: %s, expected to contain: %s", body, tc.err.Error()) } - if tc.err == ErrLoginToken && authstore.SaveRefreshTokenInvoked { - t.Errorf("SaveRefreshToken invoked despite error %s", tc.err.Error()) + if tc.err == ErrLoginToken && authStore.CreateOrUpdateTokenInvoked { + t.Errorf("CreateOrUpdate invoked despite error %s", tc.err.Error()) } - if tc.err == nil && !authstore.SaveRefreshTokenInvoked { - t.Error("SaveRefreshToken not invoked") + if tc.err == nil && !authStore.CreateOrUpdateTokenInvoked { + t.Error("CreateOrUpdate not invoked") } - authstore.SaveRefreshTokenInvoked = false + authStore.CreateOrUpdateTokenInvoked = false }) } } func TestAuthResource_refresh(t *testing.T) { - authstore.GetByRefreshTokenFn = func(token string) (*Account, *Token, error) { - var err error + authStore.GetAccountFn = func(id int) (*Account, error) { a := Account{ Active: true, Name: "Test", } - var t Token + switch id { + case 999: + a.Active = false + } + return &a, nil + } + authStore.UpdateAccountFn = func(a *Account) error { + a.LastLogin = time.Now() + return nil + } + + authStore.GetTokenFn = func(token string) (*jwt.Token, error) { + var err error + var t jwt.Token t.Expiry = time.Now().Add(1 * time.Minute) switch token { @@ -198,20 +214,14 @@ func TestAuthResource_refresh(t *testing.T) { case "expired": t.Expiry = time.Now().Add(-1 * time.Minute) case "disabled": - a.Active = false - case "valid": - // unmodified + t.AccountID = 999 } - return &a, &t, err + return &t, err } - authstore.UpdateAccountFn = func(a *Account) error { - a.LastLogin = time.Now() + authStore.CreateOrUpdateTokenFn = func(a *jwt.Token) error { return nil } - authstore.SaveRefreshTokenFn = func(a *Token) error { - return nil - } - authstore.DeleteRefreshTokenFn = func(t *Token) error { + authStore.DeleteTokenFn = func(t *jwt.Token) error { return nil } @@ -222,8 +232,8 @@ func TestAuthResource_refresh(t *testing.T) { status int err error }{ - {"notfound", "notfound", 1, http.StatusUnauthorized, errTokenExpired}, - {"expired", "expired", -1, http.StatusUnauthorized, errTokenUnauthorized}, + {"notfound", "notfound", 1, http.StatusUnauthorized, jwt.ErrTokenExpired}, + {"expired", "expired", -1, http.StatusUnauthorized, jwt.ErrTokenUnauthorized}, {"disabled", "disabled", 1, http.StatusUnauthorized, ErrLoginDisabled}, {"valid", "valid", 1, http.StatusOK, nil}, } @@ -238,32 +248,31 @@ func TestAuthResource_refresh(t *testing.T) { if tc.err != nil && !strings.Contains(body, tc.err.Error()) { t.Errorf("got: %s, expected error to contain: %s", body, tc.err.Error()) } - if tc.status == http.StatusUnauthorized && authstore.SaveRefreshTokenInvoked { - t.Errorf("SaveRefreshToken invoked for status %d", tc.status) + if tc.status == http.StatusUnauthorized && authStore.CreateOrUpdateTokenInvoked { + t.Errorf("CreateOrUpdate invoked for status %d", tc.status) } if tc.status == http.StatusOK { - if !authstore.GetByRefreshTokenInvoked { - t.Errorf("GetRefreshToken not invoked") + if !authStore.GetTokenInvoked { + t.Errorf("GetByToken not invoked") } - if !authstore.SaveRefreshTokenInvoked { - t.Errorf("SaveRefreshToken not invoked") + if !authStore.CreateOrUpdateTokenInvoked { + t.Errorf("CreateOrUpdate not invoked") } - if authstore.DeleteRefreshTokenInvoked { - t.Errorf("DeleteRefreshToken should not be invoked") + if authStore.DeleteTokenInvoked { + t.Errorf("Delete should not be invoked") } } - authstore.GetByRefreshTokenInvoked = false - authstore.SaveRefreshTokenInvoked = false - authstore.DeleteRefreshTokenInvoked = false + authStore.GetTokenInvoked = false + authStore.CreateOrUpdateTokenInvoked = false + authStore.DeleteTokenInvoked = false }) } } func TestAuthResource_logout(t *testing.T) { - authstore.GetByRefreshTokenFn = func(token string) (*Account, *Token, error) { + authStore.GetTokenFn = func(token string) (*jwt.Token, error) { var err error - var a Account - t := Token{ + t := jwt.Token{ Expiry: time.Now().Add(1 * time.Minute), } @@ -271,9 +280,9 @@ func TestAuthResource_logout(t *testing.T) { case "notfound": err = errors.New("sql no rows") } - return &a, &t, err + return &t, err } - authstore.DeleteRefreshTokenFn = func(a *Token) error { + authStore.DeleteTokenFn = func(a *jwt.Token) error { return nil } @@ -284,8 +293,8 @@ func TestAuthResource_logout(t *testing.T) { status int err error }{ - {"notfound", "notfound", 1, http.StatusUnauthorized, errTokenExpired}, - {"expired", "valid", -1, http.StatusUnauthorized, errTokenUnauthorized}, + {"notfound", "notfound", 1, http.StatusUnauthorized, jwt.ErrTokenExpired}, + {"expired", "valid", -1, http.StatusUnauthorized, jwt.ErrTokenUnauthorized}, {"valid", "valid", 1, http.StatusOK, nil}, } @@ -299,10 +308,10 @@ func TestAuthResource_logout(t *testing.T) { if tc.err != nil && !strings.Contains(body, tc.err.Error()) { t.Errorf("got: %x, expected error to contain %s", body, tc.err.Error()) } - if tc.status == http.StatusUnauthorized && authstore.DeleteRefreshTokenInvoked { - t.Errorf("DeleteRefreshToken invoked for status %d", tc.status) + if tc.status == http.StatusUnauthorized && authStore.DeleteTokenInvoked { + t.Errorf("Delete invoked for status %d", tc.status) } - authstore.DeleteRefreshTokenInvoked = false + authStore.DeleteTokenInvoked = false }) } } @@ -335,7 +344,7 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io } func genJWT(c jwtauth.Claims) string { - _, tokenString, _ := auth.Token.JwtAuth.Encode(c) + _, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(c) return tokenString } diff --git a/auth/pwdless/chores.go b/auth/pwdless/chores.go new file mode 100644 index 0000000..24d0274 --- /dev/null +++ b/auth/pwdless/chores.go @@ -0,0 +1,18 @@ +package pwdless + +import ( + "time" + + "github.com/dhax/go-base/logging" +) + +func (rs *Resource) choresTicker() { + ticker := time.NewTicker(time.Hour * 1) + go func() { + for range ticker.C { + if err := rs.Store.PurgeExpiredToken(); err != nil { + logging.Logger.WithField("chore", "purgeExpiredToken").Error(err) + } + } + }() +} diff --git a/auth/pwdless/errors.go b/auth/pwdless/errors.go new file mode 100644 index 0000000..75829e4 --- /dev/null +++ b/auth/pwdless/errors.go @@ -0,0 +1,50 @@ +package pwdless + +import ( + "errors" + "net/http" + + "github.com/go-chi/render" +) + +// The list of error types presented to the end user as error message. +var ( + ErrInvalidLogin = errors.New("invalid email address") + ErrUnknownLogin = errors.New("email not registered") + ErrLoginDisabled = errors.New("login for account disabled") + ErrLoginToken = errors.New("invalid or expired login token") +) + +// ErrResponse renderer type for handling all sorts of errors. +type ErrResponse struct { + Err error `json:"-"` // low-level runtime error + HTTPStatusCode int `json:"-"` // http response status code + + StatusText string `json:"status"` // user-level status message + AppCode int64 `json:"code,omitempty"` // application-specific error code + ErrorText string `json:"error,omitempty"` // application-level error message, for debugging +} + +// Render sets the application-specific error code in AppCode. +func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error { + render.Status(r, e.HTTPStatusCode) + return nil +} + +// ErrUnauthorized renders status 401 Unauthorized with custom error message. +func ErrUnauthorized(err error) render.Renderer { + return &ErrResponse{ + Err: err, + HTTPStatusCode: http.StatusUnauthorized, + StatusText: http.StatusText(http.StatusUnauthorized), + ErrorText: err.Error(), + } +} + +// The list of default error types without specific error message. +var ( + ErrInternalServerError = &ErrResponse{ + HTTPStatusCode: http.StatusInternalServerError, + StatusText: http.StatusText(http.StatusInternalServerError), + } +) diff --git a/auth/logintoken.go b/auth/pwdless/logintoken.go similarity index 86% rename from auth/logintoken.go rename to auth/pwdless/logintoken.go index 7e65730..a2a410a 100644 --- a/auth/logintoken.go +++ b/auth/pwdless/logintoken.go @@ -1,6 +1,7 @@ -package auth +package pwdless import ( + "crypto/rand" "errors" "sync" "time" @@ -87,3 +88,17 @@ func (a *LoginTokenAuth) purgeExpired() { } } } + +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +func randStringBytes(n int) string { + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + panic(err) + } + + for k, v := range buf { + buf[k] = letterBytes[v%byte(len(letterBytes))] + } + return string(buf) +} diff --git a/auth/pwdless/mockAuthStore.go b/auth/pwdless/mockAuthStore.go new file mode 100644 index 0000000..15521ab --- /dev/null +++ b/auth/pwdless/mockAuthStore.go @@ -0,0 +1,69 @@ +package pwdless + +import "github.com/dhax/go-base/auth/jwt" + +// MockAuthStore mocks AuthStorer interface. +type MockAuthStore struct { + GetAccountFn func(id int) (*Account, error) + GetAccountInvoked bool + + GetAccountByEmailFn func(email string) (*Account, error) + GetAccountByEmailInvoked bool + + UpdateAccountFn func(a *Account) error + UpdateAccountInvoked bool + + GetTokenFn func(token string) (*jwt.Token, error) + GetTokenInvoked bool + + CreateOrUpdateTokenFn func(t *jwt.Token) error + CreateOrUpdateTokenInvoked bool + + DeleteTokenFn func(t *jwt.Token) error + DeleteTokenInvoked bool + + PurgeExpiredTokenFn func() error + PurgeExpiredTokenInvoked bool +} + +// GetAccount mock returns an account by ID. +func (s *MockAuthStore) GetAccount(id int) (*Account, error) { + s.GetAccountInvoked = true + return s.GetAccountFn(id) +} + +// GetAccountByEmail mock returns an account by email. +func (s *MockAuthStore) GetAccountByEmail(email string) (*Account, error) { + s.GetAccountByEmailInvoked = true + return s.GetAccountByEmailFn(email) +} + +// UpdateAccount mock upates account data related to authentication. +func (s *MockAuthStore) UpdateAccount(a *Account) error { + s.UpdateAccountInvoked = true + return s.UpdateAccountFn(a) +} + +// GetToken mock returns an account and refresh token by token identifier. +func (s *MockAuthStore) GetToken(token string) (*jwt.Token, error) { + s.GetTokenInvoked = true + return s.GetTokenFn(token) +} + +// CreateOrUpdateToken mock creates or updates a refresh token. +func (s *MockAuthStore) CreateOrUpdateToken(t *jwt.Token) error { + s.CreateOrUpdateTokenInvoked = true + return s.CreateOrUpdateTokenFn(t) +} + +// DeleteToken mock deletes a refresh token. +func (s *MockAuthStore) DeleteToken(t *jwt.Token) error { + s.DeleteTokenInvoked = true + return s.DeleteTokenFn(t) +} + +// PurgeExpiredToken mock deletes expired refresh token. +func (s *MockAuthStore) PurgeExpiredToken() error { + s.PurgeExpiredTokenInvoked = true + return s.PurgeExpiredTokenFn() +} diff --git a/database/accountStore.go b/database/accountStore.go index 88a6f09..0f4fdc9 100644 --- a/database/accountStore.go +++ b/database/accountStore.go @@ -1,7 +1,8 @@ package database import ( - "github.com/dhax/go-base/auth" + "github.com/dhax/go-base/auth/jwt" + "github.com/dhax/go-base/auth/pwdless" "github.com/dhax/go-base/models" "github.com/go-pg/pg" ) @@ -19,8 +20,8 @@ func NewAccountStore(db *pg.DB) *AccountStore { } // Get an account by ID. -func (s *AccountStore) Get(id int) (*auth.Account, error) { - a := auth.Account{ID: id} +func (s *AccountStore) Get(id int) (*pwdless.Account, error) { + a := pwdless.Account{ID: id} err := s.db.Model(&a). Where("account.id = ?id"). Column("account.*", "Token"). @@ -29,7 +30,7 @@ func (s *AccountStore) Get(id int) (*auth.Account, error) { } // Update an account. -func (s *AccountStore) Update(a *auth.Account) error { +func (s *AccountStore) Update(a *pwdless.Account) error { _, err := s.db.Model(a). Column("email", "name"). Update() @@ -37,9 +38,9 @@ func (s *AccountStore) Update(a *auth.Account) error { } // Delete an account. -func (s *AccountStore) Delete(a *auth.Account) error { +func (s *AccountStore) Delete(a *pwdless.Account) error { err := s.db.RunInTransaction(func(tx *pg.Tx) error { - if _, err := tx.Model(&auth.Token{}). + if _, err := tx.Model(&jwt.Token{}). Where("account_id = ?", a.ID). Delete(); err != nil { return err @@ -55,7 +56,7 @@ func (s *AccountStore) Delete(a *auth.Account) error { } // UpdateToken updates a jwt refresh token. -func (s *AccountStore) UpdateToken(t *auth.Token) error { +func (s *AccountStore) UpdateToken(t *jwt.Token) error { _, err := s.db.Model(t). Column("identifier"). Update() @@ -63,7 +64,7 @@ func (s *AccountStore) UpdateToken(t *auth.Token) error { } // DeleteToken deletes a jwt refresh token. -func (s *AccountStore) DeleteToken(t *auth.Token) error { +func (s *AccountStore) DeleteToken(t *jwt.Token) error { err := s.db.Delete(t) return err } diff --git a/database/admAccountStore.go b/database/admAccountStore.go index 3e58c5c..47571a7 100644 --- a/database/admAccountStore.go +++ b/database/admAccountStore.go @@ -3,7 +3,8 @@ package database import ( "errors" - "github.com/dhax/go-base/auth" + "github.com/dhax/go-base/auth/jwt" + "github.com/dhax/go-base/auth/pwdless" "github.com/dhax/go-base/models" "github.com/go-pg/pg" ) @@ -26,8 +27,8 @@ func NewAdmAccountStore(db *pg.DB) *AdmAccountStore { } // List applies a filter and returns paginated array of matching results and total count. -func (s *AdmAccountStore) List(f auth.AccountFilter) ([]auth.Account, int, error) { - a := []auth.Account{} +func (s *AdmAccountStore) List(f pwdless.AccountFilter) ([]pwdless.Account, int, error) { + a := []pwdless.Account{} count, err := s.db.Model(&a). Apply(f.Filter). SelectAndCount() @@ -38,7 +39,7 @@ func (s *AdmAccountStore) List(f auth.AccountFilter) ([]auth.Account, int, error } // Create creates a new account. -func (s *AdmAccountStore) Create(a *auth.Account) error { +func (s *AdmAccountStore) Create(a *pwdless.Account) error { count, _ := s.db.Model(a). Where("email = ?email"). Count() @@ -62,22 +63,22 @@ func (s *AdmAccountStore) Create(a *auth.Account) error { } // Get account by ID. -func (s *AdmAccountStore) Get(id int) (*auth.Account, error) { - a := auth.Account{ID: id} +func (s *AdmAccountStore) Get(id int) (*pwdless.Account, error) { + a := pwdless.Account{ID: id} err := s.db.Select(&a) return &a, err } // Update account. -func (s *AdmAccountStore) Update(a *auth.Account) error { +func (s *AdmAccountStore) Update(a *pwdless.Account) error { err := s.db.Update(a) return err } // Delete account. -func (s *AdmAccountStore) Delete(a *auth.Account) error { +func (s *AdmAccountStore) Delete(a *pwdless.Account) error { err := s.db.RunInTransaction(func(tx *pg.Tx) error { - if _, err := tx.Model(&auth.Token{}). + if _, err := tx.Model(&jwt.Token{}). Where("account_id = ?", a.ID). Delete(); err != nil { return err diff --git a/database/authStore.go b/database/authStore.go index 2f60efe..fafc4a1 100644 --- a/database/authStore.go +++ b/database/authStore.go @@ -3,11 +3,12 @@ package database import ( "time" - "github.com/dhax/go-base/auth" + "github.com/dhax/go-base/auth/jwt" + "github.com/dhax/go-base/auth/pwdless" "github.com/go-pg/pg" ) -// AuthStore implements database operations for account authentication. +// AuthStore implements database operations for account pwdlessentication. type AuthStore struct { db *pg.DB } @@ -19,9 +20,9 @@ func NewAuthStore(db *pg.DB) *AuthStore { } } -// GetByID returns an account by ID. -func (s *AuthStore) GetByID(id int) (*auth.Account, error) { - a := auth.Account{ID: id} +// GetAccount returns an account by ID. +func (s *AuthStore) GetAccount(id int) (*pwdless.Account, error) { + a := pwdless.Account{ID: id} err := s.db.Model(&a). Column("account.*"). Where("id = ?id"). @@ -29,9 +30,9 @@ func (s *AuthStore) GetByID(id int) (*auth.Account, error) { return &a, err } -// GetByEmail returns an account by email. -func (s *AuthStore) GetByEmail(e string) (*auth.Account, error) { - a := auth.Account{Email: e} +// GetAccountByEmail returns an account by email. +func (s *AuthStore) GetAccountByEmail(e string) (*pwdless.Account, error) { + a := pwdless.Account{Email: e} err := s.db.Model(&a). Column("id", "active", "email", "name"). Where("email = ?email"). @@ -39,35 +40,26 @@ func (s *AuthStore) GetByEmail(e string) (*auth.Account, error) { return &a, err } -// GetByRefreshToken returns an account and refresh token by token identifier. -func (s *AuthStore) GetByRefreshToken(t string) (*auth.Account, *auth.Token, error) { - token := auth.Token{Token: t} - err := s.db.Model(&token). - Where("token = ?token"). - First() - if err != nil { - return nil, nil, err - } - - a := auth.Account{ID: token.AccountID} - err = s.db.Model(&a). - Column("account.*"). - Where("id = ?id"). - First() - - return &a, &token, err -} - -// UpdateAccount upates account data related to authentication. -func (s *AuthStore) UpdateAccount(a *auth.Account) error { +// UpdateAccount upates account data related to pwdlessentication. +func (s *AuthStore) UpdateAccount(a *pwdless.Account) error { _, err := s.db.Model(a). Column("last_login"). Update() return err } -// SaveRefreshToken creates or updates a refresh token. -func (s *AuthStore) SaveRefreshToken(t *auth.Token) error { +// GetToken returns refresh token by token identifier. +func (s *AuthStore) GetToken(t string) (*jwt.Token, error) { + token := jwt.Token{Token: t} + err := s.db.Model(&token). + Where("token = ?token"). + First() + + return &token, err +} + +// CreateOrUpdateToken creates or updates an existing refresh token. +func (s *AuthStore) CreateOrUpdateToken(t *jwt.Token) error { var err error if t.ID == 0 { err = s.db.Insert(t) @@ -77,15 +69,15 @@ func (s *AuthStore) SaveRefreshToken(t *auth.Token) error { return err } -// DeleteRefreshToken deletes a refresh token. -func (s *AuthStore) DeleteRefreshToken(t *auth.Token) error { +// DeleteToken deletes a refresh token. +func (s *AuthStore) DeleteToken(t *jwt.Token) error { err := s.db.Delete(t) return err } // PurgeExpiredToken deletes expired refresh token. func (s *AuthStore) PurgeExpiredToken() error { - _, err := s.db.Model(&auth.Token{}). + _, err := s.db.Model(&jwt.Token{}). Where("expiry < ?", time.Now()). Delete()