101 lines
2.6 KiB
Go
101 lines
2.6 KiB
Go
package jwt
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/go-chi/jwtauth/v5"
|
|
"github.com/lestrrat-go/jwx/jwt"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
// TokenAuth implements JWT authentication flow.
|
|
type TokenAuth struct {
|
|
JwtAuth *jwtauth.JWTAuth
|
|
JwtExpiry time.Duration
|
|
JwtRefreshExpiry time.Duration
|
|
}
|
|
|
|
// NewTokenAuth configures and returns a JWT authentication instance.
|
|
func NewTokenAuth() (*TokenAuth, error) {
|
|
secret := viper.GetString("auth_jwt_secret")
|
|
if secret == "random" {
|
|
secret = randStringBytes(32)
|
|
}
|
|
|
|
a := &TokenAuth{
|
|
JwtAuth: jwtauth.New("HS256", []byte(secret), nil),
|
|
JwtExpiry: viper.GetDuration("auth_jwt_expiry"),
|
|
JwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"),
|
|
}
|
|
|
|
return a, nil
|
|
}
|
|
|
|
// Verifier http middleware will verify a jwt string from a http request.
|
|
func (a *TokenAuth) Verifier() func(http.Handler) http.Handler {
|
|
return jwtauth.Verifier(a.JwtAuth)
|
|
}
|
|
|
|
// GenTokenPair returns both an access token and a refresh token.
|
|
func (a *TokenAuth) GenTokenPair(accessClaims AppClaims, refreshClaims RefreshClaims) (string, string, error) {
|
|
access, err := a.CreateJWT(accessClaims)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
refresh, err := a.CreateRefreshJWT(refreshClaims)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
return access, refresh, nil
|
|
}
|
|
|
|
// CreateJWT returns an access token for provided account claims.
|
|
func (a *TokenAuth) CreateJWT(c AppClaims) (string, error) {
|
|
token := jwt.New()
|
|
token.Set(jwt.IssuedAtKey, time.Now().Unix())
|
|
token.Set(jwt.ExpirationKey, time.Now().Add(a.JwtExpiry).Unix())
|
|
|
|
token.Set(jwt.SubjectKey, c.Sub)
|
|
token.Set(`id`, c.ID)
|
|
token.Set(`roles`, c.Roles)
|
|
|
|
tokenMap, err := token.AsMap(context.Background())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
_, tokenString, err := a.JwtAuth.Encode(tokenMap)
|
|
return tokenString, err
|
|
}
|
|
|
|
// CreateRefreshJWT returns a refresh token for provided token Claims.
|
|
func (a *TokenAuth) CreateRefreshJWT(c RefreshClaims) (string, error) {
|
|
token := jwt.New()
|
|
token.Set(jwt.IssuedAtKey, time.Now().Unix())
|
|
token.Set(jwt.ExpirationKey, time.Now().Add(a.JwtRefreshExpiry).Unix())
|
|
|
|
token.Set(`token`, c.Token)
|
|
|
|
tokenMap, err := token.AsMap(context.Background())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
_, tokenString, err := a.JwtAuth.Encode(tokenMap)
|
|
return tokenString, err
|
|
}
|
|
|
|
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
|
|
func randStringBytes(n int) string {
|
|
buf := make([]byte, n)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
for k, v := range buf {
|
|
buf[k] = letterBytes[v%byte(len(letterBytes))]
|
|
}
|
|
return string(buf)
|
|
}
|