Merge branch 'update/jwt-package' of https://github.com/hyperyuri/go-base into hyperyuri-update/jwt-package

This commit is contained in:
dhax 2021-11-15 15:03:15 +01:00
commit d8d770478f
7 changed files with 70 additions and 49 deletions

View file

@ -4,6 +4,8 @@ import (
"context"
"net/http"
"github.com/lestrrat-go/jwx/jwt"
"github.com/go-chi/jwtauth/v5"
"github.com/go-chi/render"
@ -32,7 +34,7 @@ func RefreshTokenFromCtx(ctx context.Context) string {
// response for any unverified tokens and passes the good ones through.
func Authenticator(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())
token, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
logging.GetLogEntry(r).Warn(err)
@ -40,6 +42,11 @@ func Authenticator(next http.Handler) http.Handler {
return
}
if err := jwt.Validate(token); err != nil {
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
return
}
// Token is authenticated, parse claims
var c AppClaims
err = c.ParseClaims(claims)
@ -58,13 +65,18 @@ func Authenticator(next http.Handler) http.Handler {
// AuthenticateRefreshJWT checks validity of refresh tokens and is only used for access token refresh and logout requests. It responds with 401 Unauthorized for invalid or expired refresh tokens.
func AuthenticateRefreshJWT(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())
token, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
logging.GetLogEntry(r).Warn(err)
render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
return
}
if err := jwt.Validate(token); err != nil {
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
return
}
// Token is authenticated, parse refresh token string
var c RefreshClaims
err = c.ParseClaims(claims)

View file

@ -6,11 +6,17 @@ import (
"github.com/lestrrat-go/jwx/jwt"
)
type CommonClaims struct {
ExpiresAt int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
}
// AppClaims represent the claims parsed from JWT access token.
type AppClaims struct {
ID int `json:"id,omitempty"`
Sub string `json:"sub,omitempty"`
Roles []string `json:"roles,omitempty"`
CommonClaims
}
// ParseClaims parses JWT claims into AppClaims.
@ -47,6 +53,7 @@ func (c *AppClaims) ParseClaims(claims map[string]interface{}) error {
type RefreshClaims struct {
ID int `json:"id,omitempty"`
Token string `json:"token,omitempty"`
CommonClaims
}
// ParseClaims parses the JWT claims into RefreshClaims.

View file

@ -1,13 +1,12 @@
package jwt
import (
"context"
"crypto/rand"
"encoding/json"
"net/http"
"time"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/jwt"
"github.com/spf13/viper"
)
@ -54,35 +53,40 @@ func (a *TokenAuth) GenTokenPair(accessClaims AppClaims, refreshClaims RefreshCl
// CreateJWT returns an access token for provided account claims.
func (a *TokenAuth) CreateJWT(c AppClaims) (string, error) {
token := jwt.New()
token.Set(jwt.IssuedAtKey, time.Now().Unix())
token.Set(jwt.ExpirationKey, time.Now().Add(a.JwtExpiry).Unix())
c.IssuedAt = time.Now().Unix()
c.ExpiresAt = time.Now().Add(a.JwtExpiry).Unix()
token.Set(jwt.SubjectKey, c.Sub)
token.Set(`id`, c.ID)
token.Set(`roles`, c.Roles)
tokenMap, err := token.AsMap(context.Background())
claims, err := ParseStructToMap(c)
if err != nil {
return "", err
}
_, tokenString, err := a.JwtAuth.Encode(tokenMap)
_, tokenString, err := a.JwtAuth.Encode(claims)
return tokenString, err
}
func ParseStructToMap(c interface{}) (map[string]interface{}, error) {
var claims map[string]interface{}
inrec, _ := json.Marshal(c)
err := json.Unmarshal(inrec, &claims)
if err != nil {
return nil, err
}
return claims, err
}
// CreateRefreshJWT returns a refresh token for provided token Claims.
func (a *TokenAuth) CreateRefreshJWT(c RefreshClaims) (string, error) {
token := jwt.New()
token.Set(jwt.IssuedAtKey, time.Now().Unix())
token.Set(jwt.ExpirationKey, time.Now().Add(a.JwtRefreshExpiry).Unix())
c.IssuedAt = time.Now().Unix()
c.ExpiresAt = time.Now().Add(a.JwtRefreshExpiry).Unix()
token.Set(`token`, c.Token)
tokenMap, err := token.AsMap(context.Background())
claims, err := ParseStructToMap(c)
if err != nil {
return "", err
}
_, tokenString, err := a.JwtAuth.Encode(tokenMap)
_, tokenString, err := a.JwtAuth.Encode(claims)
return tokenString, err
}