diff --git a/auth/jwt/claims.go b/auth/jwt/claims.go index 6ad3da0..1513c1a 100644 --- a/auth/jwt/claims.go +++ b/auth/jwt/claims.go @@ -2,18 +2,20 @@ package jwt import ( "errors" - "github.com/go-chi/jwtauth" + + "github.com/dgrijalva/jwt-go" ) // AppClaims represent the claims parsed from JWT access token. type AppClaims struct { - ID int - Sub string - Roles []string + ID int `json:"id,omitempty"` + Sub string `json:"sub,omitempty"` + Roles []string `json:"roles,omitempty"` + jwt.StandardClaims } // ParseClaims parses JWT claims into AppClaims. -func (c *AppClaims) ParseClaims(claims jwtauth.Claims) error { +func (c *AppClaims) ParseClaims(claims jwt.MapClaims) error { id, ok := claims["id"] if !ok { return errors.New("could not parse claim id") @@ -42,13 +44,15 @@ func (c *AppClaims) ParseClaims(claims jwtauth.Claims) error { return nil } -// RefreshClaims represent the claims parsed from JWT refresh token. +// RefreshClaims represents the claims parsed from JWT refresh token. type RefreshClaims struct { - Token string + ID int `json:"id,omitempty"` + Token string `json:"token,omitempty"` + jwt.StandardClaims } // ParseClaims parses the JWT claims into RefreshClaims. -func (c *RefreshClaims) ParseClaims(claims jwtauth.Claims) error { +func (c *RefreshClaims) ParseClaims(claims jwt.MapClaims) error { token, ok := claims["token"] if !ok { return errors.New("could not parse claim token") diff --git a/auth/jwt/token.go b/auth/jwt/token.go index 48aed23..ffb20ff 100644 --- a/auth/jwt/token.go +++ b/auth/jwt/token.go @@ -3,7 +3,6 @@ package jwt import ( "time" - "github.com/go-chi/jwtauth" "github.com/go-pg/pg/orm" ) @@ -37,9 +36,9 @@ func (t *Token) BeforeUpdate(db orm.DB) error { } // Claims returns the token claims to be signed -func (t *Token) Claims() jwtauth.Claims { - return jwtauth.Claims{ - "id": t.ID, - "token": t.Token, +func (t *Token) Claims() RefreshClaims { + return RefreshClaims{ + ID: t.ID, + Token: t.Token, } } diff --git a/auth/jwt/tokenauth.go b/auth/jwt/tokenauth.go index 2d2d589..8e49ec9 100644 --- a/auth/jwt/tokenauth.go +++ b/auth/jwt/tokenauth.go @@ -38,12 +38,12 @@ func (a *TokenAuth) Verifier() func(http.Handler) http.Handler { } // GenTokenPair returns both an access token and a refresh token. -func (a *TokenAuth) GenTokenPair(ca jwtauth.Claims, cr jwtauth.Claims) (string, string, error) { - access, err := a.CreateJWT(ca) +func (a *TokenAuth) GenTokenPair(accessClaims AppClaims, refreshClaims RefreshClaims) (string, string, error) { + access, err := a.CreateJWT(accessClaims) if err != nil { return "", "", err } - refresh, err := a.CreateRefreshJWT(cr) + refresh, err := a.CreateRefreshJWT(refreshClaims) if err != nil { return "", "", err } @@ -51,17 +51,17 @@ func (a *TokenAuth) GenTokenPair(ca jwtauth.Claims, cr jwtauth.Claims) (string, } // CreateJWT returns an access token for provided account claims. -func (a *TokenAuth) CreateJWT(c jwtauth.Claims) (string, error) { - c.SetIssuedNow() - c.SetExpiryIn(a.JwtExpiry) +func (a *TokenAuth) CreateJWT(c AppClaims) (string, error) { + c.IssuedAt = time.Now().Unix() + c.ExpiresAt = time.Now().Add(a.JwtExpiry).Unix() _, tokenString, err := a.JwtAuth.Encode(c) return tokenString, err } // CreateRefreshJWT returns a refresh token for provided token Claims. -func (a *TokenAuth) CreateRefreshJWT(c jwtauth.Claims) (string, error) { - c.SetIssuedNow() - c.SetExpiryIn(a.JwtRefreshExpiry) +func (a *TokenAuth) CreateRefreshJWT(c RefreshClaims) (string, error) { + c.IssuedAt = time.Now().Unix() + c.ExpiresAt = time.Now().Add(a.JwtExpiry).Unix() _, tokenString, err := a.JwtAuth.Encode(c) return tokenString, err } diff --git a/auth/pwdless/account.go b/auth/pwdless/account.go index 74a623c..95fa756 100644 --- a/auth/pwdless/account.go +++ b/auth/pwdless/account.go @@ -4,11 +4,11 @@ import ( "strings" "time" - "github.com/dhax/go-base/auth/jwt" - "github.com/go-chi/jwtauth" validation "github.com/go-ozzo/ozzo-validation" "github.com/go-ozzo/ozzo-validation/is" "github.com/go-pg/pg/orm" + + "github.com/dhax/go-base/auth/jwt" ) // Account represents an authenticated application user @@ -65,10 +65,10 @@ func (a *Account) CanLogin() bool { } // Claims returns the account's claims to be signed -func (a *Account) Claims() jwtauth.Claims { - return jwtauth.Claims{ - "id": a.ID, - "sub": a.Name, - "roles": a.Roles, +func (a *Account) Claims() jwt.AppClaims { + return jwt.AppClaims{ + ID: a.ID, + Sub: a.Name, + Roles: a.Roles, } } diff --git a/auth/pwdless/api_test.go b/auth/pwdless/api_test.go index 11ce752..33ea053 100644 --- a/auth/pwdless/api_test.go +++ b/auth/pwdless/api_test.go @@ -14,8 +14,8 @@ import ( "testing" "time" + jwt_go "github.com/dgrijalva/jwt-go" "github.com/go-chi/chi" - "github.com/go-chi/jwtauth" "github.com/spf13/viper" "github.com/dhax/go-base/auth/jwt" @@ -209,7 +209,7 @@ func TestAuthResource_refresh(t *testing.T) { t.Expiry = time.Now().Add(1 * time.Minute) switch token { - case "notfound": + case "not_found": err = errors.New("sql no rows") case "expired": t.Expiry = time.Now().Add(-1 * time.Minute) @@ -232,16 +232,26 @@ func TestAuthResource_refresh(t *testing.T) { status int err error }{ - {"notfound", "notfound", 1, http.StatusUnauthorized, jwt.ErrTokenExpired}, - {"expired", "expired", -1, http.StatusUnauthorized, jwt.ErrTokenUnauthorized}, + {"not_found", "not_found", 1, http.StatusUnauthorized, jwt.ErrTokenExpired}, + {"expired", "expired", -1, http.StatusUnauthorized, jwt.ErrTokenExpired}, {"disabled", "disabled", 1, http.StatusUnauthorized, ErrLoginDisabled}, {"valid", "valid", 1, http.StatusOK, nil}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - jwt := genJWT(jwtauth.Claims{"token": tc.token, "exp": time.Minute * tc.exp}) - res, body := testRequest(t, ts, "POST", "/refresh", nil, jwt) + // refreshJWT, err := auth.TokenAuth.CreateRefreshJWT(jwt.RefreshClaims{Token: tc.token}) + // if err != nil { + // t.Errorf("failed to create refresh jwt") + // } + refreshJWT := genRefreshJWT(jwt.RefreshClaims{ + Token: tc.token, + StandardClaims: jwt_go.StandardClaims{ + ExpiresAt: time.Now().Add(time.Minute * tc.exp).UnixNano(), + }, + }) + + res, body := testRequest(t, ts, "POST", "/refresh", nil, refreshJWT) if res.StatusCode != tc.status { t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status) } @@ -294,14 +304,20 @@ func TestAuthResource_logout(t *testing.T) { err error }{ {"notfound", "notfound", 1, http.StatusUnauthorized, jwt.ErrTokenExpired}, - {"expired", "valid", -1, http.StatusUnauthorized, jwt.ErrTokenUnauthorized}, + {"expired", "valid", -1, http.StatusOK, nil}, {"valid", "valid", 1, http.StatusOK, nil}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - jwt := genJWT(jwtauth.Claims{"token": tc.token, "exp": time.Minute * tc.exp}) - res, body := testRequest(t, ts, "POST", "/logout", nil, jwt) + refreshJWT := genRefreshJWT(jwt.RefreshClaims{ + Token: tc.token, + StandardClaims: jwt_go.StandardClaims{ + ExpiresAt: time.Now().Add(time.Minute * tc.exp).UnixNano(), + }, + }) + + res, body := testRequest(t, ts, "POST", "/logout", nil, refreshJWT) if res.StatusCode != tc.status { t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status) } @@ -343,7 +359,11 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io return resp, string(respBody) } -func genJWT(c jwtauth.Claims) string { +func genJWT(c jwt.AppClaims) string { + _, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(c) + return tokenString +} +func genRefreshJWT(c jwt.RefreshClaims) string { _, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(c) return tokenString } diff --git a/go.mod b/go.mod index 725e696..70ebbb8 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,11 @@ require ( github.com/andybalholm/cascadia v1.0.0 // indirect github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf // indirect github.com/coreos/go-etcd v2.0.0+incompatible // indirect - github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/go-chi/chi v4.0.3+incompatible github.com/go-chi/cors v1.0.0 github.com/go-chi/docgen v1.0.5 - github.com/go-chi/jwtauth v3.3.0+incompatible + github.com/go-chi/jwtauth v4.0.4+incompatible github.com/go-chi/render v1.0.1 github.com/go-mail/mail v2.3.1+incompatible github.com/go-ozzo/ozzo-validation v3.5.0+incompatible diff --git a/go.sum b/go.sum index d14f898..f2712f9 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/go-chi/docgen v1.0.5 h1:TiGvJAuVPZJ9zFSwoF52eORe0SztOYqf9C79LVw/xbY= github.com/go-chi/docgen v1.0.5/go.mod h1:Nm4H4RaynSlvTexxWYWwXBzrwZKRE00MrkIIcJelhWM= github.com/go-chi/jwtauth v3.3.0+incompatible h1:BEOEx6OueP61EfhuOTDqgroY0SYdcFsFsbY/n4f5+Kk= github.com/go-chi/jwtauth v3.3.0+incompatible/go.mod h1:Q5EIArY/QnD6BdS+IyDw7B2m6iNbnPxtfd6/BcmtWbs= +github.com/go-chi/jwtauth v4.0.4+incompatible h1:LGIxg6YfvSBzxU2BljXbrzVc1fMlgqSKBQgKOGAVtPY= +github.com/go-chi/jwtauth v4.0.4+incompatible/go.mod h1:Q5EIArY/QnD6BdS+IyDw7B2m6iNbnPxtfd6/BcmtWbs= github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=