initial commit

This commit is contained in:
dhax 2017-09-25 18:23:11 +02:00
commit 93d8310491
46 changed files with 3379 additions and 0 deletions

91
auth/api.go Normal file
View file

@ -0,0 +1,91 @@
package auth
import (
"net/http"
"time"
"github.com/dhax/go-base/email"
"github.com/dhax/go-base/logging"
"github.com/dhax/go-base/models"
"github.com/go-chi/chi"
"github.com/go-chi/render"
"github.com/sirupsen/logrus"
)
// Store defines database operations on account and token data.
type Store interface {
GetByID(id int) (*models.Account, error)
GetByEmail(email string) (*models.Account, error)
GetByRefreshToken(token string) (*models.Account, *models.Token, error)
UpdateAccount(a *models.Account) error
SaveRefreshToken(u *models.Token) error
DeleteRefreshToken(t *models.Token) error
PurgeExpiredToken() error
}
// EmailService defines methods to send account emails.
type EmailService interface {
LoginToken(name, email string, c email.LoginTokenContent) error
}
// Resource implements passwordless token authentication against a database.
type Resource struct {
Login *LoginTokenAuth
Token *TokenAuth
store Store
mailer EmailService
}
// NewResource returns a configured authentication resource.
func NewResource(store Store, mailer EmailService) (*Resource, error) {
loginAuth, err := NewLoginTokenAuth()
if err != nil {
return nil, err
}
tokenAuth, err := NewTokenAuth()
if err != nil {
return nil, err
}
resource := &Resource{
Login: loginAuth,
Token: tokenAuth,
store: store,
mailer: mailer,
}
resource.Cleanup()
return resource, nil
}
// Router provides neccessary routes for passwordless authentication flow.
func (rs *Resource) Router() *chi.Mux {
r := chi.NewRouter()
r.Use(render.SetContentType(render.ContentTypeJSON))
r.Post("/login", rs.login)
r.Post("/token", rs.token)
r.Group(func(r chi.Router) {
r.Use(rs.Token.Verifier())
r.Use(AuthenticateRefreshJWT)
r.Post("/refresh", rs.refresh)
r.Post("/logout", rs.logout)
})
return r
}
func (rs *Resource) Cleanup() {
ticker := time.NewTicker(time.Hour * 1)
go func() {
for range ticker.C {
if err := rs.store.PurgeExpiredToken(); err != nil {
logging.Logger.WithField("auth", "cleanup").Error(err)
}
}
}()
}
func log(r *http.Request) logrus.FieldLogger {
return logging.GetLogEntry(r)
}

87
auth/authenticator.go Normal file
View file

@ -0,0 +1,87 @@
package auth
import (
"context"
"errors"
"net/http"
"github.com/go-chi/jwtauth"
"github.com/go-chi/render"
)
type ctxKey int
const (
ctxClaims ctxKey = iota
ctxRefreshToken
)
var (
errTokenUnauthorized = errors.New("token unauthorized")
errTokenExpired = errors.New("token expired")
errInvalidAccessToken = errors.New("invalid access token")
errInvalidRefreshToken = errors.New("invalid refresh token")
)
// ClaimsFromCtx retrieves the parsed AppClaims from request context.
func ClaimsFromCtx(ctx context.Context) AppClaims {
return ctx.Value(ctxClaims).(AppClaims)
}
// RefreshTokenFromCtx retrieves the parsed refresh token from context.
func RefreshTokenFromCtx(ctx context.Context) string {
return ctx.Value(ctxRefreshToken).(string)
}
// Authenticator is a default authentication middleware to enforce access from the
// Verifier middleware request context values. The Authenticator sends a 401 Unauthorized
// response for any unverified tokens and passes the good ones through.
func Authenticator(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
log(r).Warn(err)
render.Render(w, r, ErrUnauthorized(errTokenUnauthorized))
return
}
if !token.Valid {
render.Render(w, r, ErrUnauthorized(errTokenExpired))
return
}
// Token is authenticated, parse claims
pc, ok := parseClaims(claims)
if !ok {
render.Render(w, r, ErrUnauthorized(errInvalidAccessToken))
return
}
ctx := context.WithValue(r.Context(), ctxClaims, pc)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// AuthenticateRefreshJWT checks validity of refresh tokens and is only used for access token refresh and logout requests. It responds with 401 Unauthorized for invalid or expired refresh tokens.
func AuthenticateRefreshJWT(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
log(r).Warn(err)
render.Render(w, r, ErrUnauthorized(errTokenUnauthorized))
return
}
if !token.Valid {
render.Render(w, r, ErrUnauthorized(errTokenExpired))
return
}
refreshToken, ok := parseRefreshClaims(claims)
if !ok {
render.Render(w, r, ErrUnauthorized(errInvalidRefreshToken))
return
}
// Token is authenticated, set on context
ctx := context.WithValue(r.Context(), ctxRefreshToken, refreshToken)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

31
auth/authorizer.go Normal file
View file

@ -0,0 +1,31 @@
package auth
import (
"net/http"
"github.com/go-chi/render"
)
// RequiresRole middleware restricts access to accounts having role parameter in their jwt claims.
func RequiresRole(role string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
hfn := func(w http.ResponseWriter, r *http.Request) {
claims := ClaimsFromCtx(r.Context())
if !hasRole(role, claims.Roles) {
render.Render(w, r, ErrForbidden)
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(hfn)
}
}
func hasRole(role string, roles []string) bool {
for _, r := range roles {
if r == role {
return true
}
}
return false
}

20
auth/crypto.go Normal file
View file

@ -0,0 +1,20 @@
package auth
import (
"math/rand"
"time"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func randStringBytes(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}

64
auth/errors.go Normal file
View file

@ -0,0 +1,64 @@
package auth
import (
"net/http"
"github.com/go-chi/render"
)
// ErrResponse renderer type for handling all sorts of errors.
type ErrResponse struct {
Err error `json:"-"` // low-level runtime error
HTTPStatusCode int `json:"-"` // http response status code
StatusText string `json:"status"` // user-level status message
AppCode int64 `json:"code,omitempty"` // application-specific error code
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
}
// Render sets the application-specific error code in AppCode.
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
render.Status(r, e.HTTPStatusCode)
return nil
}
// ErrUnauthorized renders status 401 Unauthorized with custom error message.
func ErrUnauthorized(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: http.StatusUnauthorized,
StatusText: http.StatusText(http.StatusUnauthorized),
ErrorText: err.Error(),
}
}
// ErrRender returns status 422 Unprocessable Entity for invalid request body
func ErrRender(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: http.StatusUnprocessableEntity,
StatusText: http.StatusText(http.StatusUnprocessableEntity),
ErrorText: err.Error(),
}
}
// ErrInvalidRequest returns status 422 Unprocessable Entity with validation errors
func ErrInvalidRequest(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: http.StatusUnprocessableEntity,
StatusText: http.StatusText(http.StatusUnprocessableEntity),
ErrorText: err.Error(),
}
}
// The list of default error types without specific error message.
var (
ErrBadRequest = &ErrResponse{HTTPStatusCode: http.StatusBadRequest, StatusText: http.StatusText(http.StatusBadRequest)}
ErrForbidden = &ErrResponse{HTTPStatusCode: http.StatusForbidden, StatusText: http.StatusText(http.StatusForbidden)}
ErrNotFound = &ErrResponse{HTTPStatusCode: http.StatusNotFound, StatusText: http.StatusText(http.StatusNotFound)}
ErrInternalServerError = &ErrResponse{HTTPStatusCode: http.StatusInternalServerError, StatusText: http.StatusText(http.StatusInternalServerError)}
)

215
auth/handler.go Normal file
View file

@ -0,0 +1,215 @@
package auth
import (
"errors"
"fmt"
"net/http"
"path"
"strings"
"time"
"github.com/go-chi/render"
validation "github.com/go-ozzo/ozzo-validation"
"github.com/go-ozzo/ozzo-validation/is"
"github.com/mssola/user_agent"
uuid "github.com/satori/go.uuid"
"github.com/dhax/go-base/email"
"github.com/dhax/go-base/models"
)
// The list of error types presented to the end user as error message.
var (
ErrInvalidLogin = errors.New("invalid email address")
ErrUnknownLogin = errors.New("email not registered")
ErrLoginDisabled = errors.New("login for account disabled")
ErrLoginToken = errors.New("invalid or expired login token")
)
type loginRequest struct {
Email string
}
func (body *loginRequest) Bind(r *http.Request) error {
body.Email = strings.TrimSpace(body.Email)
body.Email = strings.ToLower(body.Email)
if err := validation.ValidateStruct(body,
validation.Field(&body.Email, validation.Required, is.Email),
); err != nil {
return err
}
return nil
}
func (rs *Resource) login(w http.ResponseWriter, r *http.Request) {
body := &loginRequest{}
if err := render.Bind(r, body); err != nil {
log(r).WithField("email", body.Email).Warn(err)
render.Render(w, r, ErrUnauthorized(ErrInvalidLogin))
return
}
acc, err := rs.store.GetByEmail(body.Email)
if err != nil {
log(r).WithField("email", body.Email).Warn(err)
render.Render(w, r, ErrUnauthorized(ErrUnknownLogin))
return
}
if !acc.CanLogin() {
render.Render(w, r, ErrUnauthorized(ErrLoginDisabled))
return
}
lt := rs.Login.CreateToken(acc.ID)
go func() {
err := rs.mailer.LoginToken(acc.Name, acc.Email, email.LoginTokenContent{
Email: acc.Email,
Name: acc.Name,
URL: path.Join(rs.Login.loginURL, lt.Token),
Token: lt.Token,
Expiry: lt.Expiry,
})
if err != nil {
log(r).WithField("module", "email").Error(err.Error())
}
}()
render.Respond(w, r, http.NoBody)
}
type tokenRequest struct {
Token string `json:"token"`
}
type tokenResponse struct {
Access string `json:"access_token"`
Refresh string `json:"refresh_token"`
}
func (body *tokenRequest) Bind(r *http.Request) error {
body.Token = strings.TrimSpace(body.Token)
if err := validation.ValidateStruct(body,
validation.Field(&body.Token, validation.Required, is.Alphanumeric),
); err != nil {
return err
}
return nil
}
func (rs *Resource) token(w http.ResponseWriter, r *http.Request) {
body := &tokenRequest{}
if err := render.Bind(r, body); err != nil {
log(r).Warn(err)
render.Render(w, r, ErrUnauthorized(ErrLoginToken))
return
}
id, err := rs.Login.GetAccountID(body.Token)
if err != nil {
render.Render(w, r, ErrUnauthorized(ErrLoginToken))
return
}
acc, err := rs.store.GetByID(id)
if err != nil {
// account deleted before login token expired
render.Render(w, r, ErrUnauthorized(ErrUnknownLogin))
return
}
if !acc.CanLogin() {
render.Render(w, r, ErrUnauthorized(ErrLoginDisabled))
return
}
ua := user_agent.New(r.UserAgent())
browser, _ := ua.Browser()
token := &models.Token{
Token: uuid.NewV4().String(),
Expiry: time.Now().Add(time.Minute * rs.Token.jwtRefreshExpiry),
UpdatedAt: time.Now(),
AccountID: acc.ID,
Mobile: ua.Mobile(),
Identifier: fmt.Sprintf("%s on %s", browser, ua.OS()),
}
if err := rs.store.SaveRefreshToken(token); err != nil {
log(r).Error(err)
render.Respond(w, r, ErrInternalServerError)
return
}
access, refresh := rs.Token.GenTokenPair(acc, token)
acc.LastLogin = time.Now()
if err := rs.store.UpdateAccount(acc); err != nil {
log(r).Error(err)
render.Respond(w, r, ErrInternalServerError)
return
}
render.Respond(w, r, &tokenResponse{
Access: access,
Refresh: refresh,
})
}
func (rs *Resource) refresh(w http.ResponseWriter, r *http.Request) {
rt := RefreshTokenFromCtx(r.Context())
acc, token, err := rs.store.GetByRefreshToken(rt)
if err != nil {
render.Render(w, r, ErrUnauthorized(errTokenExpired))
return
}
if time.Now().After(token.Expiry) {
rs.store.DeleteRefreshToken(token)
render.Render(w, r, ErrUnauthorized(errTokenExpired))
return
}
if !acc.CanLogin() {
render.Render(w, r, ErrUnauthorized(ErrLoginDisabled))
return
}
token.Token = uuid.NewV4().String()
token.Expiry = time.Now().Add(time.Minute * rs.Token.jwtRefreshExpiry)
token.UpdatedAt = time.Now()
access, refresh := rs.Token.GenTokenPair(acc, token)
if err := rs.store.SaveRefreshToken(token); err != nil {
log(r).Error(err)
render.Respond(w, r, ErrInternalServerError)
return
}
acc.LastLogin = time.Now()
if err := rs.store.UpdateAccount(acc); err != nil {
log(r).Error(err)
render.Respond(w, r, ErrInternalServerError)
return
}
render.Respond(w, r, &tokenResponse{
Access: access,
Refresh: refresh,
})
}
func (rs *Resource) logout(w http.ResponseWriter, r *http.Request) {
rt := RefreshTokenFromCtx(r.Context())
_, token, err := rs.store.GetByRefreshToken(rt)
if err != nil {
render.Render(w, r, ErrUnauthorized(errTokenExpired))
return
}
rs.store.DeleteRefreshToken(token)
render.Respond(w, r, http.NoBody)
}

346
auth/handler_test.go Normal file
View file

@ -0,0 +1,346 @@
package auth
import (
"bytes"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/dhax/go-base/email"
"github.com/dhax/go-base/logging"
"github.com/dhax/go-base/models"
"github.com/dhax/go-base/testing/mock"
"github.com/go-chi/chi"
"github.com/go-chi/jwtauth"
"github.com/spf13/viper"
)
var (
auth *Resource
authstore mock.AuthStore
mailer mock.EmailService
ts *httptest.Server
)
func TestMain(m *testing.M) {
viper.SetDefault("auth_login_token_length", 8)
viper.SetDefault("auth_login_token_expiry", 11)
viper.SetDefault("auth_jwt_secret", "random")
viper.SetDefault("log_level", "error")
var err error
auth, err = NewResource(&authstore, &mailer)
if err != nil {
os.Exit(1)
}
r := chi.NewRouter()
r.Use(logging.NewStructuredLogger(logging.NewLogger()))
r.Mount("/", auth.Router())
ts = httptest.NewServer(r)
code := m.Run()
ts.Close()
os.Exit(code)
}
func TestAuthResource_login(t *testing.T) {
authstore.GetByEmailFn = func(email string) (*models.Account, error) {
var err error
a := models.Account{
ID: 1,
Email: email,
Name: "test",
}
switch email {
case "not@exists.io":
err = errors.New("sql no row")
case "disabled@account.io":
a.Active = false
case "valid@account.io":
a.Active = true
}
return &a, err
}
mailer.LoginTokenFn = func(n, e string, c email.LoginTokenContent) error {
return nil
}
tests := []struct {
name string
email string
status int
err error
}{
{"missing", "", http.StatusUnauthorized, ErrInvalidLogin},
{"inexistent", "not@exists.io", http.StatusUnauthorized, ErrUnknownLogin},
{"disabled", "disabled@account.io", http.StatusUnauthorized, ErrLoginDisabled},
{"valid", "valid@account.io", http.StatusOK, nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req, err := encode(&loginRequest{Email: tc.email})
if err != nil {
t.Fatal("failed to encode request body")
}
res, body := testRequest(t, ts, "POST", "/login", req, "")
if res.StatusCode != tc.status {
t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
}
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
t.Errorf(" got: %s, expected to contain: %s", body, tc.err.Error())
}
if tc.err == ErrInvalidLogin && authstore.GetByEmailInvoked {
t.Error("GetByLoginToken invoked for invalid email")
}
if tc.err == nil && !mailer.LoginTokenInvoked {
t.Error("emailService.LoginToken not invoked")
}
authstore.GetByEmailInvoked = false
mailer.LoginTokenInvoked = false
})
}
}
func TestAuthResource_token(t *testing.T) {
authstore.GetByIDFn = func(id int) (*models.Account, error) {
var err error
a := models.Account{
ID: id,
Active: true,
Name: "test",
}
switch id {
case 2:
a.Active = false
case 3:
// unmodified
default:
err = errors.New("sql no rows")
}
return &a, err
}
authstore.UpdateAccountFn = func(a *models.Account) error {
a.LastLogin = time.Now()
return nil
}
authstore.SaveRefreshTokenFn = func(a *models.Token) error {
return nil
}
tests := []struct {
name string
token string
id int
status int
err error
}{
{"invalid", "#§$%", 0, http.StatusUnauthorized, ErrLoginToken},
{"expired", "12345678", 0, http.StatusUnauthorized, ErrLoginToken},
{"deleted_account", "", 1, http.StatusUnauthorized, ErrUnknownLogin},
{"disabled", "", 2, http.StatusUnauthorized, ErrLoginDisabled},
{"valid", "", 3, http.StatusOK, nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
token := auth.Login.CreateToken(tc.id)
if tc.token != "" {
token.Token = tc.token
}
req, err := encode(tokenRequest{Token: token.Token})
if err != nil {
t.Fatal("failed to encode request body")
}
res, body := testRequest(t, ts, "POST", "/token", req, "")
if res.StatusCode != tc.status {
t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
}
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
t.Errorf("got: %s, expected to contain: %s", body, tc.err.Error())
}
if tc.err == ErrLoginToken && authstore.SaveRefreshTokenInvoked {
t.Errorf("SaveRefreshToken invoked despite error %s", tc.err.Error())
}
if tc.err == nil && !authstore.SaveRefreshTokenInvoked {
t.Error("SaveRefreshToken not invoked")
}
authstore.SaveRefreshTokenInvoked = false
})
}
}
func TestAuthResource_refresh(t *testing.T) {
authstore.GetByRefreshTokenFn = func(token string) (*models.Account, *models.Token, error) {
var err error
a := models.Account{
Active: true,
Name: "Test",
}
var t models.Token
t.Expiry = time.Now().Add(1 * time.Minute)
switch token {
case "notfound":
err = errors.New("sql no rows")
case "expired":
t.Expiry = time.Now().Add(-1 * time.Minute)
case "disabled":
a.Active = false
case "valid":
// unmodified
}
return &a, &t, err
}
authstore.UpdateAccountFn = func(a *models.Account) error {
a.LastLogin = time.Now()
return nil
}
authstore.SaveRefreshTokenFn = func(a *models.Token) error {
return nil
}
authstore.DeleteRefreshTokenFn = func(t *models.Token) error {
return nil
}
tests := []struct {
name string
token string
exp time.Duration
status int
err error
}{
{"notfound", "notfound", 1, http.StatusUnauthorized, errTokenExpired},
{"expired", "expired", -1, http.StatusUnauthorized, errTokenUnauthorized},
{"disabled", "disabled", 1, http.StatusUnauthorized, ErrLoginDisabled},
{"valid", "valid", 1, http.StatusOK, nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
jwt := genJWT(jwtauth.Claims{"token": tc.token, "exp": time.Minute * tc.exp})
res, body := testRequest(t, ts, "POST", "/refresh", nil, jwt)
if res.StatusCode != tc.status {
t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
}
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
t.Errorf("got: %s, expected error to contain: %s", body, tc.err.Error())
}
if tc.status == http.StatusUnauthorized && authstore.SaveRefreshTokenInvoked {
t.Errorf("SaveRefreshToken invoked for status %d", tc.status)
}
if tc.status == http.StatusOK && !authstore.GetByRefreshTokenInvoked {
t.Errorf("GetRefreshToken not invoked")
}
if tc.status == http.StatusOK && !authstore.SaveRefreshTokenInvoked {
t.Errorf("SaveRefreshToken not invoked")
}
if tc.status == http.StatusOK && authstore.DeleteRefreshTokenInvoked {
t.Errorf("DeleteRefreshToken should not be invoked")
}
authstore.GetByRefreshTokenInvoked = false
authstore.SaveRefreshTokenInvoked = false
authstore.DeleteRefreshTokenInvoked = false
})
}
}
func TestAuthResource_logout(t *testing.T) {
authstore.GetByRefreshTokenFn = func(token string) (*models.Account, *models.Token, error) {
var err error
var a models.Account
t := models.Token{
Expiry: time.Now().Add(1 * time.Minute),
}
switch token {
case "notfound":
err = errors.New("sql no rows")
}
return &a, &t, err
}
authstore.DeleteRefreshTokenFn = func(a *models.Token) error {
return nil
}
tests := []struct {
name string
token string
exp time.Duration
status int
err error
}{
{"notfound", "notfound", 1, http.StatusUnauthorized, errTokenExpired},
{"expired", "valid", -1, http.StatusUnauthorized, errTokenUnauthorized},
{"valid", "valid", 1, http.StatusOK, nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
jwt := genJWT(jwtauth.Claims{"token": tc.token, "exp": time.Minute * tc.exp})
res, body := testRequest(t, ts, "POST", "/logout", nil, jwt)
if res.StatusCode != tc.status {
t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
}
if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
t.Errorf("got: %x, expected error to contain %s", body, tc.err.Error())
}
if tc.status == http.StatusUnauthorized && authstore.DeleteRefreshTokenInvoked {
t.Errorf("DeleteRefreshToken invoked for status %d", tc.status)
}
authstore.DeleteRefreshTokenInvoked = false
})
}
}
func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader, token string) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return nil, ""
}
req.Header.Set("Content-Type", "application/json")
if token != "" {
req.Header.Set("Authorization", "BEARER "+token)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return nil, ""
}
return resp, string(respBody)
}
func genJWT(c jwtauth.Claims) string {
_, tokenString, _ := auth.Token.JwtAuth.Encode(c)
return tokenString
}
func encode(v interface{}) (*bytes.Buffer, error) {
data := new(bytes.Buffer)
err := json.NewEncoder(data).Encode(v)
return data, err
}

