mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-11 20:36:12 +02:00
jwt transport: fix retry on unauthorized from CAPI(#3006)
This commit is contained in:
parent
09afcbe93a
commit
f06e3e78ab
3 changed files with 103 additions and 50 deletions
|
@ -26,6 +26,7 @@ type JWTTransport struct {
|
|||
URL *url.URL
|
||||
VersionPrefix string
|
||||
UserAgent string
|
||||
RetryConfig *RetryConfig
|
||||
// Transport is the underlying HTTP transport to use when making requests.
|
||||
// It will default to http.DefaultTransport if nil.
|
||||
Transport http.RoundTripper
|
||||
|
@ -165,36 +166,67 @@ func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error)
|
|||
|
||||
// RoundTrip implements the RoundTripper interface.
|
||||
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req, err := t.prepareRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
var resp *http.Response
|
||||
attemptsCount := make(map[int]int)
|
||||
|
||||
for {
|
||||
if log.GetLevel() >= log.TraceLevel {
|
||||
// requestToDump := cloneRequest(req)
|
||||
dump, _ := httputil.DumpRequest(req, true)
|
||||
log.Tracef("req-jwt: %s", string(dump))
|
||||
}
|
||||
// Make the HTTP request.
|
||||
clonedReq := cloneRequest(req)
|
||||
|
||||
clonedReq, err := t.prepareRequest(clonedReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err = t.transport().RoundTrip(clonedReq)
|
||||
if log.GetLevel() >= log.TraceLevel {
|
||||
dump, _ := httputil.DumpResponse(resp, true)
|
||||
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// we had an error (network error for example), reset the token?
|
||||
t.ResetToken()
|
||||
return resp, fmt.Errorf("performing jwt auth: %w", err)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
log.Debugf("resp-jwt: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
config, shouldRetry := t.RetryConfig.StatusCodeConfig[resp.StatusCode]
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
|
||||
if attemptsCount[resp.StatusCode] >= config.MaxAttempts {
|
||||
log.Infof("max attempts reached for status code %d", resp.StatusCode)
|
||||
break
|
||||
}
|
||||
|
||||
if config.InvalidateToken {
|
||||
log.Debugf("invalidating token for status code %d", resp.StatusCode)
|
||||
t.ResetToken()
|
||||
}
|
||||
|
||||
log.Debugf("retrying request to %s", req.URL.String())
|
||||
attemptsCount[resp.StatusCode]++
|
||||
log.Infof("attempt %d out of %d", attemptsCount[resp.StatusCode], config.MaxAttempts)
|
||||
|
||||
if config.Backoff {
|
||||
backoff := 2*attemptsCount[resp.StatusCode] + 5
|
||||
log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, attemptsCount[resp.StatusCode], config.MaxAttempts)
|
||||
time.Sleep(time.Duration(backoff) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
if log.GetLevel() >= log.TraceLevel {
|
||||
// requestToDump := cloneRequest(req)
|
||||
dump, _ := httputil.DumpRequest(req, true)
|
||||
log.Tracef("req-jwt: %s", string(dump))
|
||||
}
|
||||
|
||||
// Make the HTTP request.
|
||||
resp, err := t.transport().RoundTrip(req)
|
||||
if log.GetLevel() >= log.TraceLevel {
|
||||
dump, _ := httputil.DumpResponse(resp, true)
|
||||
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// we had an error (network error for example, or 401 because token is refused), reset the token?
|
||||
t.ResetToken()
|
||||
|
||||
return resp, fmt.Errorf("performing jwt auth: %w", err)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
log.Debugf("resp-jwt: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
||||
}
|
||||
|
||||
func (t *JWTTransport) Client() *http.Client {
|
||||
|
@ -211,27 +243,8 @@ func (t *JWTTransport) ResetToken() {
|
|||
// transport() returns a round tripper that retries once when the status is unauthorized,
|
||||
// and 5 times when the infrastructure is overloaded.
|
||||
func (t *JWTTransport) transport() http.RoundTripper {
|
||||
transport := t.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
return &retryRoundTripper{
|
||||
next: &retryRoundTripper{
|
||||
next: transport,
|
||||
maxAttempts: 5,
|
||||
withBackOff: true,
|
||||
retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout},
|
||||
},
|
||||
maxAttempts: 2,
|
||||
withBackOff: false,
|
||||
retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden},
|
||||
onBeforeRequest: func(attempt int) {
|
||||
// reset the token only in the second attempt as this is when we know we had a 401 or 403
|
||||
// the second attempt is supposed to refresh the token
|
||||
if attempt > 0 {
|
||||
t.ResetToken()
|
||||
}
|
||||
},
|
||||
if t.Transport != nil {
|
||||
return t.Transport
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
|
|
@ -72,6 +72,13 @@ func NewClient(config *Config) (*ApiClient, error) {
|
|||
UserAgent: config.UserAgent,
|
||||
VersionPrefix: config.VersionPrefix,
|
||||
UpdateScenario: config.UpdateScenario,
|
||||
RetryConfig: NewRetryConfig(
|
||||
WithStatusCodeConfig(http.StatusUnauthorized, 2, false, true),
|
||||
WithStatusCodeConfig(http.StatusForbidden, 2, false, true),
|
||||
WithStatusCodeConfig(http.StatusTooManyRequests, 5, true, false),
|
||||
WithStatusCodeConfig(http.StatusServiceUnavailable, 5, true, false),
|
||||
WithStatusCodeConfig(http.StatusGatewayTimeout, 5, true, false),
|
||||
),
|
||||
}
|
||||
|
||||
transport, baseURL := createTransport(config.URL)
|
||||
|
|
33
pkg/apiclient/retry_config.go
Normal file
33
pkg/apiclient/retry_config.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package apiclient
|
||||
|
||||
type StatusCodeConfig struct {
|
||||
MaxAttempts int
|
||||
Backoff bool
|
||||
InvalidateToken bool
|
||||
}
|
||||
|
||||
type RetryConfig struct {
|
||||
StatusCodeConfig map[int]StatusCodeConfig
|
||||
}
|
||||
|
||||
type RetryConfigOption func(*RetryConfig)
|
||||
|
||||
func NewRetryConfig(options ...RetryConfigOption) *RetryConfig {
|
||||
rc := &RetryConfig{
|
||||
StatusCodeConfig: make(map[int]StatusCodeConfig),
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(rc)
|
||||
}
|
||||
return rc
|
||||
}
|
||||
|
||||
func WithStatusCodeConfig(statusCode int, maxAttempts int, backOff bool, invalidateToken bool) RetryConfigOption {
|
||||
return func(rc *RetryConfig) {
|
||||
rc.StatusCodeConfig[statusCode] = StatusCodeConfig{
|
||||
MaxAttempts: maxAttempts,
|
||||
Backoff: backOff,
|
||||
InvalidateToken: invalidateToken,
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue