112 lines
2.5 KiB
Go
112 lines
2.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/go-chi/jwtauth"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
// AppClaims represent the claims extracted from JWT token.
|
|
type AppClaims struct {
|
|
ID int
|
|
Sub string
|
|
Roles []string
|
|
}
|
|
|
|
// TokenAuth implements JWT authentication flow.
|
|
type TokenAuth struct {
|
|
JwtAuth *jwtauth.JwtAuth
|
|
jwtExpiry time.Duration
|
|
jwtRefreshExpiry time.Duration
|
|
}
|
|
|
|
// NewTokenAuth configures and returns a JWT authentication instance.
|
|
func NewTokenAuth() (*TokenAuth, error) {
|
|
secret := viper.GetString("auth_jwt_secret")
|
|
if secret == "random" {
|
|
secret = randStringBytes(32)
|
|
}
|
|
|
|
a := &TokenAuth{
|
|
JwtAuth: jwtauth.New("HS256", []byte(secret), nil),
|
|
jwtExpiry: viper.GetDuration("auth_jwt_expiry"),
|
|
jwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"),
|
|
}
|
|
|
|
return a, nil
|
|
}
|
|
|
|
// Verifier http middleware will verify a jwt string from a http request.
|
|
func (a *TokenAuth) Verifier() func(http.Handler) http.Handler {
|
|
return jwtauth.Verifier(a.JwtAuth)
|
|
}
|
|
|
|
// GenTokenPair returns both an access token and a refresh token.
|
|
func (a *TokenAuth) GenTokenPair(ca jwtauth.Claims, cr jwtauth.Claims) (string, string, error) {
|
|
access, err := a.CreateJWT(ca)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
refresh, err := a.CreateRefreshJWT(cr)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
return access, refresh, nil
|
|
}
|
|
|
|
// CreateJWT returns an access token for provided account claims.
|
|
func (a *TokenAuth) CreateJWT(c jwtauth.Claims) (string, error) {
|
|
c.SetIssuedNow()
|
|
c.SetExpiryIn(a.jwtExpiry)
|
|
_, tokenString, err := a.JwtAuth.Encode(c)
|
|
return tokenString, err
|
|
}
|
|
|
|
// CreateRefreshJWT returns a refresh token for provided token Claims.
|
|
func (a *TokenAuth) CreateRefreshJWT(c jwtauth.Claims) (string, error) {
|
|
c.SetIssuedNow()
|
|
c.SetExpiryIn(a.jwtRefreshExpiry)
|
|
_, tokenString, err := a.JwtAuth.Encode(c)
|
|
return tokenString, err
|
|
}
|
|
|
|
func parseClaims(c jwtauth.Claims) (AppClaims, bool) {
|
|
var claims AppClaims
|
|
allOK := true
|
|
id, ok := c.Get("id")
|
|
if !ok {
|
|
allOK = false
|
|
}
|
|
claims.ID = int(id.(float64))
|
|
|
|
sub, ok := c.Get("sub")
|
|
if !ok {
|
|
allOK = false
|
|
}
|
|
claims.Sub = sub.(string)
|
|
|
|
rl, ok := c.Get("roles")
|
|
if !ok {
|
|
allOK = false
|
|
}
|
|
|
|
var roles []string
|
|
if rl != nil {
|
|
for _, v := range rl.([]interface{}) {
|
|
roles = append(roles, v.(string))
|
|
}
|
|
}
|
|
claims.Roles = roles
|
|
|
|
return claims, allOK
|
|
}
|
|
|
|
func parseRefreshClaims(c jwtauth.Claims) (string, bool) {
|
|
token, ok := c.Get("token")
|
|
if !ok {
|
|
return "", false
|
|
}
|
|
return token.(string), ok
|
|
}
|