update go-chi and chi/jwtauth to v5

This commit is contained in:
dhax 2021-09-08 23:45:06 +02:00
parent 72fd12d0c4
commit f7b222b7f3
13 changed files with 646 additions and 102 deletions

View file

@ -4,7 +4,7 @@ import (
"context"
"net/http"
"github.com/go-chi/jwtauth"
"github.com/go-chi/jwtauth/v5"
"github.com/go-chi/render"
"github.com/dhax/go-base/logging"
@ -32,7 +32,7 @@ func RefreshTokenFromCtx(ctx context.Context) string {
// response for any unverified tokens and passes the good ones through.
func Authenticator(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, claims, err := jwtauth.FromContext(r.Context())
_, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
logging.GetLogEntry(r).Warn(err)
@ -40,11 +40,6 @@ func Authenticator(next http.Handler) http.Handler {
return
}
if !token.Valid {
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
return
}
// Token is authenticated, parse claims
var c AppClaims
err = c.ParseClaims(claims)
@ -63,16 +58,12 @@ func Authenticator(next http.Handler) http.Handler {
// AuthenticateRefreshJWT checks validity of refresh tokens and is only used for access token refresh and logout requests. It responds with 401 Unauthorized for invalid or expired refresh tokens.
func AuthenticateRefreshJWT(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, claims, err := jwtauth.FromContext(r.Context())
_, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
logging.GetLogEntry(r).Warn(err)
render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
return
}
if !token.Valid {
render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
return
}
// Token is authenticated, parse refresh token string
var c RefreshClaims

View file

@ -3,7 +3,7 @@ package jwt
import (
"errors"
"github.com/dgrijalva/jwt-go"
"github.com/lestrrat-go/jwx/jwt"
)
// AppClaims represent the claims parsed from JWT access token.
@ -11,18 +11,17 @@ type AppClaims struct {
ID int `json:"id,omitempty"`
Sub string `json:"sub,omitempty"`
Roles []string `json:"roles,omitempty"`
jwt.StandardClaims
}
// ParseClaims parses JWT claims into AppClaims.
func (c *AppClaims) ParseClaims(claims jwt.MapClaims) error {
func (c *AppClaims) ParseClaims(claims map[string]interface{}) error {
id, ok := claims["id"]
if !ok {
return errors.New("could not parse claim id")
}
c.ID = int(id.(float64))
sub, ok := claims["sub"]
sub, ok := claims[jwt.SubjectKey]
if !ok {
return errors.New("could not parse claim sub")
}
@ -48,11 +47,10 @@ func (c *AppClaims) ParseClaims(claims jwt.MapClaims) error {
type RefreshClaims struct {
ID int `json:"id,omitempty"`
Token string `json:"token,omitempty"`
jwt.StandardClaims
}
// ParseClaims parses the JWT claims into RefreshClaims.
func (c *RefreshClaims) ParseClaims(claims jwt.MapClaims) error {
func (c *RefreshClaims) ParseClaims(claims map[string]interface{}) error {
token, ok := claims["token"]
if !ok {
return errors.New("could not parse claim token")

View file

@ -1,11 +1,13 @@
package jwt
import (
"context"
"crypto/rand"
"net/http"
"time"
"github.com/go-chi/jwtauth"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/jwt"
"github.com/spf13/viper"
)
@ -52,17 +54,35 @@ func (a *TokenAuth) GenTokenPair(accessClaims AppClaims, refreshClaims RefreshCl
// CreateJWT returns an access token for provided account claims.
func (a *TokenAuth) CreateJWT(c AppClaims) (string, error) {
c.IssuedAt = time.Now().Unix()
c.ExpiresAt = time.Now().Add(a.JwtExpiry).Unix()
_, tokenString, err := a.JwtAuth.Encode(c)
token := jwt.New()
token.Set(jwt.IssuedAtKey, time.Now().Unix())
token.Set(jwt.ExpirationKey, time.Now().Add(a.JwtExpiry).Unix())
token.Set(jwt.SubjectKey, c.Sub)
token.Set(`id`, c.ID)
token.Set(`roles`, c.Roles)
tokenMap, err := token.AsMap(context.Background())
if err != nil {
return "", err
}
_, tokenString, err := a.JwtAuth.Encode(tokenMap)
return tokenString, err
}
// CreateRefreshJWT returns a refresh token for provided token Claims.
func (a *TokenAuth) CreateRefreshJWT(c RefreshClaims) (string, error) {
c.IssuedAt = time.Now().Unix()
c.ExpiresAt = time.Now().Add(a.JwtRefreshExpiry).Unix()
_, tokenString, err := a.JwtAuth.Encode(c)
token := jwt.New()
token.Set(jwt.IssuedAtKey, time.Now().Unix())
token.Set(jwt.ExpirationKey, time.Now().Add(a.JwtRefreshExpiry).Unix())
token.Set(`token`, c.Token)
tokenMap, err := token.AsMap(context.Background())
if err != nil {
return "", err
}
_, tokenString, err := a.JwtAuth.Encode(tokenMap)
return tokenString, err
}

View file

@ -12,7 +12,7 @@ import (
"github.com/dhax/go-base/auth/jwt"
"github.com/dhax/go-base/email"
"github.com/dhax/go-base/logging"
"github.com/go-chi/chi"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
validation "github.com/go-ozzo/ozzo-validation"
"github.com/go-ozzo/ozzo-validation/is"

View file

@ -2,6 +2,7 @@ package pwdless
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@ -14,8 +15,8 @@ import (
"testing"
"time"
jwt_go "github.com/dgrijalva/jwt-go"
"github.com/go-chi/chi"
"github.com/go-chi/chi/v5"
jwx_jwt "github.com/lestrrat-go/jwx/jwt"
"github.com/spf13/viper"
"github.com/dhax/go-base/auth/jwt"
@ -246,9 +247,6 @@ func TestAuthResource_refresh(t *testing.T) {
// }
refreshJWT := genRefreshJWT(jwt.RefreshClaims{
Token: tc.token,
StandardClaims: jwt_go.StandardClaims{
ExpiresAt: time.Now().Add(time.Minute * tc.exp).UnixNano(),
},
})
res, body := testRequest(t, ts, "POST", "/refresh", nil, refreshJWT)
@ -312,9 +310,6 @@ func TestAuthResource_logout(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
refreshJWT := genRefreshJWT(jwt.RefreshClaims{
Token: tc.token,
StandardClaims: jwt_go.StandardClaims{
ExpiresAt: time.Now().Add(time.Minute * tc.exp).UnixNano(),
},
})
res, body := testRequest(t, ts, "POST", "/logout", nil, refreshJWT)
@ -340,7 +335,7 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io
}
req.Header.Set("Content-Type", "application/json")
if token != "" {
req.Header.Set("Authorization", "BEARER "+token)
req.Header.Set("Authorization", fmt.Sprintf("BEARER %s", token))
}
resp, err := http.DefaultClient.Do(req)
@ -359,12 +354,30 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io
return resp, string(respBody)
}
func genJWT(c jwt.AppClaims) string {
_, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(c)
return tokenString
}
// func genJWT(c jwt.AppClaims) string {
// token := jwx_jwt.New()
// token.Set(jwx_jwt.IssuedAtKey, time.Now().Unix())
// token.Set(jwx_jwt.ExpirationKey, time.Now().Add(time.Duration(time.Minute)).Unix())
// tokenMap, err := token.AsMap(context.Background())
// if err != nil {
// return ""
// }
// _, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(tokenMap)
// return tokenString
// }
func genRefreshJWT(c jwt.RefreshClaims) string {
_, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(c)
token := jwx_jwt.New()
token.Set(jwx_jwt.IssuedAtKey, time.Now())
token.Set(jwx_jwt.ExpirationKey, time.Now().Add(time.Duration(time.Minute)))
token.Set(`token`, c.Token)
tokenMap, err := token.AsMap(context.Background())
if err != nil {
return ""
}
_, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(tokenMap)
return tokenString
}