refactor auth pkg into libraries
This commit is contained in:
parent
521f081ba0
commit
aaf0a0928d
26 changed files with 592 additions and 504 deletions
|
|
@ -6,9 +6,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/dhax/go-base/auth/pwdless"
|
||||||
"github.com/go-ozzo/ozzo-validation"
|
"github.com/go-ozzo/ozzo-validation"
|
||||||
|
|
||||||
"github.com/dhax/go-base/auth"
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
)
|
)
|
||||||
|
|
@ -20,11 +20,11 @@ var (
|
||||||
|
|
||||||
// AccountStore defines database operations for account management.
|
// AccountStore defines database operations for account management.
|
||||||
type AccountStore interface {
|
type AccountStore interface {
|
||||||
List(f auth.AccountFilter) ([]auth.Account, int, error)
|
List(f pwdless.AccountFilter) ([]pwdless.Account, int, error)
|
||||||
Create(*auth.Account) error
|
Create(*pwdless.Account) error
|
||||||
Get(id int) (*auth.Account, error)
|
Get(id int) (*pwdless.Account, error)
|
||||||
Update(*auth.Account) error
|
Update(*pwdless.Account) error
|
||||||
Delete(*auth.Account) error
|
Delete(*pwdless.Account) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountResource implements account management handler.
|
// AccountResource implements account management handler.
|
||||||
|
|
@ -70,7 +70,7 @@ func (rs *AccountResource) accountCtx(next http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
type accountRequest struct {
|
type accountRequest struct {
|
||||||
*auth.Account
|
*pwdless.Account
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *accountRequest) Bind(r *http.Request) error {
|
func (d *accountRequest) Bind(r *http.Request) error {
|
||||||
|
|
@ -78,20 +78,20 @@ func (d *accountRequest) Bind(r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type accountResponse struct {
|
type accountResponse struct {
|
||||||
*auth.Account
|
*pwdless.Account
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAccountResponse(a *auth.Account) *accountResponse {
|
func newAccountResponse(a *pwdless.Account) *accountResponse {
|
||||||
resp := &accountResponse{Account: a}
|
resp := &accountResponse{Account: a}
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
type accountListResponse struct {
|
type accountListResponse struct {
|
||||||
Accounts []auth.Account `json:"accounts"`
|
Accounts []pwdless.Account `json:"accounts"`
|
||||||
Count int `json:"count"`
|
Count int `json:"count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAccountListResponse(a []auth.Account, count int) *accountListResponse {
|
func newAccountListResponse(a []pwdless.Account, count int) *accountListResponse {
|
||||||
resp := &accountListResponse{
|
resp := &accountListResponse{
|
||||||
Accounts: a,
|
Accounts: a,
|
||||||
Count: count,
|
Count: count,
|
||||||
|
|
@ -100,7 +100,7 @@ func newAccountListResponse(a []auth.Account, count int) *accountListResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *AccountResource) list(w http.ResponseWriter, r *http.Request) {
|
func (rs *AccountResource) list(w http.ResponseWriter, r *http.Request) {
|
||||||
f := auth.NewAccountFilter(r.URL.Query())
|
f := pwdless.NewAccountFilter(r.URL.Query())
|
||||||
al, count, err := rs.Store.List(f)
|
al, count, err := rs.Store.List(f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Render(w, r, ErrRender(err))
|
render.Render(w, r, ErrRender(err))
|
||||||
|
|
@ -129,12 +129,12 @@ func (rs *AccountResource) create(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) {
|
func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) {
|
||||||
acc := r.Context().Value(ctxAccount).(*auth.Account)
|
acc := r.Context().Value(ctxAccount).(*pwdless.Account)
|
||||||
render.Respond(w, r, newAccountResponse(acc))
|
render.Respond(w, r, newAccountResponse(acc))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
|
func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
|
||||||
acc := r.Context().Value(ctxAccount).(*auth.Account)
|
acc := r.Context().Value(ctxAccount).(*pwdless.Account)
|
||||||
data := &accountRequest{Account: acc}
|
data := &accountRequest{Account: acc}
|
||||||
if err := render.Bind(r, data); err != nil {
|
if err := render.Bind(r, data); err != nil {
|
||||||
render.Render(w, r, ErrInvalidRequest(err))
|
render.Render(w, r, ErrInvalidRequest(err))
|
||||||
|
|
@ -155,7 +155,7 @@ func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) {
|
func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) {
|
||||||
acc := r.Context().Value(ctxAccount).(*auth.Account)
|
acc := r.Context().Value(ctxAccount).(*pwdless.Account)
|
||||||
if err := rs.Store.Delete(acc); err != nil {
|
if err := rs.Store.Delete(acc); err != nil {
|
||||||
render.Render(w, r, ErrInvalidRequest(err))
|
render.Render(w, r, ErrInvalidRequest(err))
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/go-pg/pg"
|
"github.com/go-pg/pg"
|
||||||
|
|
||||||
"github.com/dhax/go-base/auth"
|
"github.com/dhax/go-base/auth/authorize"
|
||||||
"github.com/dhax/go-base/database"
|
"github.com/dhax/go-base/database"
|
||||||
"github.com/dhax/go-base/logging"
|
"github.com/dhax/go-base/logging"
|
||||||
)
|
)
|
||||||
|
|
@ -44,7 +44,7 @@ func NewAPI(db *pg.DB) (*API, error) {
|
||||||
// Router provides admin application routes.
|
// Router provides admin application routes.
|
||||||
func (a *API) Router() *chi.Mux {
|
func (a *API) Router() *chi.Mux {
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.Use(auth.RequiresRole(roleAdmin))
|
r.Use(authorize.RequiresRole(roleAdmin))
|
||||||
|
|
||||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("Hello Admin"))
|
w.Write([]byte("Hello Admin"))
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,8 @@ import (
|
||||||
|
|
||||||
"github.com/dhax/go-base/api/admin"
|
"github.com/dhax/go-base/api/admin"
|
||||||
"github.com/dhax/go-base/api/app"
|
"github.com/dhax/go-base/api/app"
|
||||||
"github.com/dhax/go-base/auth"
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
|
"github.com/dhax/go-base/auth/pwdless"
|
||||||
"github.com/dhax/go-base/database"
|
"github.com/dhax/go-base/database"
|
||||||
"github.com/dhax/go-base/email"
|
"github.com/dhax/go-base/email"
|
||||||
"github.com/dhax/go-base/logging"
|
"github.com/dhax/go-base/logging"
|
||||||
|
|
@ -37,7 +38,7 @@ func New() (*chi.Mux, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
authStore := database.NewAuthStore(db)
|
authStore := database.NewAuthStore(db)
|
||||||
authResource, err := auth.NewResource(authStore, mailer)
|
authResource, err := pwdless.NewResource(authStore, mailer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithField("module", "auth").Error(err)
|
logger.WithField("module", "auth").Error(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -70,8 +71,8 @@ func New() (*chi.Mux, error) {
|
||||||
|
|
||||||
r.Mount("/auth", authResource.Router())
|
r.Mount("/auth", authResource.Router())
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(authResource.Token.Verifier())
|
r.Use(authResource.TokenAuth.Verifier())
|
||||||
r.Use(auth.Authenticator)
|
r.Use(jwt.Authenticator)
|
||||||
r.Mount("/admin", adminAPI.Router())
|
r.Mount("/admin", adminAPI.Router())
|
||||||
r.Mount("/api", appAPI.Router())
|
r.Mount("/api", appAPI.Router())
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -10,16 +10,17 @@ import (
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
validation "github.com/go-ozzo/ozzo-validation"
|
validation "github.com/go-ozzo/ozzo-validation"
|
||||||
|
|
||||||
"github.com/dhax/go-base/auth"
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
|
"github.com/dhax/go-base/auth/pwdless"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountStore defines database operations for account.
|
// AccountStore defines database operations for account.
|
||||||
type AccountStore interface {
|
type AccountStore interface {
|
||||||
Get(id int) (*auth.Account, error)
|
Get(id int) (*pwdless.Account, error)
|
||||||
Update(*auth.Account) error
|
Update(*pwdless.Account) error
|
||||||
Delete(*auth.Account) error
|
Delete(*pwdless.Account) error
|
||||||
UpdateToken(*auth.Token) error
|
UpdateToken(*jwt.Token) error
|
||||||
DeleteToken(*auth.Token) error
|
DeleteToken(*jwt.Token) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountResource implements account management handler.
|
// AccountResource implements account management handler.
|
||||||
|
|
@ -49,7 +50,7 @@ func (rs *AccountResource) router() *chi.Mux {
|
||||||
|
|
||||||
func (rs *AccountResource) accountCtx(next http.Handler) http.Handler {
|
func (rs *AccountResource) accountCtx(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := auth.ClaimsFromCtx(r.Context())
|
claims := jwt.ClaimsFromCtx(r.Context())
|
||||||
log(r).WithField("account_id", claims.ID)
|
log(r).WithField("account_id", claims.ID)
|
||||||
account, err := rs.Store.Get(claims.ID)
|
account, err := rs.Store.Get(claims.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -63,7 +64,7 @@ func (rs *AccountResource) accountCtx(next http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
type accountRequest struct {
|
type accountRequest struct {
|
||||||
*auth.Account
|
*pwdless.Account
|
||||||
// override protected data here, although not really necessary here
|
// override protected data here, although not really necessary here
|
||||||
// as we limit updated database columns in store as well
|
// as we limit updated database columns in store as well
|
||||||
ProtectedID int `json:"id"`
|
ProtectedID int `json:"id"`
|
||||||
|
|
@ -78,21 +79,21 @@ func (d *accountRequest) Bind(r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type accountResponse struct {
|
type accountResponse struct {
|
||||||
*auth.Account
|
*pwdless.Account
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAccountResponse(a *auth.Account) *accountResponse {
|
func newAccountResponse(a *pwdless.Account) *accountResponse {
|
||||||
resp := &accountResponse{Account: a}
|
resp := &accountResponse{Account: a}
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) {
|
func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) {
|
||||||
acc := r.Context().Value(ctxAccount).(*auth.Account)
|
acc := r.Context().Value(ctxAccount).(*pwdless.Account)
|
||||||
render.Respond(w, r, newAccountResponse(acc))
|
render.Respond(w, r, newAccountResponse(acc))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
|
func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
|
||||||
acc := r.Context().Value(ctxAccount).(*auth.Account)
|
acc := r.Context().Value(ctxAccount).(*pwdless.Account)
|
||||||
data := &accountRequest{Account: acc}
|
data := &accountRequest{Account: acc}
|
||||||
if err := render.Bind(r, data); err != nil {
|
if err := render.Bind(r, data); err != nil {
|
||||||
render.Render(w, r, ErrInvalidRequest(err))
|
render.Render(w, r, ErrInvalidRequest(err))
|
||||||
|
|
@ -113,7 +114,7 @@ func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) {
|
func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) {
|
||||||
acc := r.Context().Value(ctxAccount).(*auth.Account)
|
acc := r.Context().Value(ctxAccount).(*pwdless.Account)
|
||||||
if err := rs.Store.Delete(acc); err != nil {
|
if err := rs.Store.Delete(acc); err != nil {
|
||||||
render.Render(w, r, ErrRender(err))
|
render.Render(w, r, ErrRender(err))
|
||||||
return
|
return
|
||||||
|
|
@ -142,10 +143,10 @@ func (rs *AccountResource) updateToken(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Render(w, r, ErrInvalidRequest(err))
|
render.Render(w, r, ErrInvalidRequest(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acc := r.Context().Value(ctxAccount).(*auth.Account)
|
acc := r.Context().Value(ctxAccount).(*pwdless.Account)
|
||||||
for _, t := range acc.Token {
|
for _, t := range acc.Token {
|
||||||
if t.ID == id {
|
if t.ID == id {
|
||||||
if err := rs.Store.UpdateToken(&auth.Token{
|
if err := rs.Store.UpdateToken(&jwt.Token{
|
||||||
ID: t.ID,
|
ID: t.ID,
|
||||||
Identifier: data.Identifier,
|
Identifier: data.Identifier,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
|
@ -163,10 +164,10 @@ func (rs *AccountResource) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Render(w, r, ErrBadRequest)
|
render.Render(w, r, ErrBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acc := r.Context().Value(ctxAccount).(*auth.Account)
|
acc := r.Context().Value(ctxAccount).(*pwdless.Account)
|
||||||
for _, t := range acc.Token {
|
for _, t := range acc.Token {
|
||||||
if t.ID == id {
|
if t.ID == id {
|
||||||
rs.Store.DeleteToken(&auth.Token{ID: t.ID})
|
rs.Store.DeleteToken(&jwt.Token{ID: t.ID})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
render.Respond(w, r, http.NoBody)
|
render.Respond(w, r, http.NoBody)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/dhax/go-base/auth"
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
"github.com/dhax/go-base/models"
|
"github.com/dhax/go-base/models"
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
|
|
@ -39,7 +39,7 @@ func (rs *ProfileResource) router() *chi.Mux {
|
||||||
|
|
||||||
func (rs *ProfileResource) profileCtx(next http.Handler) http.Handler {
|
func (rs *ProfileResource) profileCtx(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := auth.ClaimsFromCtx(r.Context())
|
claims := jwt.ClaimsFromCtx(r.Context())
|
||||||
p, err := rs.Store.Get(claims.ID)
|
p, err := rs.Store.Get(claims.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log(r).WithField("profileCtx", claims.Sub).Error(err)
|
log(r).WithField("profileCtx", claims.Sub).Error(err)
|
||||||
|
|
|
||||||
92
auth/api.go
92
auth/api.go
|
|
@ -1,92 +0,0 @@
|
||||||
// Package auth provides JSON Web Token (JWT) authentication and authorization middleware.
|
|
||||||
// It implements a passwordless authentication flow by sending login tokens vie email which are then exchanged for JWT access and refresh tokens.
|
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/dhax/go-base/email"
|
|
||||||
"github.com/dhax/go-base/logging"
|
|
||||||
"github.com/go-chi/chi"
|
|
||||||
"github.com/go-chi/render"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Storer defines database operations on account and token data.
|
|
||||||
type Storer interface {
|
|
||||||
GetByID(id int) (*Account, error)
|
|
||||||
GetByEmail(email string) (*Account, error)
|
|
||||||
GetByRefreshToken(token string) (*Account, *Token, error)
|
|
||||||
UpdateAccount(a *Account) error
|
|
||||||
SaveRefreshToken(t *Token) error
|
|
||||||
DeleteRefreshToken(t *Token) error
|
|
||||||
PurgeExpiredToken() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mailer defines methods to send account emails.
|
|
||||||
type Mailer interface {
|
|
||||||
LoginToken(name, email string, c email.ContentLoginToken) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resource implements passwordless token authentication against a database.
|
|
||||||
type Resource struct {
|
|
||||||
Login *LoginTokenAuth
|
|
||||||
Token *TokenAuth
|
|
||||||
store Storer
|
|
||||||
mailer Mailer
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewResource returns a configured authentication resource.
|
|
||||||
func NewResource(store Storer, mailer Mailer) (*Resource, error) {
|
|
||||||
loginAuth, err := NewLoginTokenAuth()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenAuth, err := NewTokenAuth()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resource := &Resource{
|
|
||||||
Login: loginAuth,
|
|
||||||
Token: tokenAuth,
|
|
||||||
store: store,
|
|
||||||
mailer: mailer,
|
|
||||||
}
|
|
||||||
|
|
||||||
resource.cleanupTicker()
|
|
||||||
|
|
||||||
return resource, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Router provides necessary routes for passwordless authentication flow.
|
|
||||||
func (rs *Resource) Router() *chi.Mux {
|
|
||||||
r := chi.NewRouter()
|
|
||||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
|
||||||
r.Post("/login", rs.login)
|
|
||||||
r.Post("/token", rs.token)
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(rs.Token.Verifier())
|
|
||||||
r.Use(AuthenticateRefreshJWT)
|
|
||||||
r.Post("/refresh", rs.refresh)
|
|
||||||
r.Post("/logout", rs.logout)
|
|
||||||
})
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs *Resource) cleanupTicker() {
|
|
||||||
ticker := time.NewTicker(time.Hour * 1)
|
|
||||||
go func() {
|
|
||||||
for range ticker.C {
|
|
||||||
if err := rs.store.PurgeExpiredToken(); err != nil {
|
|
||||||
logging.Logger.WithField("auth", "cleanup").Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func log(r *http.Request) logrus.FieldLogger {
|
|
||||||
return logging.GetLogEntry(r)
|
|
||||||
}
|
|
||||||
31
auth/authorize/errors.go
Normal file
31
auth/authorize/errors.go
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
package authorize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-chi/render"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrResponse renderer type for handling all sorts of errors.
|
||||||
|
type ErrResponse struct {
|
||||||
|
Err error `json:"-"` // low-level runtime error
|
||||||
|
HTTPStatusCode int `json:"-"` // http response status code
|
||||||
|
|
||||||
|
StatusText string `json:"status"` // user-level status message
|
||||||
|
AppCode int64 `json:"code,omitempty"` // application-specific error code
|
||||||
|
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render sets the application-specific error code in AppCode.
|
||||||
|
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
render.Status(r, e.HTTPStatusCode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The list of default error types without specific error message.
|
||||||
|
var (
|
||||||
|
ErrForbidden = &ErrResponse{
|
||||||
|
HTTPStatusCode: http.StatusForbidden,
|
||||||
|
StatusText: http.StatusText(http.StatusForbidden),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
@ -1,16 +1,18 @@
|
||||||
package auth
|
package authorize
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
|
|
||||||
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequiresRole middleware restricts access to accounts having role parameter in their jwt claims.
|
// RequiresRole middleware restricts access to accounts having role parameter in their jwt claims.
|
||||||
func RequiresRole(role string) func(next http.Handler) http.Handler {
|
func RequiresRole(role string) func(next http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
hfn := func(w http.ResponseWriter, r *http.Request) {
|
hfn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := ClaimsFromCtx(r.Context())
|
claims := jwt.ClaimsFromCtx(r.Context())
|
||||||
if !hasRole(role, claims.Roles) {
|
if !hasRole(role, claims.Roles) {
|
||||||
render.Render(w, r, ErrForbidden)
|
render.Render(w, r, ErrForbidden)
|
||||||
return
|
return
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
)
|
|
||||||
|
|
||||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
||||||
|
|
||||||
func randStringBytes(n int) string {
|
|
||||||
buf := make([]byte, n)
|
|
||||||
if _, err := rand.Read(buf); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range buf {
|
|
||||||
buf[k] = letterBytes[v%byte(len(letterBytes))]
|
|
||||||
}
|
|
||||||
return string(buf)
|
|
||||||
}
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/go-chi/render"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrResponse renderer type for handling all sorts of errors.
|
|
||||||
type ErrResponse struct {
|
|
||||||
Err error `json:"-"` // low-level runtime error
|
|
||||||
HTTPStatusCode int `json:"-"` // http response status code
|
|
||||||
|
|
||||||
StatusText string `json:"status"` // user-level status message
|
|
||||||
AppCode int64 `json:"code,omitempty"` // application-specific error code
|
|
||||||
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
|
|
||||||
}
|
|
||||||
|
|
||||||
// Render sets the application-specific error code in AppCode.
|
|
||||||
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
render.Status(r, e.HTTPStatusCode)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrUnauthorized renders status 401 Unauthorized with custom error message.
|
|
||||||
func ErrUnauthorized(err error) render.Renderer {
|
|
||||||
return &ErrResponse{
|
|
||||||
Err: err,
|
|
||||||
HTTPStatusCode: http.StatusUnauthorized,
|
|
||||||
StatusText: http.StatusText(http.StatusUnauthorized),
|
|
||||||
ErrorText: err.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrRender returns status 422 Unprocessable Entity for invalid request body
|
|
||||||
func ErrRender(err error) render.Renderer {
|
|
||||||
return &ErrResponse{
|
|
||||||
Err: err,
|
|
||||||
HTTPStatusCode: http.StatusUnprocessableEntity,
|
|
||||||
StatusText: http.StatusText(http.StatusUnprocessableEntity),
|
|
||||||
ErrorText: err.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrInvalidRequest returns status 422 Unprocessable Entity with validation errors
|
|
||||||
func ErrInvalidRequest(err error) render.Renderer {
|
|
||||||
return &ErrResponse{
|
|
||||||
Err: err,
|
|
||||||
HTTPStatusCode: http.StatusUnprocessableEntity,
|
|
||||||
StatusText: http.StatusText(http.StatusUnprocessableEntity),
|
|
||||||
ErrorText: err.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The list of default error types without specific error message.
|
|
||||||
var (
|
|
||||||
ErrBadRequest = &ErrResponse{HTTPStatusCode: http.StatusBadRequest, StatusText: http.StatusText(http.StatusBadRequest)}
|
|
||||||
|
|
||||||
ErrForbidden = &ErrResponse{HTTPStatusCode: http.StatusForbidden, StatusText: http.StatusText(http.StatusForbidden)}
|
|
||||||
|
|
||||||
ErrNotFound = &ErrResponse{HTTPStatusCode: http.StatusNotFound, StatusText: http.StatusText(http.StatusNotFound)}
|
|
||||||
|
|
||||||
ErrInternalServerError = &ErrResponse{HTTPStatusCode: http.StatusInternalServerError, StatusText: http.StatusText(http.StatusInternalServerError)}
|
|
||||||
)
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
package auth
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/jwtauth"
|
"github.com/go-chi/jwtauth"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
|
|
||||||
|
"github.com/dhax/go-base/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ctxKey int
|
type ctxKey int
|
||||||
|
|
@ -16,13 +17,6 @@ const (
|
||||||
ctxRefreshToken
|
ctxRefreshToken
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errTokenUnauthorized = errors.New("token unauthorized")
|
|
||||||
errTokenExpired = errors.New("token expired")
|
|
||||||
errInvalidAccessToken = errors.New("invalid access token")
|
|
||||||
errInvalidRefreshToken = errors.New("invalid refresh token")
|
|
||||||
)
|
|
||||||
|
|
||||||
// ClaimsFromCtx retrieves the parsed AppClaims from request context.
|
// ClaimsFromCtx retrieves the parsed AppClaims from request context.
|
||||||
func ClaimsFromCtx(ctx context.Context) AppClaims {
|
func ClaimsFromCtx(ctx context.Context) AppClaims {
|
||||||
return ctx.Value(ctxClaims).(AppClaims)
|
return ctx.Value(ctxClaims).(AppClaims)
|
||||||
|
|
@ -41,23 +35,27 @@ func Authenticator(next http.Handler) http.Handler {
|
||||||
token, claims, err := jwtauth.FromContext(r.Context())
|
token, claims, err := jwtauth.FromContext(r.Context())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log(r).Warn(err)
|
logging.GetLogEntry(r).Warn(err)
|
||||||
render.Render(w, r, ErrUnauthorized(errTokenUnauthorized))
|
render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !token.Valid {
|
if !token.Valid {
|
||||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Token is authenticated, parse claims
|
// Token is authenticated, parse claims
|
||||||
pc, ok := parseClaims(claims)
|
var c AppClaims
|
||||||
if !ok {
|
err = c.ParseClaims(claims)
|
||||||
render.Render(w, r, ErrUnauthorized(errInvalidAccessToken))
|
if err != nil {
|
||||||
|
logging.GetLogEntry(r).Error(err)
|
||||||
|
render.Render(w, r, ErrUnauthorized(ErrInvalidAccessToken))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(r.Context(), ctxClaims, pc)
|
|
||||||
|
// Set AppClaims on context
|
||||||
|
ctx := context.WithValue(r.Context(), ctxClaims, c)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -67,21 +65,25 @@ func AuthenticateRefreshJWT(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
token, claims, err := jwtauth.FromContext(r.Context())
|
token, claims, err := jwtauth.FromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log(r).Warn(err)
|
logging.GetLogEntry(r).Warn(err)
|
||||||
render.Render(w, r, ErrUnauthorized(errTokenUnauthorized))
|
render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !token.Valid {
|
if !token.Valid {
|
||||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
refreshToken, ok := parseRefreshClaims(claims)
|
|
||||||
if !ok {
|
// Token is authenticated, parse refresh token string
|
||||||
render.Render(w, r, ErrUnauthorized(errInvalidRefreshToken))
|
var c RefreshClaims
|
||||||
|
err = c.ParseClaims(claims)
|
||||||
|
if err != nil {
|
||||||
|
logging.GetLogEntry(r).Error(err)
|
||||||
|
render.Render(w, r, ErrUnauthorized(ErrInvalidRefreshToken))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Token is authenticated, set on context
|
// Set refresh token string on context
|
||||||
ctx := context.WithValue(r.Context(), ctxRefreshToken, refreshToken)
|
ctx := context.WithValue(r.Context(), ctxRefreshToken, c.Token)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
59
auth/jwt/claims.go
Normal file
59
auth/jwt/claims.go
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/go-chi/jwtauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AppClaims represent the claims parsed from JWT access token.
|
||||||
|
type AppClaims struct {
|
||||||
|
ID int
|
||||||
|
Sub string
|
||||||
|
Roles []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseClaims parses JWT claims into AppClaims.
|
||||||
|
func (c *AppClaims) ParseClaims(claims jwtauth.Claims) error {
|
||||||
|
id, ok := claims.Get("id")
|
||||||
|
if !ok {
|
||||||
|
return errors.New("could not parse claim id")
|
||||||
|
}
|
||||||
|
c.ID = int(id.(float64))
|
||||||
|
|
||||||
|
sub, ok := claims.Get("sub")
|
||||||
|
if !ok {
|
||||||
|
return errors.New("could not parse claim sub")
|
||||||
|
}
|
||||||
|
c.Sub = sub.(string)
|
||||||
|
|
||||||
|
rl, ok := claims.Get("roles")
|
||||||
|
if !ok {
|
||||||
|
return errors.New("could not parse claims roles")
|
||||||
|
}
|
||||||
|
|
||||||
|
var roles []string
|
||||||
|
if rl != nil {
|
||||||
|
for _, v := range rl.([]interface{}) {
|
||||||
|
roles = append(roles, v.(string))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Roles = roles
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshClaims represent the claims parsed from JWT refresh token.
|
||||||
|
type RefreshClaims struct {
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseClaims parses the JWT claims into RefreshClaims.
|
||||||
|
func (c *RefreshClaims) ParseClaims(claims jwtauth.Claims) error {
|
||||||
|
token, ok := claims.Get("token")
|
||||||
|
if !ok {
|
||||||
|
return errors.New("could not parse claim token")
|
||||||
|
}
|
||||||
|
c.Token = token.(string)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
42
auth/jwt/errors.go
Normal file
42
auth/jwt/errors.go
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-chi/render"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The list of jwt token errors presented to the end user.
|
||||||
|
var (
|
||||||
|
ErrTokenUnauthorized = errors.New("token unauthorized")
|
||||||
|
ErrTokenExpired = errors.New("token expired")
|
||||||
|
ErrInvalidAccessToken = errors.New("invalid access token")
|
||||||
|
ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrResponse renderer type for handling all sorts of errors.
|
||||||
|
type ErrResponse struct {
|
||||||
|
Err error `json:"-"` // low-level runtime error
|
||||||
|
HTTPStatusCode int `json:"-"` // http response status code
|
||||||
|
|
||||||
|
StatusText string `json:"status"` // user-level status message
|
||||||
|
AppCode int64 `json:"code,omitempty"` // application-specific error code
|
||||||
|
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render sets the application-specific error code in AppCode.
|
||||||
|
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
render.Status(r, e.HTTPStatusCode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrUnauthorized renders status 401 Unauthorized with custom error message.
|
||||||
|
func ErrUnauthorized(err error) render.Renderer {
|
||||||
|
return &ErrResponse{
|
||||||
|
Err: err,
|
||||||
|
HTTPStatusCode: http.StatusUnauthorized,
|
||||||
|
StatusText: http.StatusText(http.StatusUnauthorized),
|
||||||
|
ErrorText: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package auth
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package auth
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -8,18 +9,11 @@ import (
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AppClaims represent the claims extracted from JWT token.
|
|
||||||
type AppClaims struct {
|
|
||||||
ID int
|
|
||||||
Sub string
|
|
||||||
Roles []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenAuth implements JWT authentication flow.
|
// TokenAuth implements JWT authentication flow.
|
||||||
type TokenAuth struct {
|
type TokenAuth struct {
|
||||||
JwtAuth *jwtauth.JwtAuth
|
JwtAuth *jwtauth.JwtAuth
|
||||||
jwtExpiry time.Duration
|
JwtExpiry time.Duration
|
||||||
jwtRefreshExpiry time.Duration
|
JwtRefreshExpiry time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTokenAuth configures and returns a JWT authentication instance.
|
// NewTokenAuth configures and returns a JWT authentication instance.
|
||||||
|
|
@ -31,8 +25,8 @@ func NewTokenAuth() (*TokenAuth, error) {
|
||||||
|
|
||||||
a := &TokenAuth{
|
a := &TokenAuth{
|
||||||
JwtAuth: jwtauth.New("HS256", []byte(secret), nil),
|
JwtAuth: jwtauth.New("HS256", []byte(secret), nil),
|
||||||
jwtExpiry: viper.GetDuration("auth_jwt_expiry"),
|
JwtExpiry: viper.GetDuration("auth_jwt_expiry"),
|
||||||
jwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"),
|
JwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"),
|
||||||
}
|
}
|
||||||
|
|
||||||
return a, nil
|
return a, nil
|
||||||
|
|
@ -59,7 +53,7 @@ func (a *TokenAuth) GenTokenPair(ca jwtauth.Claims, cr jwtauth.Claims) (string,
|
||||||
// CreateJWT returns an access token for provided account claims.
|
// CreateJWT returns an access token for provided account claims.
|
||||||
func (a *TokenAuth) CreateJWT(c jwtauth.Claims) (string, error) {
|
func (a *TokenAuth) CreateJWT(c jwtauth.Claims) (string, error) {
|
||||||
c.SetIssuedNow()
|
c.SetIssuedNow()
|
||||||
c.SetExpiryIn(a.jwtExpiry)
|
c.SetExpiryIn(a.JwtExpiry)
|
||||||
_, tokenString, err := a.JwtAuth.Encode(c)
|
_, tokenString, err := a.JwtAuth.Encode(c)
|
||||||
return tokenString, err
|
return tokenString, err
|
||||||
}
|
}
|
||||||
|
|
@ -67,46 +61,21 @@ func (a *TokenAuth) CreateJWT(c jwtauth.Claims) (string, error) {
|
||||||
// CreateRefreshJWT returns a refresh token for provided token Claims.
|
// CreateRefreshJWT returns a refresh token for provided token Claims.
|
||||||
func (a *TokenAuth) CreateRefreshJWT(c jwtauth.Claims) (string, error) {
|
func (a *TokenAuth) CreateRefreshJWT(c jwtauth.Claims) (string, error) {
|
||||||
c.SetIssuedNow()
|
c.SetIssuedNow()
|
||||||
c.SetExpiryIn(a.jwtRefreshExpiry)
|
c.SetExpiryIn(a.JwtRefreshExpiry)
|
||||||
_, tokenString, err := a.JwtAuth.Encode(c)
|
_, tokenString, err := a.JwtAuth.Encode(c)
|
||||||
return tokenString, err
|
return tokenString, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseClaims(c jwtauth.Claims) (AppClaims, bool) {
|
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
var claims AppClaims
|
|
||||||
allOK := true
|
|
||||||
id, ok := c.Get("id")
|
|
||||||
if !ok {
|
|
||||||
allOK = false
|
|
||||||
}
|
|
||||||
claims.ID = int(id.(float64))
|
|
||||||
|
|
||||||
sub, ok := c.Get("sub")
|
func randStringBytes(n int) string {
|
||||||
if !ok {
|
buf := make([]byte, n)
|
||||||
allOK = false
|
if _, err := rand.Read(buf); err != nil {
|
||||||
}
|
panic(err)
|
||||||
claims.Sub = sub.(string)
|
|
||||||
|
|
||||||
rl, ok := c.Get("roles")
|
|
||||||
if !ok {
|
|
||||||
allOK = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var roles []string
|
for k, v := range buf {
|
||||||
if rl != nil {
|
buf[k] = letterBytes[v%byte(len(letterBytes))]
|
||||||
for _, v := range rl.([]interface{}) {
|
|
||||||
roles = append(roles, v.(string))
|
|
||||||
}
|
}
|
||||||
}
|
return string(buf)
|
||||||
claims.Roles = roles
|
|
||||||
|
|
||||||
return claims, allOK
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseRefreshClaims(c jwtauth.Claims) (string, bool) {
|
|
||||||
token, ok := c.Get("token")
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return token.(string), ok
|
|
||||||
}
|
}
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
// MockStorer mocks Storer interface.
|
|
||||||
type MockStorer struct {
|
|
||||||
GetByIDFn func(id int) (*Account, error)
|
|
||||||
GetByIDInvoked bool
|
|
||||||
|
|
||||||
GetByEmailFn func(email string) (*Account, error)
|
|
||||||
GetByEmailInvoked bool
|
|
||||||
|
|
||||||
GetByRefreshTokenFn func(token string) (*Account, *Token, error)
|
|
||||||
GetByRefreshTokenInvoked bool
|
|
||||||
|
|
||||||
UpdateAccountFn func(a *Account) error
|
|
||||||
UpdateAccountInvoked bool
|
|
||||||
|
|
||||||
SaveRefreshTokenFn func(t *Token) error
|
|
||||||
SaveRefreshTokenInvoked bool
|
|
||||||
|
|
||||||
DeleteRefreshTokenFn func(t *Token) error
|
|
||||||
DeleteRefreshTokenInvoked bool
|
|
||||||
|
|
||||||
PurgeExpiredTokenFn func() error
|
|
||||||
PurgeExpiredTokenInvoked bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetByID mock returns an account by ID.
|
|
||||||
func (s *MockStorer) GetByID(id int) (*Account, error) {
|
|
||||||
s.GetByIDInvoked = true
|
|
||||||
return s.GetByIDFn(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetByEmail mock returns an account by email.
|
|
||||||
func (s *MockStorer) GetByEmail(email string) (*Account, error) {
|
|
||||||
s.GetByEmailInvoked = true
|
|
||||||
return s.GetByEmailFn(email)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetByRefreshToken mock returns an account and refresh token by token identifier.
|
|
||||||
func (s *MockStorer) GetByRefreshToken(token string) (*Account, *Token, error) {
|
|
||||||
s.GetByRefreshTokenInvoked = true
|
|
||||||
return s.GetByRefreshTokenFn(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateAccount mock upates account data related to authentication.
|
|
||||||
func (s *MockStorer) UpdateAccount(a *Account) error {
|
|
||||||
s.UpdateAccountInvoked = true
|
|
||||||
return s.UpdateAccountFn(a)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveRefreshToken mock creates or updates a refresh token.
|
|
||||||
func (s *MockStorer) SaveRefreshToken(t *Token) error {
|
|
||||||
s.SaveRefreshTokenInvoked = true
|
|
||||||
return s.SaveRefreshTokenFn(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRefreshToken mock deletes a refresh token.
|
|
||||||
func (s *MockStorer) DeleteRefreshToken(t *Token) error {
|
|
||||||
s.DeleteRefreshTokenInvoked = true
|
|
||||||
return s.DeleteRefreshTokenFn(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PurgeExpiredToken mock deletes expired refresh token.
|
|
||||||
func (s *MockStorer) PurgeExpiredToken() error {
|
|
||||||
s.PurgeExpiredTokenInvoked = true
|
|
||||||
return s.PurgeExpiredTokenFn()
|
|
||||||
}
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
package auth
|
package pwdless
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
"github.com/go-chi/jwtauth"
|
"github.com/go-chi/jwtauth"
|
||||||
validation "github.com/go-ozzo/ozzo-validation"
|
validation "github.com/go-ozzo/ozzo-validation"
|
||||||
"github.com/go-ozzo/ozzo-validation/is"
|
"github.com/go-ozzo/ozzo-validation/is"
|
||||||
|
|
@ -23,7 +24,7 @@ type Account struct {
|
||||||
Active bool `sql:",notnull" json:"active"`
|
Active bool `sql:",notnull" json:"active"`
|
||||||
Roles []string `pg:",array" json:"roles,omitempty"`
|
Roles []string `pg:",array" json:"roles,omitempty"`
|
||||||
|
|
||||||
Token []Token `json:"token,omitempty"`
|
Token []jwt.Token `json:"token,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeforeInsert hook executed before database insert operation.
|
// BeforeInsert hook executed before database insert operation.
|
||||||
|
|
@ -38,11 +39,8 @@ func (a *Account) BeforeInsert(db orm.DB) error {
|
||||||
|
|
||||||
// BeforeUpdate hook executed before database update operation.
|
// BeforeUpdate hook executed before database update operation.
|
||||||
func (a *Account) BeforeUpdate(db orm.DB) error {
|
func (a *Account) BeforeUpdate(db orm.DB) error {
|
||||||
if err := a.Validate(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
a.UpdatedAt = time.Now()
|
a.UpdatedAt = time.Now()
|
||||||
return nil
|
return a.Validate()
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeforeDelete hook executed before database delete operation.
|
// BeforeDelete hook executed before database delete operation.
|
||||||
|
|
@ -1,29 +1,93 @@
|
||||||
package auth
|
// Package pwdless provides JSON Web Token (JWT) authentication and authorization middleware.
|
||||||
|
// It implements a passwordless authentication flow by sending login tokens vie email which are then exchanged for JWT access and refresh tokens.
|
||||||
|
package pwdless
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
|
"github.com/dhax/go-base/email"
|
||||||
|
"github.com/dhax/go-base/logging"
|
||||||
|
"github.com/go-chi/chi"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
validation "github.com/go-ozzo/ozzo-validation"
|
validation "github.com/go-ozzo/ozzo-validation"
|
||||||
"github.com/go-ozzo/ozzo-validation/is"
|
"github.com/go-ozzo/ozzo-validation/is"
|
||||||
"github.com/mssola/user_agent"
|
"github.com/mssola/user_agent"
|
||||||
uuid "github.com/satori/go.uuid"
|
uuid "github.com/satori/go.uuid"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/dhax/go-base/email"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// The list of error types presented to the end user as error message.
|
// AuthStorer defines database operations on accounts and tokens.
|
||||||
var (
|
type AuthStorer interface {
|
||||||
ErrInvalidLogin = errors.New("invalid email address")
|
GetAccount(id int) (*Account, error)
|
||||||
ErrUnknownLogin = errors.New("email not registered")
|
GetAccountByEmail(email string) (*Account, error)
|
||||||
ErrLoginDisabled = errors.New("login for account disabled")
|
UpdateAccount(a *Account) error
|
||||||
ErrLoginToken = errors.New("invalid or expired login token")
|
|
||||||
)
|
GetToken(token string) (*jwt.Token, error)
|
||||||
|
CreateOrUpdateToken(t *jwt.Token) error
|
||||||
|
DeleteToken(t *jwt.Token) error
|
||||||
|
PurgeExpiredToken() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mailer defines methods to send account emails.
|
||||||
|
type Mailer interface {
|
||||||
|
LoginToken(name, email string, c email.ContentLoginToken) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resource implements passwordless account authentication against a database.
|
||||||
|
type Resource struct {
|
||||||
|
LoginAuth *LoginTokenAuth
|
||||||
|
TokenAuth *jwt.TokenAuth
|
||||||
|
Store AuthStorer
|
||||||
|
Mailer Mailer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResource returns a configured authentication resource.
|
||||||
|
func NewResource(authStore AuthStorer, mailer Mailer) (*Resource, error) {
|
||||||
|
loginAuth, err := NewLoginTokenAuth()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenAuth, err := jwt.NewTokenAuth()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resource := &Resource{
|
||||||
|
LoginAuth: loginAuth,
|
||||||
|
TokenAuth: tokenAuth,
|
||||||
|
Store: authStore,
|
||||||
|
Mailer: mailer,
|
||||||
|
}
|
||||||
|
|
||||||
|
resource.choresTicker()
|
||||||
|
|
||||||
|
return resource, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router provides necessary routes for passwordless authentication flow.
|
||||||
|
func (rs *Resource) Router() *chi.Mux {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||||
|
r.Post("/login", rs.login)
|
||||||
|
r.Post("/token", rs.token)
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Use(rs.TokenAuth.Verifier())
|
||||||
|
r.Use(jwt.AuthenticateRefreshJWT)
|
||||||
|
r.Post("/refresh", rs.refresh)
|
||||||
|
r.Post("/logout", rs.logout)
|
||||||
|
})
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func log(r *http.Request) logrus.FieldLogger {
|
||||||
|
return logging.GetLogEntry(r)
|
||||||
|
}
|
||||||
|
|
||||||
type loginRequest struct {
|
type loginRequest struct {
|
||||||
Email string
|
Email string
|
||||||
|
|
@ -46,7 +110,7 @@ func (rs *Resource) login(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
acc, err := rs.store.GetByEmail(body.Email)
|
acc, err := rs.Store.GetAccountByEmail(body.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log(r).WithField("email", body.Email).Warn(err)
|
log(r).WithField("email", body.Email).Warn(err)
|
||||||
render.Render(w, r, ErrUnauthorized(ErrUnknownLogin))
|
render.Render(w, r, ErrUnauthorized(ErrUnknownLogin))
|
||||||
|
|
@ -58,17 +122,17 @@ func (rs *Resource) login(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lt := rs.Login.CreateToken(acc.ID)
|
lt := rs.LoginAuth.CreateToken(acc.ID)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
content := email.ContentLoginToken{
|
content := email.ContentLoginToken{
|
||||||
Email: acc.Email,
|
Email: acc.Email,
|
||||||
Name: acc.Name,
|
Name: acc.Name,
|
||||||
URL: path.Join(rs.Login.loginURL, lt.Token),
|
URL: path.Join(rs.LoginAuth.loginURL, lt.Token),
|
||||||
Token: lt.Token,
|
Token: lt.Token,
|
||||||
Expiry: lt.Expiry,
|
Expiry: lt.Expiry,
|
||||||
}
|
}
|
||||||
if err := rs.mailer.LoginToken(acc.Name, acc.Email, content); err != nil {
|
if err := rs.Mailer.LoginToken(acc.Name, acc.Email, content); err != nil {
|
||||||
log(r).WithField("module", "email").Error(err)
|
log(r).WithField("module", "email").Error(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
@ -101,13 +165,13 @@ func (rs *Resource) token(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := rs.Login.GetAccountID(body.Token)
|
id, err := rs.LoginAuth.GetAccountID(body.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Render(w, r, ErrUnauthorized(ErrLoginToken))
|
render.Render(w, r, ErrUnauthorized(ErrLoginToken))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
acc, err := rs.store.GetByID(id)
|
acc, err := rs.Store.GetAccount(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// account deleted before login token expired
|
// account deleted before login token expired
|
||||||
render.Render(w, r, ErrUnauthorized(ErrUnknownLogin))
|
render.Render(w, r, ErrUnauthorized(ErrUnknownLogin))
|
||||||
|
|
@ -122,22 +186,22 @@ func (rs *Resource) token(w http.ResponseWriter, r *http.Request) {
|
||||||
ua := user_agent.New(r.UserAgent())
|
ua := user_agent.New(r.UserAgent())
|
||||||
browser, _ := ua.Browser()
|
browser, _ := ua.Browser()
|
||||||
|
|
||||||
token := &Token{
|
token := &jwt.Token{
|
||||||
Token: uuid.NewV4().String(),
|
Token: uuid.NewV4().String(),
|
||||||
Expiry: time.Now().Add(rs.Token.jwtRefreshExpiry),
|
Expiry: time.Now().Add(rs.TokenAuth.JwtRefreshExpiry),
|
||||||
UpdatedAt: time.Now(),
|
UpdatedAt: time.Now(),
|
||||||
AccountID: acc.ID,
|
AccountID: acc.ID,
|
||||||
Mobile: ua.Mobile(),
|
Mobile: ua.Mobile(),
|
||||||
Identifier: fmt.Sprintf("%s on %s", browser, ua.OS()),
|
Identifier: fmt.Sprintf("%s on %s", browser, ua.OS()),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rs.store.SaveRefreshToken(token); err != nil {
|
if err := rs.Store.CreateOrUpdateToken(token); err != nil {
|
||||||
log(r).Error(err)
|
log(r).Error(err)
|
||||||
render.Render(w, r, ErrInternalServerError)
|
render.Render(w, r, ErrInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
access, refresh, err := rs.Token.GenTokenPair(acc.Claims(), token.Claims())
|
access, refresh, err := rs.TokenAuth.GenTokenPair(acc.Claims(), token.Claims())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log(r).Error(err)
|
log(r).Error(err)
|
||||||
render.Render(w, r, ErrInternalServerError)
|
render.Render(w, r, ErrInternalServerError)
|
||||||
|
|
@ -145,7 +209,7 @@ func (rs *Resource) token(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
acc.LastLogin = time.Now()
|
acc.LastLogin = time.Now()
|
||||||
if err := rs.store.UpdateAccount(acc); err != nil {
|
if err := rs.Store.UpdateAccount(acc); err != nil {
|
||||||
log(r).Error(err)
|
log(r).Error(err)
|
||||||
render.Render(w, r, ErrInternalServerError)
|
render.Render(w, r, ErrInternalServerError)
|
||||||
return
|
return
|
||||||
|
|
@ -158,17 +222,23 @@ func (rs *Resource) token(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *Resource) refresh(w http.ResponseWriter, r *http.Request) {
|
func (rs *Resource) refresh(w http.ResponseWriter, r *http.Request) {
|
||||||
rt := RefreshTokenFromCtx(r.Context())
|
rt := jwt.RefreshTokenFromCtx(r.Context())
|
||||||
|
|
||||||
acc, token, err := rs.store.GetByRefreshToken(rt)
|
token, err := rs.Store.GetToken(rt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
render.Render(w, r, ErrUnauthorized(jwt.ErrTokenExpired))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().After(token.Expiry) {
|
if time.Now().After(token.Expiry) {
|
||||||
rs.store.DeleteRefreshToken(token)
|
rs.Store.DeleteToken(token)
|
||||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
render.Render(w, r, ErrUnauthorized(jwt.ErrTokenExpired))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := rs.Store.GetAccount(token.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
render.Render(w, r, ErrUnauthorized(ErrUnknownLogin))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -178,24 +248,24 @@ func (rs *Resource) refresh(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
token.Token = uuid.NewV4().String()
|
token.Token = uuid.NewV4().String()
|
||||||
token.Expiry = time.Now().Add(rs.Token.jwtRefreshExpiry)
|
token.Expiry = time.Now().Add(rs.TokenAuth.JwtRefreshExpiry)
|
||||||
token.UpdatedAt = time.Now()
|
token.UpdatedAt = time.Now()
|
||||||
|
|
||||||
access, refresh, err := rs.Token.GenTokenPair(acc.Claims(), token.Claims())
|
access, refresh, err := rs.TokenAuth.GenTokenPair(acc.Claims(), token.Claims())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log(r).Error(err)
|
log(r).Error(err)
|
||||||
render.Render(w, r, ErrInternalServerError)
|
render.Render(w, r, ErrInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rs.store.SaveRefreshToken(token); err != nil {
|
if err := rs.Store.CreateOrUpdateToken(token); err != nil {
|
||||||
log(r).Error(err)
|
log(r).Error(err)
|
||||||
render.Render(w, r, ErrInternalServerError)
|
render.Render(w, r, ErrInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
acc.LastLogin = time.Now()
|
acc.LastLogin = time.Now()
|
||||||
if err := rs.store.UpdateAccount(acc); err != nil {
|
if err := rs.Store.UpdateAccount(acc); err != nil {
|
||||||
log(r).Error(err)
|
log(r).Error(err)
|
||||||
render.Render(w, r, ErrInternalServerError)
|
render.Render(w, r, ErrInternalServerError)
|
||||||
return
|
return
|
||||||
|
|
@ -208,13 +278,13 @@ func (rs *Resource) refresh(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *Resource) logout(w http.ResponseWriter, r *http.Request) {
|
func (rs *Resource) logout(w http.ResponseWriter, r *http.Request) {
|
||||||
rt := RefreshTokenFromCtx(r.Context())
|
rt := jwt.RefreshTokenFromCtx(r.Context())
|
||||||
_, token, err := rs.store.GetByRefreshToken(rt)
|
token, err := rs.Store.GetToken(rt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
render.Render(w, r, ErrUnauthorized(jwt.ErrTokenExpired))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rs.store.DeleteRefreshToken(token)
|
rs.Store.DeleteToken(token)
|
||||||
|
|
||||||
render.Respond(w, r, http.NoBody)
|
render.Respond(w, r, http.NoBody)
|
||||||
}
|
}
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
package auth
|
package pwdless
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
@ -13,16 +14,18 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dhax/go-base/email"
|
|
||||||
"github.com/dhax/go-base/logging"
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/go-chi/jwtauth"
|
"github.com/go-chi/jwtauth"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
|
||||||
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
|
"github.com/dhax/go-base/email"
|
||||||
|
"github.com/dhax/go-base/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
auth *Resource
|
auth *Resource
|
||||||
authstore MockStorer
|
authStore MockAuthStore
|
||||||
mailer email.MockMailer
|
mailer email.MockMailer
|
||||||
ts *httptest.Server
|
ts *httptest.Server
|
||||||
)
|
)
|
||||||
|
|
@ -34,8 +37,9 @@ func TestMain(m *testing.M) {
|
||||||
viper.SetDefault("log_level", "error")
|
viper.SetDefault("log_level", "error")
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
auth, err = NewResource(&authstore, &mailer)
|
auth, err = NewResource(&authStore, &mailer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -51,7 +55,7 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthResource_login(t *testing.T) {
|
func TestAuthResource_login(t *testing.T) {
|
||||||
authstore.GetByEmailFn = func(email string) (*Account, error) {
|
authStore.GetAccountByEmailFn = func(email string) (*Account, error) {
|
||||||
var err error
|
var err error
|
||||||
a := Account{
|
a := Account{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
|
|
@ -100,20 +104,20 @@ func TestAuthResource_login(t *testing.T) {
|
||||||
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
||||||
t.Errorf(" got: %s, expected to contain: %s", body, tc.err.Error())
|
t.Errorf(" got: %s, expected to contain: %s", body, tc.err.Error())
|
||||||
}
|
}
|
||||||
if tc.err == ErrInvalidLogin && authstore.GetByEmailInvoked {
|
if tc.err == ErrInvalidLogin && authStore.GetAccountByEmailInvoked {
|
||||||
t.Error("GetByLoginToken invoked for invalid email")
|
t.Error("GetByLoginToken invoked for invalid email")
|
||||||
}
|
}
|
||||||
if tc.err == nil && !mailer.LoginTokenInvoked {
|
if tc.err == nil && !mailer.LoginTokenInvoked {
|
||||||
t.Error("emailService.LoginToken not invoked")
|
t.Error("emailService.LoginToken not invoked")
|
||||||
}
|
}
|
||||||
authstore.GetByEmailInvoked = false
|
authStore.GetAccountByEmailInvoked = false
|
||||||
mailer.LoginTokenInvoked = false
|
mailer.LoginTokenInvoked = false
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthResource_token(t *testing.T) {
|
func TestAuthResource_token(t *testing.T) {
|
||||||
authstore.GetByIDFn = func(id int) (*Account, error) {
|
authStore.GetAccountFn = func(id int) (*Account, error) {
|
||||||
var err error
|
var err error
|
||||||
a := Account{
|
a := Account{
|
||||||
ID: id,
|
ID: id,
|
||||||
|
|
@ -130,11 +134,11 @@ func TestAuthResource_token(t *testing.T) {
|
||||||
}
|
}
|
||||||
return &a, err
|
return &a, err
|
||||||
}
|
}
|
||||||
authstore.UpdateAccountFn = func(a *Account) error {
|
authStore.UpdateAccountFn = func(a *Account) error {
|
||||||
a.LastLogin = time.Now()
|
a.LastLogin = time.Now()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
authstore.SaveRefreshTokenFn = func(a *Token) error {
|
authStore.CreateOrUpdateTokenFn = func(a *jwt.Token) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -154,7 +158,7 @@ func TestAuthResource_token(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
token := auth.Login.CreateToken(tc.id)
|
token := auth.LoginAuth.CreateToken(tc.id)
|
||||||
if tc.token != "" {
|
if tc.token != "" {
|
||||||
token.Token = tc.token
|
token.Token = tc.token
|
||||||
}
|
}
|
||||||
|
|
@ -171,25 +175,37 @@ func TestAuthResource_token(t *testing.T) {
|
||||||
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
||||||
t.Errorf("got: %s, expected to contain: %s", body, tc.err.Error())
|
t.Errorf("got: %s, expected to contain: %s", body, tc.err.Error())
|
||||||
}
|
}
|
||||||
if tc.err == ErrLoginToken && authstore.SaveRefreshTokenInvoked {
|
if tc.err == ErrLoginToken && authStore.CreateOrUpdateTokenInvoked {
|
||||||
t.Errorf("SaveRefreshToken invoked despite error %s", tc.err.Error())
|
t.Errorf("CreateOrUpdate invoked despite error %s", tc.err.Error())
|
||||||
}
|
}
|
||||||
if tc.err == nil && !authstore.SaveRefreshTokenInvoked {
|
if tc.err == nil && !authStore.CreateOrUpdateTokenInvoked {
|
||||||
t.Error("SaveRefreshToken not invoked")
|
t.Error("CreateOrUpdate not invoked")
|
||||||
}
|
}
|
||||||
authstore.SaveRefreshTokenInvoked = false
|
authStore.CreateOrUpdateTokenInvoked = false
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthResource_refresh(t *testing.T) {
|
func TestAuthResource_refresh(t *testing.T) {
|
||||||
authstore.GetByRefreshTokenFn = func(token string) (*Account, *Token, error) {
|
authStore.GetAccountFn = func(id int) (*Account, error) {
|
||||||
var err error
|
|
||||||
a := Account{
|
a := Account{
|
||||||
Active: true,
|
Active: true,
|
||||||
Name: "Test",
|
Name: "Test",
|
||||||
}
|
}
|
||||||
var t Token
|
switch id {
|
||||||
|
case 999:
|
||||||
|
a.Active = false
|
||||||
|
}
|
||||||
|
return &a, nil
|
||||||
|
}
|
||||||
|
authStore.UpdateAccountFn = func(a *Account) error {
|
||||||
|
a.LastLogin = time.Now()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
authStore.GetTokenFn = func(token string) (*jwt.Token, error) {
|
||||||
|
var err error
|
||||||
|
var t jwt.Token
|
||||||
t.Expiry = time.Now().Add(1 * time.Minute)
|
t.Expiry = time.Now().Add(1 * time.Minute)
|
||||||
|
|
||||||
switch token {
|
switch token {
|
||||||
|
|
@ -198,20 +214,14 @@ func TestAuthResource_refresh(t *testing.T) {
|
||||||
case "expired":
|
case "expired":
|
||||||
t.Expiry = time.Now().Add(-1 * time.Minute)
|
t.Expiry = time.Now().Add(-1 * time.Minute)
|
||||||
case "disabled":
|
case "disabled":
|
||||||
a.Active = false
|
t.AccountID = 999
|
||||||
case "valid":
|
|
||||||
// unmodified
|
|
||||||
}
|
}
|
||||||
return &a, &t, err
|
return &t, err
|
||||||
}
|
}
|
||||||
authstore.UpdateAccountFn = func(a *Account) error {
|
authStore.CreateOrUpdateTokenFn = func(a *jwt.Token) error {
|
||||||
a.LastLogin = time.Now()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
authstore.SaveRefreshTokenFn = func(a *Token) error {
|
authStore.DeleteTokenFn = func(t *jwt.Token) error {
|
||||||
return nil
|
|
||||||
}
|
|
||||||
authstore.DeleteRefreshTokenFn = func(t *Token) error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -222,8 +232,8 @@ func TestAuthResource_refresh(t *testing.T) {
|
||||||
status int
|
status int
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{"notfound", "notfound", 1, http.StatusUnauthorized, errTokenExpired},
|
{"notfound", "notfound", 1, http.StatusUnauthorized, jwt.ErrTokenExpired},
|
||||||
{"expired", "expired", -1, http.StatusUnauthorized, errTokenUnauthorized},
|
{"expired", "expired", -1, http.StatusUnauthorized, jwt.ErrTokenUnauthorized},
|
||||||
{"disabled", "disabled", 1, http.StatusUnauthorized, ErrLoginDisabled},
|
{"disabled", "disabled", 1, http.StatusUnauthorized, ErrLoginDisabled},
|
||||||
{"valid", "valid", 1, http.StatusOK, nil},
|
{"valid", "valid", 1, http.StatusOK, nil},
|
||||||
}
|
}
|
||||||
|
|
@ -238,32 +248,31 @@ func TestAuthResource_refresh(t *testing.T) {
|
||||||
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
||||||
t.Errorf("got: %s, expected error to contain: %s", body, tc.err.Error())
|
t.Errorf("got: %s, expected error to contain: %s", body, tc.err.Error())
|
||||||
}
|
}
|
||||||
if tc.status == http.StatusUnauthorized && authstore.SaveRefreshTokenInvoked {
|
if tc.status == http.StatusUnauthorized && authStore.CreateOrUpdateTokenInvoked {
|
||||||
t.Errorf("SaveRefreshToken invoked for status %d", tc.status)
|
t.Errorf("CreateOrUpdate invoked for status %d", tc.status)
|
||||||
}
|
}
|
||||||
if tc.status == http.StatusOK {
|
if tc.status == http.StatusOK {
|
||||||
if !authstore.GetByRefreshTokenInvoked {
|
if !authStore.GetTokenInvoked {
|
||||||
t.Errorf("GetRefreshToken not invoked")
|
t.Errorf("GetByToken not invoked")
|
||||||
}
|
}
|
||||||
if !authstore.SaveRefreshTokenInvoked {
|
if !authStore.CreateOrUpdateTokenInvoked {
|
||||||
t.Errorf("SaveRefreshToken not invoked")
|
t.Errorf("CreateOrUpdate not invoked")
|
||||||
}
|
}
|
||||||
if authstore.DeleteRefreshTokenInvoked {
|
if authStore.DeleteTokenInvoked {
|
||||||
t.Errorf("DeleteRefreshToken should not be invoked")
|
t.Errorf("Delete should not be invoked")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
authstore.GetByRefreshTokenInvoked = false
|
authStore.GetTokenInvoked = false
|
||||||
authstore.SaveRefreshTokenInvoked = false
|
authStore.CreateOrUpdateTokenInvoked = false
|
||||||
authstore.DeleteRefreshTokenInvoked = false
|
authStore.DeleteTokenInvoked = false
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthResource_logout(t *testing.T) {
|
func TestAuthResource_logout(t *testing.T) {
|
||||||
authstore.GetByRefreshTokenFn = func(token string) (*Account, *Token, error) {
|
authStore.GetTokenFn = func(token string) (*jwt.Token, error) {
|
||||||
var err error
|
var err error
|
||||||
var a Account
|
t := jwt.Token{
|
||||||
t := Token{
|
|
||||||
Expiry: time.Now().Add(1 * time.Minute),
|
Expiry: time.Now().Add(1 * time.Minute),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -271,9 +280,9 @@ func TestAuthResource_logout(t *testing.T) {
|
||||||
case "notfound":
|
case "notfound":
|
||||||
err = errors.New("sql no rows")
|
err = errors.New("sql no rows")
|
||||||
}
|
}
|
||||||
return &a, &t, err
|
return &t, err
|
||||||
}
|
}
|
||||||
authstore.DeleteRefreshTokenFn = func(a *Token) error {
|
authStore.DeleteTokenFn = func(a *jwt.Token) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -284,8 +293,8 @@ func TestAuthResource_logout(t *testing.T) {
|
||||||
status int
|
status int
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{"notfound", "notfound", 1, http.StatusUnauthorized, errTokenExpired},
|
{"notfound", "notfound", 1, http.StatusUnauthorized, jwt.ErrTokenExpired},
|
||||||
{"expired", "valid", -1, http.StatusUnauthorized, errTokenUnauthorized},
|
{"expired", "valid", -1, http.StatusUnauthorized, jwt.ErrTokenUnauthorized},
|
||||||
{"valid", "valid", 1, http.StatusOK, nil},
|
{"valid", "valid", 1, http.StatusOK, nil},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -299,10 +308,10 @@ func TestAuthResource_logout(t *testing.T) {
|
||||||
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
||||||
t.Errorf("got: %x, expected error to contain %s", body, tc.err.Error())
|
t.Errorf("got: %x, expected error to contain %s", body, tc.err.Error())
|
||||||
}
|
}
|
||||||
if tc.status == http.StatusUnauthorized && authstore.DeleteRefreshTokenInvoked {
|
if tc.status == http.StatusUnauthorized && authStore.DeleteTokenInvoked {
|
||||||
t.Errorf("DeleteRefreshToken invoked for status %d", tc.status)
|
t.Errorf("Delete invoked for status %d", tc.status)
|
||||||
}
|
}
|
||||||
authstore.DeleteRefreshTokenInvoked = false
|
authStore.DeleteTokenInvoked = false
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -335,7 +344,7 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io
|
||||||
}
|
}
|
||||||
|
|
||||||
func genJWT(c jwtauth.Claims) string {
|
func genJWT(c jwtauth.Claims) string {
|
||||||
_, tokenString, _ := auth.Token.JwtAuth.Encode(c)
|
_, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(c)
|
||||||
return tokenString
|
return tokenString
|
||||||
}
|
}
|
||||||
|
|
||||||
18
auth/pwdless/chores.go
Normal file
18
auth/pwdless/chores.go
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
package pwdless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dhax/go-base/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (rs *Resource) choresTicker() {
|
||||||
|
ticker := time.NewTicker(time.Hour * 1)
|
||||||
|
go func() {
|
||||||
|
for range ticker.C {
|
||||||
|
if err := rs.Store.PurgeExpiredToken(); err != nil {
|
||||||
|
logging.Logger.WithField("chore", "purgeExpiredToken").Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
50
auth/pwdless/errors.go
Normal file
50
auth/pwdless/errors.go
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
package pwdless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-chi/render"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The list of error types presented to the end user as error message.
|
||||||
|
var (
|
||||||
|
ErrInvalidLogin = errors.New("invalid email address")
|
||||||
|
ErrUnknownLogin = errors.New("email not registered")
|
||||||
|
ErrLoginDisabled = errors.New("login for account disabled")
|
||||||
|
ErrLoginToken = errors.New("invalid or expired login token")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrResponse renderer type for handling all sorts of errors.
|
||||||
|
type ErrResponse struct {
|
||||||
|
Err error `json:"-"` // low-level runtime error
|
||||||
|
HTTPStatusCode int `json:"-"` // http response status code
|
||||||
|
|
||||||
|
StatusText string `json:"status"` // user-level status message
|
||||||
|
AppCode int64 `json:"code,omitempty"` // application-specific error code
|
||||||
|
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render sets the application-specific error code in AppCode.
|
||||||
|
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
render.Status(r, e.HTTPStatusCode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrUnauthorized renders status 401 Unauthorized with custom error message.
|
||||||
|
func ErrUnauthorized(err error) render.Renderer {
|
||||||
|
return &ErrResponse{
|
||||||
|
Err: err,
|
||||||
|
HTTPStatusCode: http.StatusUnauthorized,
|
||||||
|
StatusText: http.StatusText(http.StatusUnauthorized),
|
||||||
|
ErrorText: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The list of default error types without specific error message.
|
||||||
|
var (
|
||||||
|
ErrInternalServerError = &ErrResponse{
|
||||||
|
HTTPStatusCode: http.StatusInternalServerError,
|
||||||
|
StatusText: http.StatusText(http.StatusInternalServerError),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package auth
|
package pwdless
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -87,3 +88,17 @@ func (a *LoginTokenAuth) purgeExpired() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
|
||||||
|
func randStringBytes(n int) string {
|
||||||
|
buf := make([]byte, n)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range buf {
|
||||||
|
buf[k] = letterBytes[v%byte(len(letterBytes))]
|
||||||
|
}
|
||||||
|
return string(buf)
|
||||||
|
}
|
||||||
69
auth/pwdless/mockAuthStore.go
Normal file
69
auth/pwdless/mockAuthStore.go
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
package pwdless
|
||||||
|
|
||||||
|
import "github.com/dhax/go-base/auth/jwt"
|
||||||
|
|
||||||
|
// MockAuthStore mocks AuthStorer interface.
|
||||||
|
type MockAuthStore struct {
|
||||||
|
GetAccountFn func(id int) (*Account, error)
|
||||||
|
GetAccountInvoked bool
|
||||||
|
|
||||||
|
GetAccountByEmailFn func(email string) (*Account, error)
|
||||||
|
GetAccountByEmailInvoked bool
|
||||||
|
|
||||||
|
UpdateAccountFn func(a *Account) error
|
||||||
|
UpdateAccountInvoked bool
|
||||||
|
|
||||||
|
GetTokenFn func(token string) (*jwt.Token, error)
|
||||||
|
GetTokenInvoked bool
|
||||||
|
|
||||||
|
CreateOrUpdateTokenFn func(t *jwt.Token) error
|
||||||
|
CreateOrUpdateTokenInvoked bool
|
||||||
|
|
||||||
|
DeleteTokenFn func(t *jwt.Token) error
|
||||||
|
DeleteTokenInvoked bool
|
||||||
|
|
||||||
|
PurgeExpiredTokenFn func() error
|
||||||
|
PurgeExpiredTokenInvoked bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount mock returns an account by ID.
|
||||||
|
func (s *MockAuthStore) GetAccount(id int) (*Account, error) {
|
||||||
|
s.GetAccountInvoked = true
|
||||||
|
return s.GetAccountFn(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByEmail mock returns an account by email.
|
||||||
|
func (s *MockAuthStore) GetAccountByEmail(email string) (*Account, error) {
|
||||||
|
s.GetAccountByEmailInvoked = true
|
||||||
|
return s.GetAccountByEmailFn(email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccount mock upates account data related to authentication.
|
||||||
|
func (s *MockAuthStore) UpdateAccount(a *Account) error {
|
||||||
|
s.UpdateAccountInvoked = true
|
||||||
|
return s.UpdateAccountFn(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetToken mock returns an account and refresh token by token identifier.
|
||||||
|
func (s *MockAuthStore) GetToken(token string) (*jwt.Token, error) {
|
||||||
|
s.GetTokenInvoked = true
|
||||||
|
return s.GetTokenFn(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOrUpdateToken mock creates or updates a refresh token.
|
||||||
|
func (s *MockAuthStore) CreateOrUpdateToken(t *jwt.Token) error {
|
||||||
|
s.CreateOrUpdateTokenInvoked = true
|
||||||
|
return s.CreateOrUpdateTokenFn(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteToken mock deletes a refresh token.
|
||||||
|
func (s *MockAuthStore) DeleteToken(t *jwt.Token) error {
|
||||||
|
s.DeleteTokenInvoked = true
|
||||||
|
return s.DeleteTokenFn(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PurgeExpiredToken mock deletes expired refresh token.
|
||||||
|
func (s *MockAuthStore) PurgeExpiredToken() error {
|
||||||
|
s.PurgeExpiredTokenInvoked = true
|
||||||
|
return s.PurgeExpiredTokenFn()
|
||||||
|
}
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/dhax/go-base/auth"
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
|
"github.com/dhax/go-base/auth/pwdless"
|
||||||
"github.com/dhax/go-base/models"
|
"github.com/dhax/go-base/models"
|
||||||
"github.com/go-pg/pg"
|
"github.com/go-pg/pg"
|
||||||
)
|
)
|
||||||
|
|
@ -19,8 +20,8 @@ func NewAccountStore(db *pg.DB) *AccountStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get an account by ID.
|
// Get an account by ID.
|
||||||
func (s *AccountStore) Get(id int) (*auth.Account, error) {
|
func (s *AccountStore) Get(id int) (*pwdless.Account, error) {
|
||||||
a := auth.Account{ID: id}
|
a := pwdless.Account{ID: id}
|
||||||
err := s.db.Model(&a).
|
err := s.db.Model(&a).
|
||||||
Where("account.id = ?id").
|
Where("account.id = ?id").
|
||||||
Column("account.*", "Token").
|
Column("account.*", "Token").
|
||||||
|
|
@ -29,7 +30,7 @@ func (s *AccountStore) Get(id int) (*auth.Account, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update an account.
|
// Update an account.
|
||||||
func (s *AccountStore) Update(a *auth.Account) error {
|
func (s *AccountStore) Update(a *pwdless.Account) error {
|
||||||
_, err := s.db.Model(a).
|
_, err := s.db.Model(a).
|
||||||
Column("email", "name").
|
Column("email", "name").
|
||||||
Update()
|
Update()
|
||||||
|
|
@ -37,9 +38,9 @@ func (s *AccountStore) Update(a *auth.Account) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete an account.
|
// Delete an account.
|
||||||
func (s *AccountStore) Delete(a *auth.Account) error {
|
func (s *AccountStore) Delete(a *pwdless.Account) error {
|
||||||
err := s.db.RunInTransaction(func(tx *pg.Tx) error {
|
err := s.db.RunInTransaction(func(tx *pg.Tx) error {
|
||||||
if _, err := tx.Model(&auth.Token{}).
|
if _, err := tx.Model(&jwt.Token{}).
|
||||||
Where("account_id = ?", a.ID).
|
Where("account_id = ?", a.ID).
|
||||||
Delete(); err != nil {
|
Delete(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -55,7 +56,7 @@ func (s *AccountStore) Delete(a *auth.Account) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateToken updates a jwt refresh token.
|
// UpdateToken updates a jwt refresh token.
|
||||||
func (s *AccountStore) UpdateToken(t *auth.Token) error {
|
func (s *AccountStore) UpdateToken(t *jwt.Token) error {
|
||||||
_, err := s.db.Model(t).
|
_, err := s.db.Model(t).
|
||||||
Column("identifier").
|
Column("identifier").
|
||||||
Update()
|
Update()
|
||||||
|
|
@ -63,7 +64,7 @@ func (s *AccountStore) UpdateToken(t *auth.Token) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteToken deletes a jwt refresh token.
|
// DeleteToken deletes a jwt refresh token.
|
||||||
func (s *AccountStore) DeleteToken(t *auth.Token) error {
|
func (s *AccountStore) DeleteToken(t *jwt.Token) error {
|
||||||
err := s.db.Delete(t)
|
err := s.db.Delete(t)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@ package database
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/dhax/go-base/auth"
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
|
"github.com/dhax/go-base/auth/pwdless"
|
||||||
"github.com/dhax/go-base/models"
|
"github.com/dhax/go-base/models"
|
||||||
"github.com/go-pg/pg"
|
"github.com/go-pg/pg"
|
||||||
)
|
)
|
||||||
|
|
@ -26,8 +27,8 @@ func NewAdmAccountStore(db *pg.DB) *AdmAccountStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
// List applies a filter and returns paginated array of matching results and total count.
|
// List applies a filter and returns paginated array of matching results and total count.
|
||||||
func (s *AdmAccountStore) List(f auth.AccountFilter) ([]auth.Account, int, error) {
|
func (s *AdmAccountStore) List(f pwdless.AccountFilter) ([]pwdless.Account, int, error) {
|
||||||
a := []auth.Account{}
|
a := []pwdless.Account{}
|
||||||
count, err := s.db.Model(&a).
|
count, err := s.db.Model(&a).
|
||||||
Apply(f.Filter).
|
Apply(f.Filter).
|
||||||
SelectAndCount()
|
SelectAndCount()
|
||||||
|
|
@ -38,7 +39,7 @@ func (s *AdmAccountStore) List(f auth.AccountFilter) ([]auth.Account, int, error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new account.
|
// Create creates a new account.
|
||||||
func (s *AdmAccountStore) Create(a *auth.Account) error {
|
func (s *AdmAccountStore) Create(a *pwdless.Account) error {
|
||||||
count, _ := s.db.Model(a).
|
count, _ := s.db.Model(a).
|
||||||
Where("email = ?email").
|
Where("email = ?email").
|
||||||
Count()
|
Count()
|
||||||
|
|
@ -62,22 +63,22 @@ func (s *AdmAccountStore) Create(a *auth.Account) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get account by ID.
|
// Get account by ID.
|
||||||
func (s *AdmAccountStore) Get(id int) (*auth.Account, error) {
|
func (s *AdmAccountStore) Get(id int) (*pwdless.Account, error) {
|
||||||
a := auth.Account{ID: id}
|
a := pwdless.Account{ID: id}
|
||||||
err := s.db.Select(&a)
|
err := s.db.Select(&a)
|
||||||
return &a, err
|
return &a, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update account.
|
// Update account.
|
||||||
func (s *AdmAccountStore) Update(a *auth.Account) error {
|
func (s *AdmAccountStore) Update(a *pwdless.Account) error {
|
||||||
err := s.db.Update(a)
|
err := s.db.Update(a)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete account.
|
// Delete account.
|
||||||
func (s *AdmAccountStore) Delete(a *auth.Account) error {
|
func (s *AdmAccountStore) Delete(a *pwdless.Account) error {
|
||||||
err := s.db.RunInTransaction(func(tx *pg.Tx) error {
|
err := s.db.RunInTransaction(func(tx *pg.Tx) error {
|
||||||
if _, err := tx.Model(&auth.Token{}).
|
if _, err := tx.Model(&jwt.Token{}).
|
||||||
Where("account_id = ?", a.ID).
|
Where("account_id = ?", a.ID).
|
||||||
Delete(); err != nil {
|
Delete(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,12 @@ package database
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dhax/go-base/auth"
|
"github.com/dhax/go-base/auth/jwt"
|
||||||
|
"github.com/dhax/go-base/auth/pwdless"
|
||||||
"github.com/go-pg/pg"
|
"github.com/go-pg/pg"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthStore implements database operations for account authentication.
|
// AuthStore implements database operations for account pwdlessentication.
|
||||||
type AuthStore struct {
|
type AuthStore struct {
|
||||||
db *pg.DB
|
db *pg.DB
|
||||||
}
|
}
|
||||||
|
|
@ -19,9 +20,9 @@ func NewAuthStore(db *pg.DB) *AuthStore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID returns an account by ID.
|
// GetAccount returns an account by ID.
|
||||||
func (s *AuthStore) GetByID(id int) (*auth.Account, error) {
|
func (s *AuthStore) GetAccount(id int) (*pwdless.Account, error) {
|
||||||
a := auth.Account{ID: id}
|
a := pwdless.Account{ID: id}
|
||||||
err := s.db.Model(&a).
|
err := s.db.Model(&a).
|
||||||
Column("account.*").
|
Column("account.*").
|
||||||
Where("id = ?id").
|
Where("id = ?id").
|
||||||
|
|
@ -29,9 +30,9 @@ func (s *AuthStore) GetByID(id int) (*auth.Account, error) {
|
||||||
return &a, err
|
return &a, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByEmail returns an account by email.
|
// GetAccountByEmail returns an account by email.
|
||||||
func (s *AuthStore) GetByEmail(e string) (*auth.Account, error) {
|
func (s *AuthStore) GetAccountByEmail(e string) (*pwdless.Account, error) {
|
||||||
a := auth.Account{Email: e}
|
a := pwdless.Account{Email: e}
|
||||||
err := s.db.Model(&a).
|
err := s.db.Model(&a).
|
||||||
Column("id", "active", "email", "name").
|
Column("id", "active", "email", "name").
|
||||||
Where("email = ?email").
|
Where("email = ?email").
|
||||||
|
|
@ -39,35 +40,26 @@ func (s *AuthStore) GetByEmail(e string) (*auth.Account, error) {
|
||||||
return &a, err
|
return &a, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByRefreshToken returns an account and refresh token by token identifier.
|
// UpdateAccount upates account data related to pwdlessentication.
|
||||||
func (s *AuthStore) GetByRefreshToken(t string) (*auth.Account, *auth.Token, error) {
|
func (s *AuthStore) UpdateAccount(a *pwdless.Account) error {
|
||||||
token := auth.Token{Token: t}
|
|
||||||
err := s.db.Model(&token).
|
|
||||||
Where("token = ?token").
|
|
||||||
First()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
a := auth.Account{ID: token.AccountID}
|
|
||||||
err = s.db.Model(&a).
|
|
||||||
Column("account.*").
|
|
||||||
Where("id = ?id").
|
|
||||||
First()
|
|
||||||
|
|
||||||
return &a, &token, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateAccount upates account data related to authentication.
|
|
||||||
func (s *AuthStore) UpdateAccount(a *auth.Account) error {
|
|
||||||
_, err := s.db.Model(a).
|
_, err := s.db.Model(a).
|
||||||
Column("last_login").
|
Column("last_login").
|
||||||
Update()
|
Update()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveRefreshToken creates or updates a refresh token.
|
// GetToken returns refresh token by token identifier.
|
||||||
func (s *AuthStore) SaveRefreshToken(t *auth.Token) error {
|
func (s *AuthStore) GetToken(t string) (*jwt.Token, error) {
|
||||||
|
token := jwt.Token{Token: t}
|
||||||
|
err := s.db.Model(&token).
|
||||||
|
Where("token = ?token").
|
||||||
|
First()
|
||||||
|
|
||||||
|
return &token, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOrUpdateToken creates or updates an existing refresh token.
|
||||||
|
func (s *AuthStore) CreateOrUpdateToken(t *jwt.Token) error {
|
||||||
var err error
|
var err error
|
||||||
if t.ID == 0 {
|
if t.ID == 0 {
|
||||||
err = s.db.Insert(t)
|
err = s.db.Insert(t)
|
||||||
|
|
@ -77,15 +69,15 @@ func (s *AuthStore) SaveRefreshToken(t *auth.Token) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRefreshToken deletes a refresh token.
|
// DeleteToken deletes a refresh token.
|
||||||
func (s *AuthStore) DeleteRefreshToken(t *auth.Token) error {
|
func (s *AuthStore) DeleteToken(t *jwt.Token) error {
|
||||||
err := s.db.Delete(t)
|
err := s.db.Delete(t)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// PurgeExpiredToken deletes expired refresh token.
|
// PurgeExpiredToken deletes expired refresh token.
|
||||||
func (s *AuthStore) PurgeExpiredToken() error {
|
func (s *AuthStore) PurgeExpiredToken() error {
|
||||||
_, err := s.db.Model(&auth.Token{}).
|
_, err := s.db.Model(&jwt.Token{}).
|
||||||
Where("expiry < ?", time.Now()).
|
Where("expiry < ?", time.Now()).
|
||||||
Delete()
|
Delete()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue