Decoupled code from DefaultSigningKey (#16743)
Decoupled code from `DefaultSigningKey`. Makes testing a little bit easier and is cleaner.
This commit is contained in:
parent
cd8db3a83d
commit
88abb0dc8a
4 changed files with 27 additions and 27 deletions
|
@ -115,7 +115,7 @@ type AccessTokenResponse struct {
|
||||||
IDToken string `json:"id_token,omitempty"`
|
IDToken string `json:"id_token,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) {
|
func newAccessTokenResponse(grant *models.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) {
|
||||||
if setting.OAuth2.InvalidateRefreshTokens {
|
if setting.OAuth2.InvalidateRefreshTokens {
|
||||||
if err := grant.IncreaseCounter(); err != nil {
|
if err := grant.IncreaseCounter(); err != nil {
|
||||||
return nil, &AccessTokenError{
|
return nil, &AccessTokenError{
|
||||||
|
@ -133,7 +133,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
|
||||||
ExpiresAt: expirationDate.AsTime().Unix(),
|
ExpiresAt: expirationDate.AsTime().Unix(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
signedAccessToken, err := accessToken.SignToken()
|
signedAccessToken, err := accessToken.SignToken(serverKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &AccessTokenError{
|
return nil, &AccessTokenError{
|
||||||
ErrorCode: AccessTokenErrorCodeInvalidRequest,
|
ErrorCode: AccessTokenErrorCodeInvalidRequest,
|
||||||
|
@ -151,7 +151,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
|
||||||
ExpiresAt: refreshExpirationDate,
|
ExpiresAt: refreshExpirationDate,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
signedRefreshToken, err := refreshToken.SignToken()
|
signedRefreshToken, err := refreshToken.SignToken(serverKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &AccessTokenError{
|
return nil, &AccessTokenError{
|
||||||
ErrorCode: AccessTokenErrorCodeInvalidRequest,
|
ErrorCode: AccessTokenErrorCodeInvalidRequest,
|
||||||
|
@ -207,7 +207,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
|
||||||
idToken.EmailVerified = user.IsActive
|
idToken.EmailVerified = user.IsActive
|
||||||
}
|
}
|
||||||
|
|
||||||
signedIDToken, err = idToken.SignToken(signingKey)
|
signedIDToken, err = idToken.SignToken(clientKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &AccessTokenError{
|
return nil, &AccessTokenError{
|
||||||
ErrorCode: AccessTokenErrorCodeInvalidRequest,
|
ErrorCode: AccessTokenErrorCodeInvalidRequest,
|
||||||
|
@ -265,7 +265,7 @@ func IntrospectOAuth(ctx *context.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
form := web.GetForm(ctx).(*forms.IntrospectTokenForm)
|
form := web.GetForm(ctx).(*forms.IntrospectTokenForm)
|
||||||
token, err := oauth2.ParseToken(form.Token)
|
token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if token.Valid() == nil {
|
if token.Valid() == nil {
|
||||||
grant, err := models.GetOAuth2GrantByID(token.GrantID)
|
grant, err := models.GetOAuth2GrantByID(token.GrantID)
|
||||||
|
@ -544,9 +544,11 @@ func AccessTokenOAuth(ctx *context.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
signingKey := oauth2.DefaultSigningKey
|
serverKey := oauth2.DefaultSigningKey
|
||||||
if signingKey.IsSymmetric() {
|
clientKey := serverKey
|
||||||
clientKey, err := oauth2.CreateJWTSigningKey(signingKey.SigningMethod().Alg(), []byte(form.ClientSecret))
|
if serverKey.IsSymmetric() {
|
||||||
|
var err error
|
||||||
|
clientKey, err = oauth2.CreateJWTSigningKey(serverKey.SigningMethod().Alg(), []byte(form.ClientSecret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleAccessTokenError(ctx, AccessTokenError{
|
handleAccessTokenError(ctx, AccessTokenError{
|
||||||
ErrorCode: AccessTokenErrorCodeInvalidRequest,
|
ErrorCode: AccessTokenErrorCodeInvalidRequest,
|
||||||
|
@ -554,14 +556,13 @@ func AccessTokenOAuth(ctx *context.Context) {
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
signingKey = clientKey
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch form.GrantType {
|
switch form.GrantType {
|
||||||
case "refresh_token":
|
case "refresh_token":
|
||||||
handleRefreshToken(ctx, form, signingKey)
|
handleRefreshToken(ctx, form, serverKey, clientKey)
|
||||||
case "authorization_code":
|
case "authorization_code":
|
||||||
handleAuthorizationCode(ctx, form, signingKey)
|
handleAuthorizationCode(ctx, form, serverKey, clientKey)
|
||||||
default:
|
default:
|
||||||
handleAccessTokenError(ctx, AccessTokenError{
|
handleAccessTokenError(ctx, AccessTokenError{
|
||||||
ErrorCode: AccessTokenErrorCodeUnsupportedGrantType,
|
ErrorCode: AccessTokenErrorCodeUnsupportedGrantType,
|
||||||
|
@ -570,8 +571,8 @@ func AccessTokenOAuth(ctx *context.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) {
|
func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) {
|
||||||
token, err := oauth2.ParseToken(form.RefreshToken)
|
token, err := oauth2.ParseToken(form.RefreshToken, serverKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleAccessTokenError(ctx, AccessTokenError{
|
handleAccessTokenError(ctx, AccessTokenError{
|
||||||
ErrorCode: AccessTokenErrorCodeUnauthorizedClient,
|
ErrorCode: AccessTokenErrorCodeUnauthorizedClient,
|
||||||
|
@ -598,7 +599,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin
|
||||||
log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID)
|
log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
accessToken, tokenErr := newAccessTokenResponse(grant, signingKey)
|
accessToken, tokenErr := newAccessTokenResponse(grant, serverKey, clientKey)
|
||||||
if tokenErr != nil {
|
if tokenErr != nil {
|
||||||
handleAccessTokenError(ctx, *tokenErr)
|
handleAccessTokenError(ctx, *tokenErr)
|
||||||
return
|
return
|
||||||
|
@ -606,7 +607,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin
|
||||||
ctx.JSON(http.StatusOK, accessToken)
|
ctx.JSON(http.StatusOK, accessToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) {
|
func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) {
|
||||||
app, err := models.GetOAuth2ApplicationByClientID(form.ClientID)
|
app, err := models.GetOAuth2ApplicationByClientID(form.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleAccessTokenError(ctx, AccessTokenError{
|
handleAccessTokenError(ctx, AccessTokenError{
|
||||||
|
@ -660,7 +661,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s
|
||||||
ErrorDescription: "cannot proceed your request",
|
ErrorDescription: "cannot proceed your request",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, signingKey)
|
resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, serverKey, clientKey)
|
||||||
if tokenErr != nil {
|
if tokenErr != nil {
|
||||||
handleAccessTokenError(ctx, *tokenErr)
|
handleAccessTokenError(ctx, *tokenErr)
|
||||||
return
|
return
|
||||||
|
|
|
@ -18,9 +18,8 @@ func createAndParseToken(t *testing.T, grant *models.OAuth2Grant) *oauth2.OIDCTo
|
||||||
signingKey, err := oauth2.CreateJWTSigningKey("HS256", make([]byte, 32))
|
signingKey, err := oauth2.CreateJWTSigningKey("HS256", make([]byte, 32))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, signingKey)
|
assert.NotNil(t, signingKey)
|
||||||
oauth2.DefaultSigningKey = signingKey
|
|
||||||
|
|
||||||
response, terr := newAccessTokenResponse(grant, signingKey)
|
response, terr := newAccessTokenResponse(grant, signingKey, signingKey)
|
||||||
assert.Nil(t, terr)
|
assert.Nil(t, terr)
|
||||||
assert.NotNil(t, response)
|
assert.NotNil(t, response)
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,9 @@ func CheckOAuthAccessToken(accessToken string) int64 {
|
||||||
if !strings.Contains(accessToken, ".") {
|
if !strings.Contains(accessToken, ".") {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
token, err := oauth2.ParseToken(accessToken)
|
token, err := oauth2.ParseToken(accessToken, oauth2.DefaultSigningKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace("ParseOAuth2Token: %v", err)
|
log.Trace("oauth2.ParseToken: %v", err)
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
var grant *models.OAuth2Grant
|
var grant *models.OAuth2Grant
|
||||||
|
|
|
@ -40,12 +40,12 @@ type Token struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseToken parses a signed jwt string
|
// ParseToken parses a signed jwt string
|
||||||
func ParseToken(jwtToken string) (*Token, error) {
|
func ParseToken(jwtToken string, signingKey JWTSigningKey) (*Token, error) {
|
||||||
parsedToken, err := jwt.ParseWithClaims(jwtToken, &Token{}, func(token *jwt.Token) (interface{}, error) {
|
parsedToken, err := jwt.ParseWithClaims(jwtToken, &Token{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
if token.Method == nil || token.Method.Alg() != DefaultSigningKey.SigningMethod().Alg() {
|
if token.Method == nil || token.Method.Alg() != signingKey.SigningMethod().Alg() {
|
||||||
return nil, fmt.Errorf("unexpected signing algo: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing algo: %v", token.Header["alg"])
|
||||||
}
|
}
|
||||||
return DefaultSigningKey.VerifyKey(), nil
|
return signingKey.VerifyKey(), nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -59,11 +59,11 @@ func ParseToken(jwtToken string) (*Token, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignToken signs the token with the JWT secret
|
// SignToken signs the token with the JWT secret
|
||||||
func (token *Token) SignToken() (string, error) {
|
func (token *Token) SignToken(signingKey JWTSigningKey) (string, error) {
|
||||||
token.IssuedAt = time.Now().Unix()
|
token.IssuedAt = time.Now().Unix()
|
||||||
jwtToken := jwt.NewWithClaims(DefaultSigningKey.SigningMethod(), token)
|
jwtToken := jwt.NewWithClaims(signingKey.SigningMethod(), token)
|
||||||
DefaultSigningKey.PreProcessToken(jwtToken)
|
signingKey.PreProcessToken(jwtToken)
|
||||||
return jwtToken.SignedString(DefaultSigningKey.SignKey())
|
return jwtToken.SignedString(signingKey.SignKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
// OIDCToken represents an OpenID Connect id_token
|
// OIDCToken represents an OpenID Connect id_token
|
||||||
|
|
Reference in a new issue