118
auth/jwt.go Normal file
View file

@ -0,0 +1,118 @@
package auth
import (
"net/http"
"time"
"github.com/dhax/go-base/models"
"github.com/go-chi/jwtauth"
"github.com/spf13/viper"
)
// AppClaims represent the claims extracted from JWT token.
type AppClaims struct {
ID int
Sub string
Roles []string
}
// TokenAuth implements JWT authentication flow.
type TokenAuth struct {
JwtAuth *jwtauth.JwtAuth
jwtExpiry time.Duration
jwtRefreshExpiry time.Duration
}
// NewTokenAuth configures and returns a JWT authentication instance.
func NewTokenAuth() (*TokenAuth, error) {
secret := viper.GetString("auth_jwt_secret")
if secret == "random" {
secret = randStringBytes(32)
}
a := &TokenAuth{
JwtAuth: jwtauth.New("HS256", []byte(secret), nil),
jwtExpiry: viper.GetDuration("auth_jwt_expiry"),
jwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"),
}
return a, nil
}
// Verifier http middleware will verify a jwt string from a http request.
func (a *TokenAuth) Verifier() func(http.Handler) http.Handler {
return jwtauth.Verifier(a.JwtAuth)
}
// GenTokenPair returns both an access token and a refresh token for provided account.
func (a *TokenAuth) GenTokenPair(u *models.Account, tok *models.Token) (string, string) {
access := a.CreateJWT(u)
refresh := a.CreateRefreshJWT(tok)
return access, refresh
}
// CreateJWT returns an access token for provided account.
func (a *TokenAuth) CreateJWT(acc *models.Account) string {
claims := jwtauth.Claims{
"id": acc.ID,
"sub": acc.Name,
"roles": acc.Roles,
}
claims.SetIssuedNow()
claims.SetExpiryIn(a.jwtExpiry * time.Minute)
_, tokenString, _ := a.JwtAuth.Encode(claims)
return tokenString
}
// CreateRefreshJWT returns a refresh token for provided account.
func (a *TokenAuth) CreateRefreshJWT(tok *models.Token) string {
claims := jwtauth.Claims{
"id": tok.ID,
"token": tok.Token,
}
claims.SetIssuedNow()
claims.SetExpiryIn(time.Minute * a.jwtRefreshExpiry)
_, tokenString, _ := a.JwtAuth.Encode(claims)
return tokenString
}
func parseClaims(c jwtauth.Claims) (AppClaims, bool) {
var claims AppClaims
allOK := true
id, ok := c.Get("id")
if !ok {
allOK = false
}
claims.ID = int(id.(float64))
sub, ok := c.Get("sub")
if !ok {
allOK = false
}
claims.Sub = sub.(string)
rl, ok := c.Get("roles")
if !ok {
allOK = false
}
var roles []string
if rl != nil {
for _, v := range rl.([]interface{}) {
roles = append(roles, v.(string))
}
}
claims.Roles = roles
return claims, allOK
}
func parseRefreshClaims(c jwtauth.Claims) (string, bool) {
token, ok := c.Get("token")
if !ok {
return "", false
}
return token.(string), ok
}

