refactor auth pkg into libraries

This commit is contained in:
dhax 2017-10-31 19:10:09 +01:00
parent 521f081ba0
commit aaf0a0928d
26 changed files with 592 additions and 504 deletions

89
auth/jwt/authenticator.go Normal file
View file

@ -0,0 +1,89 @@
package jwt
import (
"context"
"net/http"
"github.com/go-chi/jwtauth"
"github.com/go-chi/render"
"github.com/dhax/go-base/logging"
)
type ctxKey int
const (
ctxClaims ctxKey = iota
ctxRefreshToken
)
// ClaimsFromCtx retrieves the parsed AppClaims from request context.
func ClaimsFromCtx(ctx context.Context) AppClaims {
return ctx.Value(ctxClaims).(AppClaims)
}
// RefreshTokenFromCtx retrieves the parsed refresh token from context.
func RefreshTokenFromCtx(ctx context.Context) string {
return ctx.Value(ctxRefreshToken).(string)
}
// Authenticator is a default authentication middleware to enforce access from the
// Verifier middleware request context values. The Authenticator sends a 401 Unauthorized
// response for any unverified tokens and passes the good ones through.
func Authenticator(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
logging.GetLogEntry(r).Warn(err)
render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
return
}
if !token.Valid {
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
return
}
// Token is authenticated, parse claims
var c AppClaims
err = c.ParseClaims(claims)
if err != nil {
logging.GetLogEntry(r).Error(err)
render.Render(w, r, ErrUnauthorized(ErrInvalidAccessToken))
return
}
// Set AppClaims on context
ctx := context.WithValue(r.Context(), ctxClaims, c)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// AuthenticateRefreshJWT checks validity of refresh tokens and is only used for access token refresh and logout requests. It responds with 401 Unauthorized for invalid or expired refresh tokens.
func AuthenticateRefreshJWT(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
logging.GetLogEntry(r).Warn(err)
render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
return
}
if !token.Valid {
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
return
}
// Token is authenticated, parse refresh token string
var c RefreshClaims
err = c.ParseClaims(claims)
if err != nil {
logging.GetLogEntry(r).Error(err)
render.Render(w, r, ErrUnauthorized(ErrInvalidRefreshToken))
return
}
// Set refresh token string on context
ctx := context.WithValue(r.Context(), ctxRefreshToken, c.Token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

59
auth/jwt/claims.go Normal file
View file

@ -0,0 +1,59 @@
package jwt
import (
"errors"
"github.com/go-chi/jwtauth"
)
// AppClaims represent the claims parsed from JWT access token.
type AppClaims struct {
ID int
Sub string
Roles []string
}
// ParseClaims parses JWT claims into AppClaims.
func (c *AppClaims) ParseClaims(claims jwtauth.Claims) error {
id, ok := claims.Get("id")
if !ok {
return errors.New("could not parse claim id")
}
c.ID = int(id.(float64))
sub, ok := claims.Get("sub")
if !ok {
return errors.New("could not parse claim sub")
}
c.Sub = sub.(string)
rl, ok := claims.Get("roles")
if !ok {
return errors.New("could not parse claims roles")
}
var roles []string
if rl != nil {
for _, v := range rl.([]interface{}) {
roles = append(roles, v.(string))
}
}
c.Roles = roles
return nil
}
// RefreshClaims represent the claims parsed from JWT refresh token.
type RefreshClaims struct {
Token string
}
// ParseClaims parses the JWT claims into RefreshClaims.
func (c *RefreshClaims) ParseClaims(claims jwtauth.Claims) error {
token, ok := claims.Get("token")
if !ok {
return errors.New("could not parse claim token")
}
c.Token = token.(string)
return nil
}

42
auth/jwt/errors.go Normal file
View file

@ -0,0 +1,42 @@
package jwt
import (
"errors"
"net/http"
"github.com/go-chi/render"
)
// The list of jwt token errors presented to the end user.
var (
ErrTokenUnauthorized = errors.New("token unauthorized")
ErrTokenExpired = errors.New("token expired")
ErrInvalidAccessToken = errors.New("invalid access token")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
)
// ErrResponse renderer type for handling all sorts of errors.
type ErrResponse struct {
Err error `json:"-"` // low-level runtime error
HTTPStatusCode int `json:"-"` // http response status code
StatusText string `json:"status"` // user-level status message
AppCode int64 `json:"code,omitempty"` // application-specific error code
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
}
// Render sets the application-specific error code in AppCode.
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
render.Status(r, e.HTTPStatusCode)
return nil
}
// ErrUnauthorized renders status 401 Unauthorized with custom error message.
func ErrUnauthorized(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: http.StatusUnauthorized,
StatusText: http.StatusText(http.StatusUnauthorized),
ErrorText: err.Error(),
}
}

45
auth/jwt/token.go Normal file
View file

@ -0,0 +1,45 @@
package jwt
import (
"time"
"github.com/go-chi/jwtauth"
"github.com/go-pg/pg/orm"
)
// Token holds refresh jwt information.
type Token struct {
ID int `json:"id,omitempty"`
CreatedAt time.Time `json:"created_at,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
AccountID int `json:"-"`
Token string `json:"-"`
Expiry time.Time `json:"-"`
Mobile bool `sql:",notnull" json:"mobile"`
Identifier string `json:"identifier,omitempty"`
}
// BeforeInsert hook executed before database insert operation.
func (t *Token) BeforeInsert(db orm.DB) error {
now := time.Now()
if t.CreatedAt.IsZero() {
t.CreatedAt = now
t.UpdatedAt = now
}
return nil
}
// BeforeUpdate hook executed before database update operation.
func (t *Token) BeforeUpdate(db orm.DB) error {
t.UpdatedAt = time.Now()
return nil
}
// Claims returns the token claims to be signed
func (t *Token) Claims() jwtauth.Claims {
return jwtauth.Claims{
"id": t.ID,
"token": t.Token,
}
}

81
auth/jwt/tokenauth.go Normal file
View file

@ -0,0 +1,81 @@
package jwt
import (
"crypto/rand"
"net/http"
"time"
"github.com/go-chi/jwtauth"
"github.com/spf13/viper"
)
// TokenAuth implements JWT authentication flow.
type TokenAuth struct {
JwtAuth *jwtauth.JwtAuth
JwtExpiry time.Duration
JwtRefreshExpiry time.Duration
}
// NewTokenAuth configures and returns a JWT authentication instance.
func NewTokenAuth() (*TokenAuth, error) {
secret := viper.GetString("auth_jwt_secret")
if secret == "random" {
secret = randStringBytes(32)
}
a := &TokenAuth{
JwtAuth: jwtauth.New("HS256", []byte(secret), nil),
JwtExpiry: viper.GetDuration("auth_jwt_expiry"),
JwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"),
}
return a, nil
}
// Verifier http middleware will verify a jwt string from a http request.
func (a *TokenAuth) Verifier() func(http.Handler) http.Handler {
return jwtauth.Verifier(a.JwtAuth)
}
// GenTokenPair returns both an access token and a refresh token.
func (a *TokenAuth) GenTokenPair(ca jwtauth.Claims, cr jwtauth.Claims) (string, string, error) {
access, err := a.CreateJWT(ca)
if err != nil {
return "", "", err
}
refresh, err := a.CreateRefreshJWT(cr)
if err != nil {
return "", "", err
}
return access, refresh, nil
}
// CreateJWT returns an access token for provided account claims.
func (a *TokenAuth) CreateJWT(c jwtauth.Claims) (string, error) {
c.SetIssuedNow()
c.SetExpiryIn(a.JwtExpiry)
_, tokenString, err := a.JwtAuth.Encode(c)
return tokenString, err
}
// CreateRefreshJWT returns a refresh token for provided token Claims.
func (a *TokenAuth) CreateRefreshJWT(c jwtauth.Claims) (string, error) {
c.SetIssuedNow()
c.SetExpiryIn(a.JwtRefreshExpiry)
_, tokenString, err := a.JwtAuth.Encode(c)
return tokenString, err
}
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func randStringBytes(n int) string {
buf := make([]byte, n)
if _, err := rand.Read(buf); err != nil {
panic(err)
}
for k, v := range buf {
buf[k] = letterBytes[v%byte(len(letterBytes))]
}
return string(buf)
}