initial commit
This commit is contained in:
commit
93d8310491
46 changed files with 3379 additions and 0 deletions
91
auth/api.go
Normal file
91
auth/api.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dhax/go-base/email"
|
||||
"github.com/dhax/go-base/logging"
|
||||
"github.com/dhax/go-base/models"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Store defines database operations on account and token data.
|
||||
type Store interface {
|
||||
GetByID(id int) (*models.Account, error)
|
||||
GetByEmail(email string) (*models.Account, error)
|
||||
GetByRefreshToken(token string) (*models.Account, *models.Token, error)
|
||||
UpdateAccount(a *models.Account) error
|
||||
SaveRefreshToken(u *models.Token) error
|
||||
DeleteRefreshToken(t *models.Token) error
|
||||
PurgeExpiredToken() error
|
||||
}
|
||||
|
||||
// EmailService defines methods to send account emails.
|
||||
type EmailService interface {
|
||||
LoginToken(name, email string, c email.LoginTokenContent) error
|
||||
}
|
||||
|
||||
// Resource implements passwordless token authentication against a database.
|
||||
type Resource struct {
|
||||
Login *LoginTokenAuth
|
||||
Token *TokenAuth
|
||||
store Store
|
||||
mailer EmailService
|
||||
}
|
||||
|
||||
// NewResource returns a configured authentication resource.
|
||||
func NewResource(store Store, mailer EmailService) (*Resource, error) {
|
||||
loginAuth, err := NewLoginTokenAuth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenAuth, err := NewTokenAuth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resource := &Resource{
|
||||
Login: loginAuth,
|
||||
Token: tokenAuth,
|
||||
store: store,
|
||||
mailer: mailer,
|
||||
}
|
||||
|
||||
resource.Cleanup()
|
||||
|
||||
return resource, nil
|
||||
}
|
||||
|
||||
// Router provides neccessary routes for passwordless authentication flow.
|
||||
func (rs *Resource) Router() *chi.Mux {
|
||||
r := chi.NewRouter()
|
||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||
r.Post("/login", rs.login)
|
||||
r.Post("/token", rs.token)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(rs.Token.Verifier())
|
||||
r.Use(AuthenticateRefreshJWT)
|
||||
r.Post("/refresh", rs.refresh)
|
||||
r.Post("/logout", rs.logout)
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func (rs *Resource) Cleanup() {
|
||||
ticker := time.NewTicker(time.Hour * 1)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
if err := rs.store.PurgeExpiredToken(); err != nil {
|
||||
logging.Logger.WithField("auth", "cleanup").Error(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func log(r *http.Request) logrus.FieldLogger {
|
||||
return logging.GetLogEntry(r)
|
||||
}
|
||||
87
auth/authenticator.go
Normal file
87
auth/authenticator.go
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/jwtauth"
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
ctxClaims ctxKey = iota
|
||||
ctxRefreshToken
|
||||
)
|
||||
|
||||
var (
|
||||
errTokenUnauthorized = errors.New("token unauthorized")
|
||||
errTokenExpired = errors.New("token expired")
|
||||
errInvalidAccessToken = errors.New("invalid access token")
|
||||
errInvalidRefreshToken = errors.New("invalid refresh token")
|
||||
)
|
||||
|
||||
// ClaimsFromCtx retrieves the parsed AppClaims from request context.
|
||||
func ClaimsFromCtx(ctx context.Context) AppClaims {
|
||||
return ctx.Value(ctxClaims).(AppClaims)
|
||||
}
|
||||
|
||||
// RefreshTokenFromCtx retrieves the parsed refresh token from context.
|
||||
func RefreshTokenFromCtx(ctx context.Context) string {
|
||||
return ctx.Value(ctxRefreshToken).(string)
|
||||
}
|
||||
|
||||
// Authenticator is a default authentication middleware to enforce access from the
|
||||
// Verifier middleware request context values. The Authenticator sends a 401 Unauthorized
|
||||
// response for any unverified tokens and passes the good ones through.
|
||||
func Authenticator(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, claims, err := jwtauth.FromContext(r.Context())
|
||||
|
||||
if err != nil {
|
||||
log(r).Warn(err)
|
||||
render.Render(w, r, ErrUnauthorized(errTokenUnauthorized))
|
||||
return
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
||||
return
|
||||
}
|
||||
|
||||
// Token is authenticated, parse claims
|
||||
pc, ok := parseClaims(claims)
|
||||
if !ok {
|
||||
render.Render(w, r, ErrUnauthorized(errInvalidAccessToken))
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ctxClaims, pc)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// AuthenticateRefreshJWT checks validity of refresh tokens and is only used for access token refresh and logout requests. It responds with 401 Unauthorized for invalid or expired refresh tokens.
|
||||
func AuthenticateRefreshJWT(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, claims, err := jwtauth.FromContext(r.Context())
|
||||
if err != nil {
|
||||
log(r).Warn(err)
|
||||
render.Render(w, r, ErrUnauthorized(errTokenUnauthorized))
|
||||
return
|
||||
}
|
||||
if !token.Valid {
|
||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
||||
return
|
||||
}
|
||||
refreshToken, ok := parseRefreshClaims(claims)
|
||||
if !ok {
|
||||
render.Render(w, r, ErrUnauthorized(errInvalidRefreshToken))
|
||||
return
|
||||
}
|
||||
// Token is authenticated, set on context
|
||||
ctx := context.WithValue(r.Context(), ctxRefreshToken, refreshToken)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
31
auth/authorizer.go
Normal file
31
auth/authorizer.go
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
// RequiresRole middleware restricts access to accounts having role parameter in their jwt claims.
|
||||
func RequiresRole(role string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
hfn := func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := ClaimsFromCtx(r.Context())
|
||||
if !hasRole(role, claims.Roles) {
|
||||
render.Render(w, r, ErrForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(hfn)
|
||||
}
|
||||
}
|
||||
|
||||
func hasRole(role string, roles []string) bool {
|
||||
for _, r := range roles {
|
||||
if r == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
20
auth/crypto.go
Normal file
20
auth/crypto.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
func randStringBytes(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
64
auth/errors.go
Normal file
64
auth/errors.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
// ErrResponse renderer type for handling all sorts of errors.
|
||||
type ErrResponse struct {
|
||||
Err error `json:"-"` // low-level runtime error
|
||||
HTTPStatusCode int `json:"-"` // http response status code
|
||||
|
||||
StatusText string `json:"status"` // user-level status message
|
||||
AppCode int64 `json:"code,omitempty"` // application-specific error code
|
||||
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
|
||||
}
|
||||
|
||||
// Render sets the application-specific error code in AppCode.
|
||||
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
render.Status(r, e.HTTPStatusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrUnauthorized renders status 401 Unauthorized with custom error message.
|
||||
func ErrUnauthorized(err error) render.Renderer {
|
||||
return &ErrResponse{
|
||||
Err: err,
|
||||
HTTPStatusCode: http.StatusUnauthorized,
|
||||
StatusText: http.StatusText(http.StatusUnauthorized),
|
||||
ErrorText: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// ErrRender returns status 422 Unprocessable Entity for invalid request body
|
||||
func ErrRender(err error) render.Renderer {
|
||||
return &ErrResponse{
|
||||
Err: err,
|
||||
HTTPStatusCode: http.StatusUnprocessableEntity,
|
||||
StatusText: http.StatusText(http.StatusUnprocessableEntity),
|
||||
ErrorText: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// ErrInvalidRequest returns status 422 Unprocessable Entity with validation errors
|
||||
func ErrInvalidRequest(err error) render.Renderer {
|
||||
return &ErrResponse{
|
||||
Err: err,
|
||||
HTTPStatusCode: http.StatusUnprocessableEntity,
|
||||
StatusText: http.StatusText(http.StatusUnprocessableEntity),
|
||||
ErrorText: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// The list of default error types without specific error message.
|
||||
var (
|
||||
ErrBadRequest = &ErrResponse{HTTPStatusCode: http.StatusBadRequest, StatusText: http.StatusText(http.StatusBadRequest)}
|
||||
|
||||
ErrForbidden = &ErrResponse{HTTPStatusCode: http.StatusForbidden, StatusText: http.StatusText(http.StatusForbidden)}
|
||||
|
||||
ErrNotFound = &ErrResponse{HTTPStatusCode: http.StatusNotFound, StatusText: http.StatusText(http.StatusNotFound)}
|
||||
|
||||
ErrInternalServerError = &ErrResponse{HTTPStatusCode: http.StatusInternalServerError, StatusText: http.StatusText(http.StatusInternalServerError)}
|
||||
)
|
||||
215
auth/handler.go
Normal file
215
auth/handler.go
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
validation "github.com/go-ozzo/ozzo-validation"
|
||||
"github.com/go-ozzo/ozzo-validation/is"
|
||||
"github.com/mssola/user_agent"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
|
||||
"github.com/dhax/go-base/email"
|
||||
"github.com/dhax/go-base/models"
|
||||
)
|
||||
|
||||
// The list of error types presented to the end user as error message.
|
||||
var (
|
||||
ErrInvalidLogin = errors.New("invalid email address")
|
||||
ErrUnknownLogin = errors.New("email not registered")
|
||||
ErrLoginDisabled = errors.New("login for account disabled")
|
||||
ErrLoginToken = errors.New("invalid or expired login token")
|
||||
)
|
||||
|
||||
type loginRequest struct {
|
||||
Email string
|
||||
}
|
||||
|
||||
func (body *loginRequest) Bind(r *http.Request) error {
|
||||
body.Email = strings.TrimSpace(body.Email)
|
||||
body.Email = strings.ToLower(body.Email)
|
||||
|
||||
if err := validation.ValidateStruct(body,
|
||||
validation.Field(&body.Email, validation.Required, is.Email),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *Resource) login(w http.ResponseWriter, r *http.Request) {
|
||||
body := &loginRequest{}
|
||||
if err := render.Bind(r, body); err != nil {
|
||||
log(r).WithField("email", body.Email).Warn(err)
|
||||
render.Render(w, r, ErrUnauthorized(ErrInvalidLogin))
|
||||
return
|
||||
}
|
||||
|
||||
acc, err := rs.store.GetByEmail(body.Email)
|
||||
if err != nil {
|
||||
log(r).WithField("email", body.Email).Warn(err)
|
||||
render.Render(w, r, ErrUnauthorized(ErrUnknownLogin))
|
||||
return
|
||||
}
|
||||
|
||||
if !acc.CanLogin() {
|
||||
render.Render(w, r, ErrUnauthorized(ErrLoginDisabled))
|
||||
return
|
||||
}
|
||||
|
||||
lt := rs.Login.CreateToken(acc.ID)
|
||||
|
||||
go func() {
|
||||
err := rs.mailer.LoginToken(acc.Name, acc.Email, email.LoginTokenContent{
|
||||
Email: acc.Email,
|
||||
Name: acc.Name,
|
||||
URL: path.Join(rs.Login.loginURL, lt.Token),
|
||||
Token: lt.Token,
|
||||
Expiry: lt.Expiry,
|
||||
})
|
||||
if err != nil {
|
||||
log(r).WithField("module", "email").Error(err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
render.Respond(w, r, http.NoBody)
|
||||
}
|
||||
|
||||
type tokenRequest struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type tokenResponse struct {
|
||||
Access string `json:"access_token"`
|
||||
Refresh string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
func (body *tokenRequest) Bind(r *http.Request) error {
|
||||
body.Token = strings.TrimSpace(body.Token)
|
||||
|
||||
if err := validation.ValidateStruct(body,
|
||||
validation.Field(&body.Token, validation.Required, is.Alphanumeric),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *Resource) token(w http.ResponseWriter, r *http.Request) {
|
||||
body := &tokenRequest{}
|
||||
if err := render.Bind(r, body); err != nil {
|
||||
log(r).Warn(err)
|
||||
render.Render(w, r, ErrUnauthorized(ErrLoginToken))
|
||||
return
|
||||
}
|
||||
|
||||
id, err := rs.Login.GetAccountID(body.Token)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrUnauthorized(ErrLoginToken))
|
||||
return
|
||||
}
|
||||
|
||||
acc, err := rs.store.GetByID(id)
|
||||
if err != nil {
|
||||
// account deleted before login token expired
|
||||
render.Render(w, r, ErrUnauthorized(ErrUnknownLogin))
|
||||
return
|
||||
}
|
||||
|
||||
if !acc.CanLogin() {
|
||||
render.Render(w, r, ErrUnauthorized(ErrLoginDisabled))
|
||||
return
|
||||
}
|
||||
|
||||
ua := user_agent.New(r.UserAgent())
|
||||
browser, _ := ua.Browser()
|
||||
token := &models.Token{
|
||||
Token: uuid.NewV4().String(),
|
||||
Expiry: time.Now().Add(time.Minute * rs.Token.jwtRefreshExpiry),
|
||||
UpdatedAt: time.Now(),
|
||||
AccountID: acc.ID,
|
||||
Mobile: ua.Mobile(),
|
||||
Identifier: fmt.Sprintf("%s on %s", browser, ua.OS()),
|
||||
}
|
||||
|
||||
if err := rs.store.SaveRefreshToken(token); err != nil {
|
||||
log(r).Error(err)
|
||||
render.Respond(w, r, ErrInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
access, refresh := rs.Token.GenTokenPair(acc, token)
|
||||
|
||||
acc.LastLogin = time.Now()
|
||||
if err := rs.store.UpdateAccount(acc); err != nil {
|
||||
log(r).Error(err)
|
||||
render.Respond(w, r, ErrInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
render.Respond(w, r, &tokenResponse{
|
||||
Access: access,
|
||||
Refresh: refresh,
|
||||
})
|
||||
}
|
||||
|
||||
func (rs *Resource) refresh(w http.ResponseWriter, r *http.Request) {
|
||||
rt := RefreshTokenFromCtx(r.Context())
|
||||
|
||||
acc, token, err := rs.store.GetByRefreshToken(rt)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
||||
return
|
||||
}
|
||||
|
||||
if time.Now().After(token.Expiry) {
|
||||
rs.store.DeleteRefreshToken(token)
|
||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
||||
return
|
||||
}
|
||||
|
||||
if !acc.CanLogin() {
|
||||
render.Render(w, r, ErrUnauthorized(ErrLoginDisabled))
|
||||
return
|
||||
}
|
||||
|
||||
token.Token = uuid.NewV4().String()
|
||||
token.Expiry = time.Now().Add(time.Minute * rs.Token.jwtRefreshExpiry)
|
||||
token.UpdatedAt = time.Now()
|
||||
|
||||
access, refresh := rs.Token.GenTokenPair(acc, token)
|
||||
if err := rs.store.SaveRefreshToken(token); err != nil {
|
||||
log(r).Error(err)
|
||||
render.Respond(w, r, ErrInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
acc.LastLogin = time.Now()
|
||||
if err := rs.store.UpdateAccount(acc); err != nil {
|
||||
log(r).Error(err)
|
||||
render.Respond(w, r, ErrInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
render.Respond(w, r, &tokenResponse{
|
||||
Access: access,
|
||||
Refresh: refresh,
|
||||
})
|
||||
}
|
||||
|
||||
func (rs *Resource) logout(w http.ResponseWriter, r *http.Request) {
|
||||
rt := RefreshTokenFromCtx(r.Context())
|
||||
_, token, err := rs.store.GetByRefreshToken(rt)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrUnauthorized(errTokenExpired))
|
||||
return
|
||||
}
|
||||
rs.store.DeleteRefreshToken(token)
|
||||
|
||||
render.Respond(w, r, http.NoBody)
|
||||
}
|
||||
346
auth/handler_test.go
Normal file
346
auth/handler_test.go
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dhax/go-base/email"
|
||||
"github.com/dhax/go-base/logging"
|
||||
"github.com/dhax/go-base/models"
|
||||
"github.com/dhax/go-base/testing/mock"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/jwtauth"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var (
|
||||
auth *Resource
|
||||
authstore mock.AuthStore
|
||||
mailer mock.EmailService
|
||||
ts *httptest.Server
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
viper.SetDefault("auth_login_token_length", 8)
|
||||
viper.SetDefault("auth_login_token_expiry", 11)
|
||||
viper.SetDefault("auth_jwt_secret", "random")
|
||||
viper.SetDefault("log_level", "error")
|
||||
|
||||
var err error
|
||||
auth, err = NewResource(&authstore, &mailer)
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(logging.NewStructuredLogger(logging.NewLogger()))
|
||||
r.Mount("/", auth.Router())
|
||||
|
||||
ts = httptest.NewServer(r)
|
||||
|
||||
code := m.Run()
|
||||
ts.Close()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestAuthResource_login(t *testing.T) {
|
||||
authstore.GetByEmailFn = func(email string) (*models.Account, error) {
|
||||
var err error
|
||||
a := models.Account{
|
||||
ID: 1,
|
||||
Email: email,
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
switch email {
|
||||
case "not@exists.io":
|
||||
err = errors.New("sql no row")
|
||||
case "disabled@account.io":
|
||||
a.Active = false
|
||||
case "valid@account.io":
|
||||
a.Active = true
|
||||
}
|
||||
return &a, err
|
||||
}
|
||||
|
||||
mailer.LoginTokenFn = func(n, e string, c email.LoginTokenContent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
status int
|
||||
err error
|
||||
}{
|
||||
{"missing", "", http.StatusUnauthorized, ErrInvalidLogin},
|
||||
{"inexistent", "not@exists.io", http.StatusUnauthorized, ErrUnknownLogin},
|
||||
{"disabled", "disabled@account.io", http.StatusUnauthorized, ErrLoginDisabled},
|
||||
{"valid", "valid@account.io", http.StatusOK, nil},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, err := encode(&loginRequest{Email: tc.email})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode request body")
|
||||
}
|
||||
res, body := testRequest(t, ts, "POST", "/login", req, "")
|
||||
|
||||
if res.StatusCode != tc.status {
|
||||
t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
|
||||
}
|
||||
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
||||
t.Errorf(" got: %s, expected to contain: %s", body, tc.err.Error())
|
||||
}
|
||||
if tc.err == ErrInvalidLogin && authstore.GetByEmailInvoked {
|
||||
t.Error("GetByLoginToken invoked for invalid email")
|
||||
}
|
||||
if tc.err == nil && !mailer.LoginTokenInvoked {
|
||||
t.Error("emailService.LoginToken not invoked")
|
||||
}
|
||||
authstore.GetByEmailInvoked = false
|
||||
mailer.LoginTokenInvoked = false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthResource_token(t *testing.T) {
|
||||
authstore.GetByIDFn = func(id int) (*models.Account, error) {
|
||||
var err error
|
||||
a := models.Account{
|
||||
ID: id,
|
||||
Active: true,
|
||||
Name: "test",
|
||||
}
|
||||
switch id {
|
||||
case 2:
|
||||
a.Active = false
|
||||
case 3:
|
||||
// unmodified
|
||||
default:
|
||||
err = errors.New("sql no rows")
|
||||
}
|
||||
return &a, err
|
||||
}
|
||||
authstore.UpdateAccountFn = func(a *models.Account) error {
|
||||
a.LastLogin = time.Now()
|
||||
return nil
|
||||
}
|
||||
authstore.SaveRefreshTokenFn = func(a *models.Token) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
id int
|
||||
status int
|
||||
err error
|
||||
}{
|
||||
{"invalid", "#§$%", 0, http.StatusUnauthorized, ErrLoginToken},
|
||||
{"expired", "12345678", 0, http.StatusUnauthorized, ErrLoginToken},
|
||||
{"deleted_account", "", 1, http.StatusUnauthorized, ErrUnknownLogin},
|
||||
{"disabled", "", 2, http.StatusUnauthorized, ErrLoginDisabled},
|
||||
{"valid", "", 3, http.StatusOK, nil},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
token := auth.Login.CreateToken(tc.id)
|
||||
if tc.token != "" {
|
||||
token.Token = tc.token
|
||||
}
|
||||
|
||||
req, err := encode(tokenRequest{Token: token.Token})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode request body")
|
||||
}
|
||||
res, body := testRequest(t, ts, "POST", "/token", req, "")
|
||||
|
||||
if res.StatusCode != tc.status {
|
||||
t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
|
||||
}
|
||||
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
||||
t.Errorf("got: %s, expected to contain: %s", body, tc.err.Error())
|
||||
}
|
||||
if tc.err == ErrLoginToken && authstore.SaveRefreshTokenInvoked {
|
||||
t.Errorf("SaveRefreshToken invoked despite error %s", tc.err.Error())
|
||||
}
|
||||
if tc.err == nil && !authstore.SaveRefreshTokenInvoked {
|
||||
t.Error("SaveRefreshToken not invoked")
|
||||
}
|
||||
authstore.SaveRefreshTokenInvoked = false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthResource_refresh(t *testing.T) {
|
||||
authstore.GetByRefreshTokenFn = func(token string) (*models.Account, *models.Token, error) {
|
||||
var err error
|
||||
a := models.Account{
|
||||
Active: true,
|
||||
Name: "Test",
|
||||
}
|
||||
var t models.Token
|
||||
t.Expiry = time.Now().Add(1 * time.Minute)
|
||||
|
||||
switch token {
|
||||
case "notfound":
|
||||
err = errors.New("sql no rows")
|
||||
case "expired":
|
||||
t.Expiry = time.Now().Add(-1 * time.Minute)
|
||||
case "disabled":
|
||||
a.Active = false
|
||||
case "valid":
|
||||
// unmodified
|
||||
}
|
||||
return &a, &t, err
|
||||
}
|
||||
authstore.UpdateAccountFn = func(a *models.Account) error {
|
||||
a.LastLogin = time.Now()
|
||||
return nil
|
||||
}
|
||||
authstore.SaveRefreshTokenFn = func(a *models.Token) error {
|
||||
return nil
|
||||
}
|
||||
authstore.DeleteRefreshTokenFn = func(t *models.Token) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
exp time.Duration
|
||||
status int
|
||||
err error
|
||||
}{
|
||||
{"notfound", "notfound", 1, http.StatusUnauthorized, errTokenExpired},
|
||||
{"expired", "expired", -1, http.StatusUnauthorized, errTokenUnauthorized},
|
||||
{"disabled", "disabled", 1, http.StatusUnauthorized, ErrLoginDisabled},
|
||||
{"valid", "valid", 1, http.StatusOK, nil},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwt := genJWT(jwtauth.Claims{"token": tc.token, "exp": time.Minute * tc.exp})
|
||||
res, body := testRequest(t, ts, "POST", "/refresh", nil, jwt)
|
||||
if res.StatusCode != tc.status {
|
||||
t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
|
||||
}
|
||||
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
||||
t.Errorf("got: %s, expected error to contain: %s", body, tc.err.Error())
|
||||
}
|
||||
if tc.status == http.StatusUnauthorized && authstore.SaveRefreshTokenInvoked {
|
||||
t.Errorf("SaveRefreshToken invoked for status %d", tc.status)
|
||||
}
|
||||
if tc.status == http.StatusOK && !authstore.GetByRefreshTokenInvoked {
|
||||
t.Errorf("GetRefreshToken not invoked")
|
||||
}
|
||||
if tc.status == http.StatusOK && !authstore.SaveRefreshTokenInvoked {
|
||||
t.Errorf("SaveRefreshToken not invoked")
|
||||
}
|
||||
if tc.status == http.StatusOK && authstore.DeleteRefreshTokenInvoked {
|
||||
t.Errorf("DeleteRefreshToken should not be invoked")
|
||||
}
|
||||
authstore.GetByRefreshTokenInvoked = false
|
||||
authstore.SaveRefreshTokenInvoked = false
|
||||
authstore.DeleteRefreshTokenInvoked = false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthResource_logout(t *testing.T) {
|
||||
authstore.GetByRefreshTokenFn = func(token string) (*models.Account, *models.Token, error) {
|
||||
var err error
|
||||
var a models.Account
|
||||
t := models.Token{
|
||||
Expiry: time.Now().Add(1 * time.Minute),
|
||||
}
|
||||
|
||||
switch token {
|
||||
case "notfound":
|
||||
err = errors.New("sql no rows")
|
||||
}
|
||||
return &a, &t, err
|
||||
}
|
||||
authstore.DeleteRefreshTokenFn = func(a *models.Token) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
exp time.Duration
|
||||
status int
|
||||
err error
|
||||
}{
|
||||
{"notfound", "notfound", 1, http.StatusUnauthorized, errTokenExpired},
|
||||
{"expired", "valid", -1, http.StatusUnauthorized, errTokenUnauthorized},
|
||||
{"valid", "valid", 1, http.StatusOK, nil},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwt := genJWT(jwtauth.Claims{"token": tc.token, "exp": time.Minute * tc.exp})
|
||||
res, body := testRequest(t, ts, "POST", "/logout", nil, jwt)
|
||||
if res.StatusCode != tc.status {
|
||||
t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
|
||||
}
|
||||
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
|
||||
t.Errorf("got: %x, expected error to contain %s", body, tc.err.Error())
|
||||
}
|
||||
if tc.status == http.StatusUnauthorized && authstore.DeleteRefreshTokenInvoked {
|
||||
t.Errorf("DeleteRefreshToken invoked for status %d", tc.status)
|
||||
}
|
||||
authstore.DeleteRefreshTokenInvoked = false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader, token string) (*http.Response, string) {
|
||||
req, err := http.NewRequest(method, ts.URL+path, body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil, ""
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "BEARER "+token)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil, ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
return resp, string(respBody)
|
||||
}
|
||||
|
||||
func genJWT(c jwtauth.Claims) string {
|
||||
_, tokenString, _ := auth.Token.JwtAuth.Encode(c)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func encode(v interface{}) (*bytes.Buffer, error) {
|
||||
data := new(bytes.Buffer)
|
||||
err := json.NewEncoder(data).Encode(v)
|
||||
return data, err
|
||||
}
|
||||
118
auth/jwt.go
Normal file
118
auth/jwt.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dhax/go-base/models"
|
||||
"github.com/go-chi/jwtauth"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// AppClaims represent the claims extracted from JWT token.
|
||||
type AppClaims struct {
|
||||
ID int
|
||||
Sub string
|
||||
Roles []string
|
||||
}
|
||||
|
||||
// TokenAuth implements JWT authentication flow.
|
||||
type TokenAuth struct {
|
||||
JwtAuth *jwtauth.JwtAuth
|
||||
jwtExpiry time.Duration
|
||||
jwtRefreshExpiry time.Duration
|
||||
}
|
||||
|
||||
// NewTokenAuth configures and returns a JWT authentication instance.
|
||||
func NewTokenAuth() (*TokenAuth, error) {
|
||||
secret := viper.GetString("auth_jwt_secret")
|
||||
if secret == "random" {
|
||||
secret = randStringBytes(32)
|
||||
}
|
||||
|
||||
a := &TokenAuth{
|
||||
JwtAuth: jwtauth.New("HS256", []byte(secret), nil),
|
||||
jwtExpiry: viper.GetDuration("auth_jwt_expiry"),
|
||||
jwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"),
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Verifier http middleware will verify a jwt string from a http request.
|
||||
func (a *TokenAuth) Verifier() func(http.Handler) http.Handler {
|
||||
return jwtauth.Verifier(a.JwtAuth)
|
||||
}
|
||||
|
||||
// GenTokenPair returns both an access token and a refresh token for provided account.
|
||||
func (a *TokenAuth) GenTokenPair(u *models.Account, tok *models.Token) (string, string) {
|
||||
access := a.CreateJWT(u)
|
||||
refresh := a.CreateRefreshJWT(tok)
|
||||
return access, refresh
|
||||
}
|
||||
|
||||
// CreateJWT returns an access token for provided account.
|
||||
func (a *TokenAuth) CreateJWT(acc *models.Account) string {
|
||||
claims := jwtauth.Claims{
|
||||
"id": acc.ID,
|
||||
"sub": acc.Name,
|
||||
"roles": acc.Roles,
|
||||
}
|
||||
claims.SetIssuedNow()
|
||||
claims.SetExpiryIn(a.jwtExpiry * time.Minute)
|
||||
|
||||
_, tokenString, _ := a.JwtAuth.Encode(claims)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// CreateRefreshJWT returns a refresh token for provided account.
|
||||
func (a *TokenAuth) CreateRefreshJWT(tok *models.Token) string {
|
||||
claims := jwtauth.Claims{
|
||||
"id": tok.ID,
|
||||
"token": tok.Token,
|
||||
}
|
||||
claims.SetIssuedNow()
|
||||
claims.SetExpiryIn(time.Minute * a.jwtRefreshExpiry)
|
||||
|
||||
_, tokenString, _ := a.JwtAuth.Encode(claims)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func parseClaims(c jwtauth.Claims) (AppClaims, bool) {
|
||||
var claims AppClaims
|
||||
allOK := true
|
||||
id, ok := c.Get("id")
|
||||
if !ok {
|
||||
allOK = false
|
||||
}
|
||||
claims.ID = int(id.(float64))
|
||||
|
||||
sub, ok := c.Get("sub")
|
||||
if !ok {
|
||||
allOK = false
|
||||
}
|
||||
claims.Sub = sub.(string)
|
||||
|
||||
rl, ok := c.Get("roles")
|
||||
if !ok {
|
||||
allOK = false
|
||||
}
|
||||
|
||||
var roles []string
|
||||
if rl != nil {
|
||||
for _, v := range rl.([]interface{}) {
|
||||
roles = append(roles, v.(string))
|
||||
}
|
||||
}
|
||||
claims.Roles = roles
|
||||
|
||||
return claims, allOK
|
||||
}
|
||||
|
||||
func parseRefreshClaims(c jwtauth.Claims) (string, bool) {
|
||||
token, ok := c.Get("token")
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return token.(string), ok
|
||||
}
|
||||
89
auth/logintoken.go
Normal file
89
auth/logintoken.go
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var (
|
||||
errTokenNotFound = errors.New("login token not found")
|
||||
)
|
||||
|
||||
// LoginToken is an in-memory saved token referencing an account ID and an expiry date.
|
||||
type LoginToken struct {
|
||||
Token string
|
||||
AccountID int
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
// LoginTokenAuth implements passwordless login authentication flow using temporary in-memory stored tokens.
|
||||
type LoginTokenAuth struct {
|
||||
token map[string]LoginToken
|
||||
mux sync.RWMutex
|
||||
loginURL string
|
||||
loginTokenLength int
|
||||
loginTokenExpiry time.Duration
|
||||
}
|
||||
|
||||
// NewLoginTokenAuth configures and returns a LoginToken authentication instance.
|
||||
func NewLoginTokenAuth() (*LoginTokenAuth, error) {
|
||||
a := &LoginTokenAuth{
|
||||
token: make(map[string]LoginToken),
|
||||
loginURL: viper.GetString("auth_login_url"),
|
||||
loginTokenLength: viper.GetInt("auth_login_token_length"),
|
||||
loginTokenExpiry: viper.GetDuration("auth_login_token_expiry"),
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// CreateToken creates an in-memory login token referencing account ID. It returns a token containing a random tokenstring and expiry date.
|
||||
func (a *LoginTokenAuth) CreateToken(id int) LoginToken {
|
||||
lt := LoginToken{
|
||||
Token: randStringBytes(a.loginTokenLength),
|
||||
AccountID: id,
|
||||
Expiry: time.Now().Add(time.Minute * a.loginTokenExpiry),
|
||||
}
|
||||
a.add(lt)
|
||||
a.purgeExpired()
|
||||
return lt
|
||||
}
|
||||
|
||||
// GetAccountID looks up the token by tokenstring and returns the account ID or error if token not found or expired.
|
||||
func (a *LoginTokenAuth) GetAccountID(token string) (int, error) {
|
||||
lt, exists := a.get(token)
|
||||
if !exists || time.Now().After(lt.Expiry) {
|
||||
return 0, errTokenNotFound
|
||||
}
|
||||
a.delete(lt.Token)
|
||||
return lt.AccountID, nil
|
||||
}
|
||||
|
||||
func (a *LoginTokenAuth) get(token string) (LoginToken, bool) {
|
||||
a.mux.RLock()
|
||||
lt, ok := a.token[token]
|
||||
a.mux.RUnlock()
|
||||
return lt, ok
|
||||
}
|
||||
|
||||
func (a *LoginTokenAuth) add(lt LoginToken) {
|
||||
a.mux.Lock()
|
||||
a.token[lt.Token] = lt
|
||||
a.mux.Unlock()
|
||||
}
|
||||
|
||||
func (a *LoginTokenAuth) delete(token string) {
|
||||
a.mux.Lock()
|
||||
delete(a.token, token)
|
||||
a.mux.Unlock()
|
||||
}
|
||||
|
||||
func (a *LoginTokenAuth) purgeExpired() {
|
||||
for t, v := range a.token {
|
||||
if time.Now().After(v.Expiry) {
|
||||
a.delete(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue