package jwt import ( "context" "crypto/rand" "net/http" "time" "github.com/go-chi/jwtauth/v5" "github.com/lestrrat-go/jwx/jwt" "github.com/spf13/viper" ) // TokenAuth implements JWT authentication flow. type TokenAuth struct { JwtAuth *jwtauth.JWTAuth JwtExpiry time.Duration JwtRefreshExpiry time.Duration } // NewTokenAuth configures and returns a JWT authentication instance. func NewTokenAuth() (*TokenAuth, error) { secret := viper.GetString("auth_jwt_secret") if secret == "random" { secret = randStringBytes(32) } a := &TokenAuth{ JwtAuth: jwtauth.New("HS256", []byte(secret), nil), JwtExpiry: viper.GetDuration("auth_jwt_expiry"), JwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"), } return a, nil } // Verifier http middleware will verify a jwt string from a http request. func (a *TokenAuth) Verifier() func(http.Handler) http.Handler { return jwtauth.Verifier(a.JwtAuth) } // GenTokenPair returns both an access token and a refresh token. func (a *TokenAuth) GenTokenPair(accessClaims AppClaims, refreshClaims RefreshClaims) (string, string, error) { access, err := a.CreateJWT(accessClaims) if err != nil { return "", "", err } refresh, err := a.CreateRefreshJWT(refreshClaims) if err != nil { return "", "", err } return access, refresh, nil } // CreateJWT returns an access token for provided account claims. func (a *TokenAuth) CreateJWT(c AppClaims) (string, error) { token := jwt.New() token.Set(jwt.IssuedAtKey, time.Now().Unix()) token.Set(jwt.ExpirationKey, time.Now().Add(a.JwtExpiry).Unix()) token.Set(jwt.SubjectKey, c.Sub) token.Set(`id`, c.ID) token.Set(`roles`, c.Roles) tokenMap, err := token.AsMap(context.Background()) if err != nil { return "", err } _, tokenString, err := a.JwtAuth.Encode(tokenMap) return tokenString, err } // CreateRefreshJWT returns a refresh token for provided token Claims. func (a *TokenAuth) CreateRefreshJWT(c RefreshClaims) (string, error) { token := jwt.New() token.Set(jwt.IssuedAtKey, time.Now().Unix()) token.Set(jwt.ExpirationKey, time.Now().Add(a.JwtRefreshExpiry).Unix()) token.Set(`token`, c.Token) tokenMap, err := token.AsMap(context.Background()) if err != nil { return "", err } _, tokenString, err := a.JwtAuth.Encode(tokenMap) return tokenString, err } const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" func randStringBytes(n int) string { buf := make([]byte, n) if _, err := rand.Read(buf); err != nil { panic(err) } for k, v := range buf { buf[k] = letterBytes[v%byte(len(letterBytes))] } return string(buf) }