89
auth/logintoken.go Normal file
View file

@ -0,0 +1,89 @@
package auth
import (
"errors"
"sync"
"time"
"github.com/spf13/viper"
)
var (
errTokenNotFound = errors.New("login token not found")
)
// LoginToken is an in-memory saved token referencing an account ID and an expiry date.
type LoginToken struct {
Token string
AccountID int
Expiry time.Time
}
// LoginTokenAuth implements passwordless login authentication flow using temporary in-memory stored tokens.
type LoginTokenAuth struct {
token map[string]LoginToken
mux sync.RWMutex
loginURL string
loginTokenLength int
loginTokenExpiry time.Duration
}
// NewLoginTokenAuth configures and returns a LoginToken authentication instance.
func NewLoginTokenAuth() (*LoginTokenAuth, error) {
a := &LoginTokenAuth{
token: make(map[string]LoginToken),
loginURL: viper.GetString("auth_login_url"),
loginTokenLength: viper.GetInt("auth_login_token_length"),
loginTokenExpiry: viper.GetDuration("auth_login_token_expiry"),
}
return a, nil
}
// CreateToken creates an in-memory login token referencing account ID. It returns a token containing a random tokenstring and expiry date.
func (a *LoginTokenAuth) CreateToken(id int) LoginToken {
lt := LoginToken{
Token: randStringBytes(a.loginTokenLength),
AccountID: id,
Expiry: time.Now().Add(time.Minute * a.loginTokenExpiry),
}
a.add(lt)
a.purgeExpired()
return lt
}
// GetAccountID looks up the token by tokenstring and returns the account ID or error if token not found or expired.
func (a *LoginTokenAuth) GetAccountID(token string) (int, error) {
lt, exists := a.get(token)
if !exists || time.Now().After(lt.Expiry) {
return 0, errTokenNotFound
}
a.delete(lt.Token)
return lt.AccountID, nil
}
func (a *LoginTokenAuth) get(token string) (LoginToken, bool) {
a.mux.RLock()
lt, ok := a.token[token]
a.mux.RUnlock()
return lt, ok
}
func (a *LoginTokenAuth) add(lt LoginToken) {
a.mux.Lock()
a.token[lt.Token] = lt
a.mux.Unlock()
}
func (a *LoginTokenAuth) delete(token string) {
a.mux.Lock()
delete(a.token, token)
a.mux.Unlock()
}
func (a *LoginTokenAuth) purgeExpired() {
for t, v := range a.token {
if time.Now().After(v.Expiry) {
a.delete(t)
}
}
}