refactor auth pkg into libraries
This commit is contained in:
parent
521f081ba0
commit
aaf0a0928d
26 changed files with 592 additions and 504 deletions
89
auth/jwt/authenticator.go
Normal file
89
auth/jwt/authenticator.go
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/jwtauth"
|
||||
"github.com/go-chi/render"
|
||||
|
||||
"github.com/dhax/go-base/logging"
|
||||
)
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
ctxClaims ctxKey = iota
|
||||
ctxRefreshToken
|
||||
)
|
||||
|
||||
// ClaimsFromCtx retrieves the parsed AppClaims from request context.
|
||||
func ClaimsFromCtx(ctx context.Context) AppClaims {
|
||||
return ctx.Value(ctxClaims).(AppClaims)
|
||||
}
|
||||
|
||||
// RefreshTokenFromCtx retrieves the parsed refresh token from context.
|
||||
func RefreshTokenFromCtx(ctx context.Context) string {
|
||||
return ctx.Value(ctxRefreshToken).(string)
|
||||
}
|
||||
|
||||
// Authenticator is a default authentication middleware to enforce access from the
|
||||
// Verifier middleware request context values. The Authenticator sends a 401 Unauthorized
|
||||
// response for any unverified tokens and passes the good ones through.
|
||||
func Authenticator(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, claims, err := jwtauth.FromContext(r.Context())
|
||||
|
||||
if err != nil {
|
||||
logging.GetLogEntry(r).Warn(err)
|
||||
render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
|
||||
return
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
|
||||
return
|
||||
}
|
||||
|
||||
// Token is authenticated, parse claims
|
||||
var c AppClaims
|
||||
err = c.ParseClaims(claims)
|
||||
if err != nil {
|
||||
logging.GetLogEntry(r).Error(err)
|
||||
render.Render(w, r, ErrUnauthorized(ErrInvalidAccessToken))
|
||||
return
|
||||
}
|
||||
|
||||
// Set AppClaims on context
|
||||
ctx := context.WithValue(r.Context(), ctxClaims, c)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// AuthenticateRefreshJWT checks validity of refresh tokens and is only used for access token refresh and logout requests. It responds with 401 Unauthorized for invalid or expired refresh tokens.
|
||||
func AuthenticateRefreshJWT(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, claims, err := jwtauth.FromContext(r.Context())
|
||||
if err != nil {
|
||||
logging.GetLogEntry(r).Warn(err)
|
||||
render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
|
||||
return
|
||||
}
|
||||
if !token.Valid {
|
||||
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
|
||||
return
|
||||
}
|
||||
|
||||
// Token is authenticated, parse refresh token string
|
||||
var c RefreshClaims
|
||||
err = c.ParseClaims(claims)
|
||||
if err != nil {
|
||||
logging.GetLogEntry(r).Error(err)
|
||||
render.Render(w, r, ErrUnauthorized(ErrInvalidRefreshToken))
|
||||
return
|
||||
}
|
||||
// Set refresh token string on context
|
||||
ctx := context.WithValue(r.Context(), ctxRefreshToken, c.Token)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
59
auth/jwt/claims.go
Normal file
59
auth/jwt/claims.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/go-chi/jwtauth"
|
||||
)
|
||||
|
||||
// AppClaims represent the claims parsed from JWT access token.
|
||||
type AppClaims struct {
|
||||
ID int
|
||||
Sub string
|
||||
Roles []string
|
||||
}
|
||||
|
||||
// ParseClaims parses JWT claims into AppClaims.
|
||||
func (c *AppClaims) ParseClaims(claims jwtauth.Claims) error {
|
||||
id, ok := claims.Get("id")
|
||||
if !ok {
|
||||
return errors.New("could not parse claim id")
|
||||
}
|
||||
c.ID = int(id.(float64))
|
||||
|
||||
sub, ok := claims.Get("sub")
|
||||
if !ok {
|
||||
return errors.New("could not parse claim sub")
|
||||
}
|
||||
c.Sub = sub.(string)
|
||||
|
||||
rl, ok := claims.Get("roles")
|
||||
if !ok {
|
||||
return errors.New("could not parse claims roles")
|
||||
}
|
||||
|
||||
var roles []string
|
||||
if rl != nil {
|
||||
for _, v := range rl.([]interface{}) {
|
||||
roles = append(roles, v.(string))
|
||||
}
|
||||
}
|
||||
c.Roles = roles
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshClaims represent the claims parsed from JWT refresh token.
|
||||
type RefreshClaims struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
// ParseClaims parses the JWT claims into RefreshClaims.
|
||||
func (c *RefreshClaims) ParseClaims(claims jwtauth.Claims) error {
|
||||
token, ok := claims.Get("token")
|
||||
if !ok {
|
||||
return errors.New("could not parse claim token")
|
||||
}
|
||||
c.Token = token.(string)
|
||||
return nil
|
||||
}
|
||||
42
auth/jwt/errors.go
Normal file
42
auth/jwt/errors.go
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
// The list of jwt token errors presented to the end user.
|
||||
var (
|
||||
ErrTokenUnauthorized = errors.New("token unauthorized")
|
||||
ErrTokenExpired = errors.New("token expired")
|
||||
ErrInvalidAccessToken = errors.New("invalid access token")
|
||||
ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||
)
|
||||
|
||||
// ErrResponse renderer type for handling all sorts of errors.
|
||||
type ErrResponse struct {
|
||||
Err error `json:"-"` // low-level runtime error
|
||||
HTTPStatusCode int `json:"-"` // http response status code
|
||||
|
||||
StatusText string `json:"status"` // user-level status message
|
||||
AppCode int64 `json:"code,omitempty"` // application-specific error code
|
||||
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
|
||||
}
|
||||
|
||||
// Render sets the application-specific error code in AppCode.
|
||||
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
render.Status(r, e.HTTPStatusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrUnauthorized renders status 401 Unauthorized with custom error message.
|
||||
func ErrUnauthorized(err error) render.Renderer {
|
||||
return &ErrResponse{
|
||||
Err: err,
|
||||
HTTPStatusCode: http.StatusUnauthorized,
|
||||
StatusText: http.StatusText(http.StatusUnauthorized),
|
||||
ErrorText: err.Error(),
|
||||
}
|
||||
}
|
||||
45
auth/jwt/token.go
Normal file
45
auth/jwt/token.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/jwtauth"
|
||||
"github.com/go-pg/pg/orm"
|
||||
)
|
||||
|
||||
// Token holds refresh jwt information.
|
||||
type Token struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
AccountID int `json:"-"`
|
||||
|
||||
Token string `json:"-"`
|
||||
Expiry time.Time `json:"-"`
|
||||
Mobile bool `sql:",notnull" json:"mobile"`
|
||||
Identifier string `json:"identifier,omitempty"`
|
||||
}
|
||||
|
||||
// BeforeInsert hook executed before database insert operation.
|
||||
func (t *Token) BeforeInsert(db orm.DB) error {
|
||||
now := time.Now()
|
||||
if t.CreatedAt.IsZero() {
|
||||
t.CreatedAt = now
|
||||
t.UpdatedAt = now
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeforeUpdate hook executed before database update operation.
|
||||
func (t *Token) BeforeUpdate(db orm.DB) error {
|
||||
t.UpdatedAt = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Claims returns the token claims to be signed
|
||||
func (t *Token) Claims() jwtauth.Claims {
|
||||
return jwtauth.Claims{
|
||||
"id": t.ID,
|
||||
"token": t.Token,
|
||||
}
|
||||
}
|
||||
81
auth/jwt/tokenauth.go
Normal file
81
auth/jwt/tokenauth.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/jwtauth"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// TokenAuth implements JWT authentication flow.
|
||||
type TokenAuth struct {
|
||||
JwtAuth *jwtauth.JwtAuth
|
||||
JwtExpiry time.Duration
|
||||
JwtRefreshExpiry time.Duration
|
||||
}
|
||||
|
||||
// NewTokenAuth configures and returns a JWT authentication instance.
|
||||
func NewTokenAuth() (*TokenAuth, error) {
|
||||
secret := viper.GetString("auth_jwt_secret")
|
||||
if secret == "random" {
|
||||
secret = randStringBytes(32)
|
||||
}
|
||||
|
||||
a := &TokenAuth{
|
||||
JwtAuth: jwtauth.New("HS256", []byte(secret), nil),
|
||||
JwtExpiry: viper.GetDuration("auth_jwt_expiry"),
|
||||
JwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"),
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Verifier http middleware will verify a jwt string from a http request.
|
||||
func (a *TokenAuth) Verifier() func(http.Handler) http.Handler {
|
||||
return jwtauth.Verifier(a.JwtAuth)
|
||||
}
|
||||
|
||||
// GenTokenPair returns both an access token and a refresh token.
|
||||
func (a *TokenAuth) GenTokenPair(ca jwtauth.Claims, cr jwtauth.Claims) (string, string, error) {
|
||||
access, err := a.CreateJWT(ca)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
refresh, err := a.CreateRefreshJWT(cr)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return access, refresh, nil
|
||||
}
|
||||
|
||||
// CreateJWT returns an access token for provided account claims.
|
||||
func (a *TokenAuth) CreateJWT(c jwtauth.Claims) (string, error) {
|
||||
c.SetIssuedNow()
|
||||
c.SetExpiryIn(a.JwtExpiry)
|
||||
_, tokenString, err := a.JwtAuth.Encode(c)
|
||||
return tokenString, err
|
||||
}
|
||||
|
||||
// CreateRefreshJWT returns a refresh token for provided token Claims.
|
||||
func (a *TokenAuth) CreateRefreshJWT(c jwtauth.Claims) (string, error) {
|
||||
c.SetIssuedNow()
|
||||
c.SetExpiryIn(a.JwtRefreshExpiry)
|
||||
_, tokenString, err := a.JwtAuth.Encode(c)
|
||||
return tokenString, err
|
||||
}
|
||||
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
func randStringBytes(n int) string {
|
||||
buf := make([]byte, n)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for k, v := range buf {
|
||||
buf[k] = letterBytes[v%byte(len(letterBytes))]
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue