jwt transport: fix retry on unauthorized from CAPI(#3006)

This commit is contained in:
blotus 2024-05-24 14:43:12 +02:00 committed by GitHub
parent 09afcbe93a
commit f06e3e78ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 103 additions and 50 deletions

View file

@ -26,6 +26,7 @@ type JWTTransport struct {
URL *url.URL
VersionPrefix string
UserAgent string
RetryConfig *RetryConfig
// Transport is the underlying HTTP transport to use when making requests.
// It will default to http.DefaultTransport if nil.
Transport http.RoundTripper
@ -165,36 +166,67 @@ func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error)
// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req, err := t.prepareRequest(req)
if err != nil {
return nil, err
var resp *http.Response
attemptsCount := make(map[int]int)
for {
if log.GetLevel() >= log.TraceLevel {
// requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("req-jwt: %s", string(dump))
}
// Make the HTTP request.
clonedReq := cloneRequest(req)
clonedReq, err := t.prepareRequest(clonedReq)
if err != nil {
return nil, err
}
resp, err = t.transport().RoundTrip(clonedReq)
if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
}
if err != nil {
// we had an error (network error for example), reset the token?
t.ResetToken()
return resp, fmt.Errorf("performing jwt auth: %w", err)
}
if resp != nil {
log.Debugf("resp-jwt: %d", resp.StatusCode)
}
config, shouldRetry := t.RetryConfig.StatusCodeConfig[resp.StatusCode]
if !shouldRetry {
break
}
if attemptsCount[resp.StatusCode] >= config.MaxAttempts {
log.Infof("max attempts reached for status code %d", resp.StatusCode)
break
}
if config.InvalidateToken {
log.Debugf("invalidating token for status code %d", resp.StatusCode)
t.ResetToken()
}
log.Debugf("retrying request to %s", req.URL.String())
attemptsCount[resp.StatusCode]++
log.Infof("attempt %d out of %d", attemptsCount[resp.StatusCode], config.MaxAttempts)
if config.Backoff {
backoff := 2*attemptsCount[resp.StatusCode] + 5
log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, attemptsCount[resp.StatusCode], config.MaxAttempts)
time.Sleep(time.Duration(backoff) * time.Second)
}
}
if log.GetLevel() >= log.TraceLevel {
// requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("req-jwt: %s", string(dump))
}
// Make the HTTP request.
resp, err := t.transport().RoundTrip(req)
if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
}
if err != nil {
// we had an error (network error for example, or 401 because token is refused), reset the token?
t.ResetToken()
return resp, fmt.Errorf("performing jwt auth: %w", err)
}
if resp != nil {
log.Debugf("resp-jwt: %d", resp.StatusCode)
}
return resp, nil
}
func (t *JWTTransport) Client() *http.Client {
@ -211,27 +243,8 @@ func (t *JWTTransport) ResetToken() {
// transport() returns a round tripper that retries once when the status is unauthorized,
// and 5 times when the infrastructure is overloaded.
func (t *JWTTransport) transport() http.RoundTripper {
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}
return &retryRoundTripper{
next: &retryRoundTripper{
next: transport,
maxAttempts: 5,
withBackOff: true,
retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout},
},
maxAttempts: 2,
withBackOff: false,
retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden},
onBeforeRequest: func(attempt int) {
// reset the token only in the second attempt as this is when we know we had a 401 or 403
// the second attempt is supposed to refresh the token
if attempt > 0 {
t.ResetToken()
}
},
if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
}

View file

@ -72,6 +72,13 @@ func NewClient(config *Config) (*ApiClient, error) {
UserAgent: config.UserAgent,
VersionPrefix: config.VersionPrefix,
UpdateScenario: config.UpdateScenario,
RetryConfig: NewRetryConfig(
WithStatusCodeConfig(http.StatusUnauthorized, 2, false, true),
WithStatusCodeConfig(http.StatusForbidden, 2, false, true),
WithStatusCodeConfig(http.StatusTooManyRequests, 5, true, false),
WithStatusCodeConfig(http.StatusServiceUnavailable, 5, true, false),
WithStatusCodeConfig(http.StatusGatewayTimeout, 5, true, false),
),
}
transport, baseURL := createTransport(config.URL)

View file

@ -0,0 +1,33 @@
package apiclient
type StatusCodeConfig struct {
MaxAttempts int
Backoff bool
InvalidateToken bool
}
type RetryConfig struct {
StatusCodeConfig map[int]StatusCodeConfig
}
type RetryConfigOption func(*RetryConfig)
func NewRetryConfig(options ...RetryConfigOption) *RetryConfig {
rc := &RetryConfig{
StatusCodeConfig: make(map[int]StatusCodeConfig),
}
for _, opt := range options {
opt(rc)
}
return rc
}
func WithStatusCodeConfig(statusCode int, maxAttempts int, backOff bool, invalidateToken bool) RetryConfigOption {
return func(rc *RetryConfig) {
rc.StatusCodeConfig[statusCode] = StatusCodeConfig{
MaxAttempts: maxAttempts,
Backoff: backOff,
InvalidateToken: invalidateToken,
}
}
}