move auth related models into auth package

This commit is contained in:
dhax 2017-10-04 19:35:22 +02:00
parent 2a9667a616
commit 6b7b5f2ae9
12 changed files with 159 additions and 151 deletions

View file

@ -8,7 +8,7 @@ import (
"github.com/go-ozzo/ozzo-validation" "github.com/go-ozzo/ozzo-validation"
"github.com/dhax/go-base/models" "github.com/dhax/go-base/auth"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/render" "github.com/go-chi/render"
) )
@ -26,11 +26,11 @@ const (
// AccountStore defines database operations for account management. // AccountStore defines database operations for account management.
type AccountStore interface { type AccountStore interface {
List(f models.AccountFilter) (*[]models.Account, int, error) List(f auth.AccountFilter) (*[]auth.Account, int, error)
Create(*models.Account) error Create(*auth.Account) error
Get(id int) (*models.Account, error) Get(id int) (*auth.Account, error)
Update(*models.Account) error Update(*auth.Account) error
Delete(*models.Account) error Delete(*auth.Account) error
} }
// AccountResource implements account managment handler. // AccountResource implements account managment handler.
@ -76,7 +76,7 @@ func (rs *AccountResource) accountCtx(next http.Handler) http.Handler {
} }
type accountRequest struct { type accountRequest struct {
*models.Account *auth.Account
} }
func (d *accountRequest) Bind(r *http.Request) error { func (d *accountRequest) Bind(r *http.Request) error {
@ -84,20 +84,20 @@ func (d *accountRequest) Bind(r *http.Request) error {
} }
type accountResponse struct { type accountResponse struct {
*models.Account *auth.Account
} }
func newAccountResponse(a *models.Account) *accountResponse { func newAccountResponse(a *auth.Account) *accountResponse {
resp := &accountResponse{Account: a} resp := &accountResponse{Account: a}
return resp return resp
} }
type accountListResponse struct { type accountListResponse struct {
Accounts *[]models.Account `json:"accounts"` Accounts *[]auth.Account `json:"accounts"`
Count int `json:"count"` Count int `json:"count"`
} }
func newAccountListResponse(a *[]models.Account, count int) *accountListResponse { func newAccountListResponse(a *[]auth.Account, count int) *accountListResponse {
resp := &accountListResponse{ resp := &accountListResponse{
Accounts: a, Accounts: a,
Count: count, Count: count,
@ -106,7 +106,7 @@ func newAccountListResponse(a *[]models.Account, count int) *accountListResponse
} }
func (rs *AccountResource) list(w http.ResponseWriter, r *http.Request) { func (rs *AccountResource) list(w http.ResponseWriter, r *http.Request) {
f := models.NewAccountFilter(r.URL.Query()) f := auth.NewAccountFilter(r.URL.Query())
al, count, err := rs.Store.List(f) al, count, err := rs.Store.List(f)
if err != nil { if err != nil {
render.Render(w, r, ErrRender(err)) render.Render(w, r, ErrRender(err))
@ -136,12 +136,12 @@ func (rs *AccountResource) create(w http.ResponseWriter, r *http.Request) {
} }
func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) { func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) {
acc := r.Context().Value(ctxAccount).(*models.Account) acc := r.Context().Value(ctxAccount).(*auth.Account)
render.Respond(w, r, newAccountResponse(acc)) render.Respond(w, r, newAccountResponse(acc))
} }
func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) { func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
acc := r.Context().Value(ctxAccount).(*models.Account) acc := r.Context().Value(ctxAccount).(*auth.Account)
data := &accountRequest{Account: acc} data := &accountRequest{Account: acc}
if err := render.Bind(r, data); err != nil { if err := render.Bind(r, data); err != nil {
render.Render(w, r, ErrInvalidRequest(err)) render.Render(w, r, ErrInvalidRequest(err))
@ -163,7 +163,7 @@ func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
} }
func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) { func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) {
acc := r.Context().Value(ctxAccount).(*models.Account) acc := r.Context().Value(ctxAccount).(*auth.Account)
if err := rs.Store.Delete(acc); err != nil { if err := rs.Store.Delete(acc); err != nil {
render.Render(w, r, ErrInvalidRequest(err)) render.Render(w, r, ErrInvalidRequest(err))
return return

View file

@ -23,11 +23,11 @@ const (
// AccountStore defines database operations for account. // AccountStore defines database operations for account.
type AccountStore interface { type AccountStore interface {
Get(id int) (*models.Account, error) Get(id int) (*auth.Account, error)
Update(*models.Account) error Update(*auth.Account) error
Delete(*models.Account) error Delete(*auth.Account) error
UpdateToken(*models.Token) error UpdateToken(*auth.Token) error
DeleteToken(*models.Token) error DeleteToken(*auth.Token) error
UpdateProfile(*models.Profile) error UpdateProfile(*models.Profile) error
} }
@ -74,7 +74,7 @@ func (rs *AccountResource) accountCtx(next http.Handler) http.Handler {
} }
type accountRequest struct { type accountRequest struct {
*models.Account *auth.Account
// not really neccessary here as we limit updated database columns in store // not really neccessary here as we limit updated database columns in store
ProtectedID int `json:"id"` ProtectedID int `json:"id"`
ProtectedActive bool `json:"active"` ProtectedActive bool `json:"active"`
@ -88,21 +88,21 @@ func (d *accountRequest) Bind(r *http.Request) error {
} }
type accountResponse struct { type accountResponse struct {
*models.Account *auth.Account
} }
func newAccountResponse(a *models.Account) *accountResponse { func newAccountResponse(a *auth.Account) *accountResponse {
resp := &accountResponse{Account: a} resp := &accountResponse{Account: a}
return resp return resp
} }
func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) { func (rs *AccountResource) get(w http.ResponseWriter, r *http.Request) {
acc := r.Context().Value(ctxAccount).(*models.Account) acc := r.Context().Value(ctxAccount).(*auth.Account)
render.Respond(w, r, newAccountResponse(acc)) render.Respond(w, r, newAccountResponse(acc))
} }
func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) { func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
acc := r.Context().Value(ctxAccount).(*models.Account) acc := r.Context().Value(ctxAccount).(*auth.Account)
data := &accountRequest{Account: acc} data := &accountRequest{Account: acc}
if err := render.Bind(r, data); err != nil { if err := render.Bind(r, data); err != nil {
render.Render(w, r, ErrInvalidRequest(err)) render.Render(w, r, ErrInvalidRequest(err))
@ -124,7 +124,7 @@ func (rs *AccountResource) update(w http.ResponseWriter, r *http.Request) {
} }
func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) { func (rs *AccountResource) delete(w http.ResponseWriter, r *http.Request) {
acc := r.Context().Value(ctxAccount).(*models.Account) acc := r.Context().Value(ctxAccount).(*auth.Account)
if err := rs.Store.Delete(acc); err != nil { if err := rs.Store.Delete(acc); err != nil {
render.Render(w, r, ErrRender(err)) render.Render(w, r, ErrRender(err))
return return
@ -153,10 +153,10 @@ func (rs *AccountResource) updateToken(w http.ResponseWriter, r *http.Request) {
render.Render(w, r, ErrInvalidRequest(err)) render.Render(w, r, ErrInvalidRequest(err))
return return
} }
acc := r.Context().Value(ctxAccount).(*models.Account) acc := r.Context().Value(ctxAccount).(*auth.Account)
for _, t := range acc.Token { for _, t := range acc.Token {
if t.ID == id { if t.ID == id {
if err := rs.Store.UpdateToken(&models.Token{ if err := rs.Store.UpdateToken(&auth.Token{
ID: t.ID, ID: t.ID,
Identifier: data.Identifier, Identifier: data.Identifier,
}); err != nil { }); err != nil {
@ -174,10 +174,10 @@ func (rs *AccountResource) deleteToken(w http.ResponseWriter, r *http.Request) {
render.Render(w, r, ErrBadRequest) render.Render(w, r, ErrBadRequest)
return return
} }
acc := r.Context().Value(ctxAccount).(*models.Account) acc := r.Context().Value(ctxAccount).(*auth.Account)
for _, t := range acc.Token { for _, t := range acc.Token {
if t.ID == id { if t.ID == id {
rs.Store.DeleteToken(&models.Token{ID: t.ID}) rs.Store.DeleteToken(&auth.Token{ID: t.ID})
} }
} }
render.Respond(w, r, http.NoBody) render.Respond(w, r, http.NoBody)
@ -205,7 +205,7 @@ func newProfileResponse(p *models.Profile) *profileResponse {
} }
func (rs *AccountResource) updateProfile(w http.ResponseWriter, r *http.Request) { func (rs *AccountResource) updateProfile(w http.ResponseWriter, r *http.Request) {
acc := r.Context().Value(ctxAccount).(*models.Account) acc := r.Context().Value(ctxAccount).(*auth.Account)
data := &profileRequest{Profile: acc.Profile} data := &profileRequest{Profile: acc.Profile}
if err := render.Bind(r, data); err != nil { if err := render.Bind(r, data); err != nil {
render.Render(w, r, ErrInvalidRequest(err)) render.Render(w, r, ErrInvalidRequest(err))

View file

@ -1,10 +1,11 @@
package models package auth
import ( import (
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/dhax/go-base/models"
"github.com/go-chi/jwtauth" "github.com/go-chi/jwtauth"
validation "github.com/go-ozzo/ozzo-validation" validation "github.com/go-ozzo/ozzo-validation"
"github.com/go-ozzo/ozzo-validation/is" "github.com/go-ozzo/ozzo-validation/is"
@ -23,7 +24,7 @@ type Account struct {
Active bool `sql:",notnull" json:"active"` Active bool `sql:",notnull" json:"active"`
Roles []string `pg:",array" json:"roles,omitempty"` Roles []string `pg:",array" json:"roles,omitempty"`
Profile *Profile `json:"profile,omitempty"` Profile *models.Profile `json:"profile,omitempty"`
Token []*Token `json:"token,omitempty"` Token []*Token `json:"token,omitempty"`
} }
@ -71,6 +72,7 @@ func (a *Account) CanLogin() bool {
return a.Active return a.Active
} }
// Claims returns the account's claims to be signed
func (a *Account) Claims() jwtauth.Claims { func (a *Account) Claims() jwtauth.Claims {
return jwtauth.Claims{ return jwtauth.Claims{
"id": a.ID, "id": a.ID,

View file

@ -6,7 +6,6 @@ import (
"github.com/dhax/go-base/email" "github.com/dhax/go-base/email"
"github.com/dhax/go-base/logging" "github.com/dhax/go-base/logging"
"github.com/dhax/go-base/models"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/render" "github.com/go-chi/render"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -14,12 +13,12 @@ import (
// Storer defines database operations on account and token data. // Storer defines database operations on account and token data.
type Storer interface { type Storer interface {
GetByID(id int) (*models.Account, error) GetByID(id int) (*Account, error)
GetByEmail(email string) (*models.Account, error) GetByEmail(email string) (*Account, error)
GetByRefreshToken(token string) (*models.Account, *models.Token, error) GetByRefreshToken(token string) (*Account, *Token, error)
UpdateAccount(a *models.Account) error UpdateAccount(a *Account) error
SaveRefreshToken(u *models.Token) error SaveRefreshToken(t *Token) error
DeleteRefreshToken(t *models.Token) error DeleteRefreshToken(t *Token) error
PurgeExpiredToken() error PurgeExpiredToken() error
} }

View file

@ -15,7 +15,6 @@ import (
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
"github.com/dhax/go-base/email" "github.com/dhax/go-base/email"
"github.com/dhax/go-base/models"
) )
// The list of error types presented to the end user as error message. // The list of error types presented to the end user as error message.
@ -129,7 +128,7 @@ func (rs *Resource) token(w http.ResponseWriter, r *http.Request) {
ua := user_agent.New(r.UserAgent()) ua := user_agent.New(r.UserAgent())
browser, _ := ua.Browser() browser, _ := ua.Browser()
token := &models.Token{ token := &Token{
Token: uuid.NewV4().String(), Token: uuid.NewV4().String(),
Expiry: time.Now().Add(time.Minute * rs.Token.jwtRefreshExpiry), Expiry: time.Now().Add(time.Minute * rs.Token.jwtRefreshExpiry),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),

View file

@ -15,7 +15,6 @@ import (
"github.com/dhax/go-base/email" "github.com/dhax/go-base/email"
"github.com/dhax/go-base/logging" "github.com/dhax/go-base/logging"
"github.com/dhax/go-base/models"
"github.com/dhax/go-base/testing/mock" "github.com/dhax/go-base/testing/mock"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/jwtauth" "github.com/go-chi/jwtauth"
@ -24,7 +23,7 @@ import (
var ( var (
auth *Resource auth *Resource
authstore mock.AuthStore authstore MockStorer
mailer mock.Mailer mailer mock.Mailer
ts *httptest.Server ts *httptest.Server
) )
@ -53,9 +52,9 @@ func TestMain(m *testing.M) {
} }
func TestAuthResource_login(t *testing.T) { func TestAuthResource_login(t *testing.T) {
authstore.GetByEmailFn = func(email string) (*models.Account, error) { authstore.GetByEmailFn = func(email string) (*Account, error) {
var err error var err error
a := models.Account{ a := Account{
ID: 1, ID: 1,
Email: email, Email: email,
Name: "test", Name: "test",
@ -115,9 +114,9 @@ func TestAuthResource_login(t *testing.T) {
} }
func TestAuthResource_token(t *testing.T) { func TestAuthResource_token(t *testing.T) {
authstore.GetByIDFn = func(id int) (*models.Account, error) { authstore.GetByIDFn = func(id int) (*Account, error) {
var err error var err error
a := models.Account{ a := Account{
ID: id, ID: id,
Active: true, Active: true,
Name: "test", Name: "test",
@ -132,11 +131,11 @@ func TestAuthResource_token(t *testing.T) {
} }
return &a, err return &a, err
} }
authstore.UpdateAccountFn = func(a *models.Account) error { authstore.UpdateAccountFn = func(a *Account) error {
a.LastLogin = time.Now() a.LastLogin = time.Now()
return nil return nil
} }
authstore.SaveRefreshTokenFn = func(a *models.Token) error { authstore.SaveRefreshTokenFn = func(a *Token) error {
return nil return nil
} }
@ -185,13 +184,13 @@ func TestAuthResource_token(t *testing.T) {
} }
func TestAuthResource_refresh(t *testing.T) { func TestAuthResource_refresh(t *testing.T) {
authstore.GetByRefreshTokenFn = func(token string) (*models.Account, *models.Token, error) { authstore.GetByRefreshTokenFn = func(token string) (*Account, *Token, error) {
var err error var err error
a := models.Account{ a := Account{
Active: true, Active: true,
Name: "Test", Name: "Test",
} }
var t models.Token var t Token
t.Expiry = time.Now().Add(1 * time.Minute) t.Expiry = time.Now().Add(1 * time.Minute)
switch token { switch token {
@ -206,14 +205,14 @@ func TestAuthResource_refresh(t *testing.T) {
} }
return &a, &t, err return &a, &t, err
} }
authstore.UpdateAccountFn = func(a *models.Account) error { authstore.UpdateAccountFn = func(a *Account) error {
a.LastLogin = time.Now() a.LastLogin = time.Now()
return nil return nil
} }
authstore.SaveRefreshTokenFn = func(a *models.Token) error { authstore.SaveRefreshTokenFn = func(a *Token) error {
return nil return nil
} }
authstore.DeleteRefreshTokenFn = func(t *models.Token) error { authstore.DeleteRefreshTokenFn = func(t *Token) error {
return nil return nil
} }
@ -260,10 +259,10 @@ func TestAuthResource_refresh(t *testing.T) {
} }
func TestAuthResource_logout(t *testing.T) { func TestAuthResource_logout(t *testing.T) {
authstore.GetByRefreshTokenFn = func(token string) (*models.Account, *models.Token, error) { authstore.GetByRefreshTokenFn = func(token string) (*Account, *Token, error) {
var err error var err error
var a models.Account var a Account
t := models.Token{ t := Token{
Expiry: time.Now().Add(1 * time.Minute), Expiry: time.Now().Add(1 * time.Minute),
} }
@ -273,7 +272,7 @@ func TestAuthResource_logout(t *testing.T) {
} }
return &a, &t, err return &a, &t, err
} }
authstore.DeleteRefreshTokenFn = func(a *models.Token) error { authstore.DeleteRefreshTokenFn = func(a *Token) error {
return nil return nil
} }

67
auth/mockStorer.go Normal file
View file

@ -0,0 +1,67 @@
package auth
// MockStorer mocks Storer interface.
type MockStorer struct {
GetByIDFn func(id int) (*Account, error)
GetByIDInvoked bool
GetByEmailFn func(email string) (*Account, error)
GetByEmailInvoked bool
GetByRefreshTokenFn func(token string) (*Account, *Token, error)
GetByRefreshTokenInvoked bool
UpdateAccountFn func(a *Account) error
UpdateAccountInvoked bool
SaveRefreshTokenFn func(t *Token) error
SaveRefreshTokenInvoked bool
DeleteRefreshTokenFn func(t *Token) error
DeleteRefreshTokenInvoked bool
PurgeExpiredTokenFn func() error
PurgeExpiredTokenInvoked bool
}
// GetByID mock returns an account by ID.
func (s *MockStorer) GetByID(id int) (*Account, error) {
s.GetByIDInvoked = true
return s.GetByIDFn(id)
}
// GetByEmail mock returns an account by email.
func (s *MockStorer) GetByEmail(email string) (*Account, error) {
s.GetByEmailInvoked = true
return s.GetByEmailFn(email)
}
// GetByRefreshToken mock returns an account and refresh token by token identifier.
func (s *MockStorer) GetByRefreshToken(token string) (*Account, *Token, error) {
s.GetByRefreshTokenInvoked = true
return s.GetByRefreshTokenFn(token)
}
// UpdateAccount mock upates account data related to authentication.
func (s *MockStorer) UpdateAccount(a *Account) error {
s.UpdateAccountInvoked = true
return s.UpdateAccountFn(a)
}
// SaveRefreshToken mock creates or updates a refresh token.
func (s *MockStorer) SaveRefreshToken(t *Token) error {
s.SaveRefreshTokenInvoked = true
return s.SaveRefreshTokenFn(t)
}
// DeleteRefreshToken mock deletes a refresh token.
func (s *MockStorer) DeleteRefreshToken(t *Token) error {
s.DeleteRefreshTokenInvoked = true
return s.DeleteRefreshTokenFn(t)
}
// PurgeExpiredToken mock deletes expired refresh token.
func (s *MockStorer) PurgeExpiredToken() error {
s.PurgeExpiredTokenInvoked = true
return s.PurgeExpiredTokenFn()
}

View file

@ -1,4 +1,4 @@
package models package auth
import ( import (
"time" "time"
@ -36,6 +36,7 @@ func (t *Token) BeforeUpdate(db orm.DB) error {
return nil return nil
} }
// Claims returns the token claims to be signed
func (t *Token) Claims() jwtauth.Claims { func (t *Token) Claims() jwtauth.Claims {
return jwtauth.Claims{ return jwtauth.Claims{
"id": t.ID, "id": t.ID,

View file

@ -1,6 +1,7 @@
package database package database
import ( import (
"github.com/dhax/go-base/auth"
"github.com/dhax/go-base/models" "github.com/dhax/go-base/models"
"github.com/go-pg/pg" "github.com/go-pg/pg"
) )
@ -18,8 +19,8 @@ func NewAccountStore(db *pg.DB) *AccountStore {
} }
// Get an account by ID. // Get an account by ID.
func (s *AccountStore) Get(id int) (*models.Account, error) { func (s *AccountStore) Get(id int) (*auth.Account, error) {
a := models.Account{ID: id} a := auth.Account{ID: id}
err := s.db.Model(&a). err := s.db.Model(&a).
Where("account.id = ?id"). Where("account.id = ?id").
Column("account.*", "Profile", "Token"). Column("account.*", "Profile", "Token").
@ -28,7 +29,7 @@ func (s *AccountStore) Get(id int) (*models.Account, error) {
} }
// Update an account. // Update an account.
func (s *AccountStore) Update(a *models.Account) error { func (s *AccountStore) Update(a *auth.Account) error {
_, err := s.db.Model(a). _, err := s.db.Model(a).
Column("email", "name"). Column("email", "name").
Update() Update()
@ -36,9 +37,9 @@ func (s *AccountStore) Update(a *models.Account) error {
} }
// Delete an account. // Delete an account.
func (s *AccountStore) Delete(a *models.Account) error { func (s *AccountStore) Delete(a *auth.Account) error {
err := s.db.RunInTransaction(func(tx *pg.Tx) error { err := s.db.RunInTransaction(func(tx *pg.Tx) error {
if _, err := tx.Model(&models.Token{}). if _, err := tx.Model(&auth.Token{}).
Where("account_id = ?", a.ID). Where("account_id = ?", a.ID).
Delete(); err != nil { Delete(); err != nil {
return err return err
@ -54,7 +55,7 @@ func (s *AccountStore) Delete(a *models.Account) error {
} }
// UpdateToken updates a jwt refresh token. // UpdateToken updates a jwt refresh token.
func (s *AccountStore) UpdateToken(t *models.Token) error { func (s *AccountStore) UpdateToken(t *auth.Token) error {
_, err := s.db.Model(t). _, err := s.db.Model(t).
Column("identifier"). Column("identifier").
Update() Update()
@ -62,7 +63,7 @@ func (s *AccountStore) UpdateToken(t *models.Token) error {
} }
// DeleteToken deletes a jwt refresh token. // DeleteToken deletes a jwt refresh token.
func (s *AccountStore) DeleteToken(t *models.Token) error { func (s *AccountStore) DeleteToken(t *auth.Token) error {
err := s.db.Delete(t) err := s.db.Delete(t)
return err return err
} }

View file

@ -3,6 +3,7 @@ package database
import ( import (
"errors" "errors"
"github.com/dhax/go-base/auth"
"github.com/dhax/go-base/models" "github.com/dhax/go-base/models"
"github.com/go-pg/pg" "github.com/go-pg/pg"
) )
@ -25,8 +26,8 @@ func NewAdmAccountStore(db *pg.DB) *AdmAccountStore {
} }
// List applies a filter and returns paginated array of matching results and total count. // List applies a filter and returns paginated array of matching results and total count.
func (s *AdmAccountStore) List(f models.AccountFilter) (*[]models.Account, int, error) { func (s *AdmAccountStore) List(f auth.AccountFilter) (*[]auth.Account, int, error) {
var a []models.Account var a []auth.Account
count, err := s.db.Model(&a). count, err := s.db.Model(&a).
Apply(f.Filter). Apply(f.Filter).
SelectAndCount() SelectAndCount()
@ -37,7 +38,7 @@ func (s *AdmAccountStore) List(f models.AccountFilter) (*[]models.Account, int,
} }
// Create creates a new account. // Create creates a new account.
func (s *AdmAccountStore) Create(a *models.Account) error { func (s *AdmAccountStore) Create(a *auth.Account) error {
count, _ := s.db.Model(a). count, _ := s.db.Model(a).
Where("email = ?email"). Where("email = ?email").
Count() Count()
@ -61,22 +62,22 @@ func (s *AdmAccountStore) Create(a *models.Account) error {
} }
// Get account by ID. // Get account by ID.
func (s *AdmAccountStore) Get(id int) (*models.Account, error) { func (s *AdmAccountStore) Get(id int) (*auth.Account, error) {
a := models.Account{ID: id} a := auth.Account{ID: id}
err := s.db.Select(&a) err := s.db.Select(&a)
return &a, err return &a, err
} }
// Update account. // Update account.
func (s *AdmAccountStore) Update(a *models.Account) error { func (s *AdmAccountStore) Update(a *auth.Account) error {
err := s.db.Update(a) err := s.db.Update(a)
return err return err
} }
// Delete account. // Delete account.
func (s *AdmAccountStore) Delete(a *models.Account) error { func (s *AdmAccountStore) Delete(a *auth.Account) error {
err := s.db.RunInTransaction(func(tx *pg.Tx) error { err := s.db.RunInTransaction(func(tx *pg.Tx) error {
if _, err := tx.Model(&models.Token{}). if _, err := tx.Model(&auth.Token{}).
Where("account_id = ?", a.ID). Where("account_id = ?", a.ID).
Delete(); err != nil { Delete(); err != nil {
return err return err

View file

@ -3,7 +3,7 @@ package database
import ( import (
"time" "time"
"github.com/dhax/go-base/models" "github.com/dhax/go-base/auth"
"github.com/go-pg/pg" "github.com/go-pg/pg"
) )
@ -20,8 +20,8 @@ func NewAuthStore(db *pg.DB) *AuthStore {
} }
// GetByID returns an account by ID. // GetByID returns an account by ID.
func (s *AuthStore) GetByID(id int) (*models.Account, error) { func (s *AuthStore) GetByID(id int) (*auth.Account, error) {
a := models.Account{ID: id} a := auth.Account{ID: id}
err := s.db.Model(&a). err := s.db.Model(&a).
Column("account.*"). Column("account.*").
Where("id = ?id"). Where("id = ?id").
@ -30,8 +30,8 @@ func (s *AuthStore) GetByID(id int) (*models.Account, error) {
} }
// GetByEmail returns an account by email. // GetByEmail returns an account by email.
func (s *AuthStore) GetByEmail(e string) (*models.Account, error) { func (s *AuthStore) GetByEmail(e string) (*auth.Account, error) {
a := models.Account{Email: e} a := auth.Account{Email: e}
err := s.db.Model(&a). err := s.db.Model(&a).
Column("id", "active", "email", "name"). Column("id", "active", "email", "name").
Where("email = ?email"). Where("email = ?email").
@ -40,8 +40,8 @@ func (s *AuthStore) GetByEmail(e string) (*models.Account, error) {
} }
// GetByRefreshToken returns an account and refresh token by token identifier. // GetByRefreshToken returns an account and refresh token by token identifier.
func (s *AuthStore) GetByRefreshToken(t string) (*models.Account, *models.Token, error) { func (s *AuthStore) GetByRefreshToken(t string) (*auth.Account, *auth.Token, error) {
token := models.Token{Token: t} token := auth.Token{Token: t}
err := s.db.Model(&token). err := s.db.Model(&token).
Where("token = ?token"). Where("token = ?token").
First() First()
@ -49,7 +49,7 @@ func (s *AuthStore) GetByRefreshToken(t string) (*models.Account, *models.Token,
return nil, nil, err return nil, nil, err
} }
a := models.Account{ID: token.AccountID} a := auth.Account{ID: token.AccountID}
err = s.db.Model(&a). err = s.db.Model(&a).
Column("account.*"). Column("account.*").
Where("id = ?id"). Where("id = ?id").
@ -59,7 +59,7 @@ func (s *AuthStore) GetByRefreshToken(t string) (*models.Account, *models.Token,
} }
// UpdateAccount upates account data related to authentication. // UpdateAccount upates account data related to authentication.
func (s *AuthStore) UpdateAccount(a *models.Account) error { func (s *AuthStore) UpdateAccount(a *auth.Account) error {
_, err := s.db.Model(a). _, err := s.db.Model(a).
Column("last_login"). Column("last_login").
Update() Update()
@ -67,7 +67,7 @@ func (s *AuthStore) UpdateAccount(a *models.Account) error {
} }
// SaveRefreshToken creates or updates a refresh token. // SaveRefreshToken creates or updates a refresh token.
func (s *AuthStore) SaveRefreshToken(t *models.Token) error { func (s *AuthStore) SaveRefreshToken(t *auth.Token) error {
var err error var err error
if t.ID == 0 { if t.ID == 0 {
err = s.db.Insert(t) err = s.db.Insert(t)
@ -78,14 +78,14 @@ func (s *AuthStore) SaveRefreshToken(t *models.Token) error {
} }
// DeleteRefreshToken deletes a refresh token. // DeleteRefreshToken deletes a refresh token.
func (s *AuthStore) DeleteRefreshToken(t *models.Token) error { func (s *AuthStore) DeleteRefreshToken(t *auth.Token) error {
err := s.db.Delete(t) err := s.db.Delete(t)
return err return err
} }
// PurgeExpiredToken deletes expired refresh token. // PurgeExpiredToken deletes expired refresh token.
func (s *AuthStore) PurgeExpiredToken() error { func (s *AuthStore) PurgeExpiredToken() error {
_, err := s.db.Model(&models.Token{}). _, err := s.db.Model(&auth.Token{}).
Where("expiry < ?", time.Now()). Where("expiry < ?", time.Now()).
Delete() Delete()

View file

@ -1,61 +0,0 @@
package mock
import "github.com/dhax/go-base/models"
type AuthStore struct {
GetByIDFn func(id int) (*models.Account, error)
GetByIDInvoked bool
GetByEmailFn func(email string) (*models.Account, error)
GetByEmailInvoked bool
GetByRefreshTokenFn func(token string) (*models.Account, *models.Token, error)
GetByRefreshTokenInvoked bool
UpdateAccountFn func(a *models.Account) error
UpdateAccountInvoked bool
SaveRefreshTokenFn func(u *models.Token) error
SaveRefreshTokenInvoked bool
DeleteRefreshTokenFn func(t *models.Token) error
DeleteRefreshTokenInvoked bool
PurgeExpiredTokenFn func() error
PurgeExpiredTokenInvoked bool
}
func (s *AuthStore) GetByID(id int) (*models.Account, error) {
s.GetByIDInvoked = true
return s.GetByIDFn(id)
}
func (s *AuthStore) GetByEmail(email string) (*models.Account, error) {
s.GetByEmailInvoked = true
return s.GetByEmailFn(email)
}
func (s *AuthStore) GetByRefreshToken(token string) (*models.Account, *models.Token, error) {
s.GetByRefreshTokenInvoked = true
return s.GetByRefreshTokenFn(token)
}
func (s *AuthStore) UpdateAccount(a *models.Account) error {
s.UpdateAccountInvoked = true
return s.UpdateAccountFn(a)
}
func (s *AuthStore) SaveRefreshToken(u *models.Token) error {
s.SaveRefreshTokenInvoked = true
return s.SaveRefreshTokenFn(u)
}
func (s *AuthStore) DeleteRefreshToken(t *models.Token) error {
s.DeleteRefreshTokenInvoked = true
return s.DeleteRefreshTokenFn(t)
}
func (s *AuthStore) PurgeExpiredToken() error {
s.PurgeExpiredTokenInvoked = true
return s.PurgeExpiredTokenFn()
}