refact: context propagation (apiclient, cticlient...) (#3477)

This commit is contained in:
mmetc 2025-02-21 13:23:39 +01:00 committed by GitHub
parent 8da6a4dc92
commit a3187d6f2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 131 additions and 68 deletions

View file

@ -194,19 +194,16 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, metricsLe
// let's load the associated appsec_config:
if w.config.AppsecConfigPath != "" {
err := appsecCfg.LoadByPath(w.config.AppsecConfigPath)
if err != nil {
if err = appsecCfg.LoadByPath(w.config.AppsecConfigPath); err != nil {
return fmt.Errorf("unable to load appsec_config: %w", err)
}
} else if w.config.AppsecConfig != "" {
err := appsecCfg.Load(w.config.AppsecConfig)
if err != nil {
if err = appsecCfg.Load(w.config.AppsecConfig); err != nil {
return fmt.Errorf("unable to load appsec_config: %w", err)
}
} else if len(w.config.AppsecConfigs) > 0 {
for _, appsecConfig := range w.config.AppsecConfigs {
err := appsecCfg.Load(appsecConfig)
if err != nil {
if err = appsecCfg.Load(appsecConfig); err != nil {
return fmt.Errorf("unable to load appsec_config: %w", err)
}
}
@ -233,6 +230,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, metricsLe
if err != nil {
return fmt.Errorf("unable to get authenticated LAPI client: %w", err)
}
w.appsecAllowlistClient = allowlists.NewAppsecAllowlist(w.apiClient, w.logger)
for nbRoutine := range w.config.Routines {
@ -371,12 +369,12 @@ func (w *AppsecSource) Dump() interface{} {
return w
}
func (w *AppsecSource) IsAuth(apiKey string) bool {
func (w *AppsecSource) IsAuth(ctx context.Context, apiKey string) bool {
client := &http.Client{
Timeout: 200 * time.Millisecond,
}
req, err := http.NewRequest(http.MethodHead, w.lapiURL, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodHead, w.lapiURL, nil)
if err != nil {
log.Errorf("Error creating request: %s", err)
return false
@ -397,6 +395,7 @@ func (w *AppsecSource) IsAuth(apiKey string) bool {
// should this be in the runner ?
func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
w.logger.Debugf("Received request from '%s' on %s", r.RemoteAddr, r.URL.Path)
apiKey := r.Header.Get(appsec.APIKeyHeaderName)
@ -413,7 +412,7 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
expiration, exists := w.AuthCache.Get(apiKey)
// if the apiKey is not in cache or has expired, just recheck the auth
if !exists || time.Now().After(expiration) {
if !w.IsAuth(apiKey) {
if !w.IsAuth(ctx, apiKey) {
rw.WriteHeader(http.StatusUnauthorized)
w.logger.Errorf("Unauthorized request from '%s' (real IP = %s)", remoteIP, clientIP)

View file

@ -288,6 +288,7 @@ basic_auth:
}
func TestStreamingAcquisitionBasicAuth(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -306,7 +307,7 @@ basic_auth:
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
require.NoError(t, err)
req.SetBasicAuth("test", "WrongPassword")
@ -321,6 +322,7 @@ basic_auth:
}
func TestStreamingAcquisitionBadHeaders(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -334,7 +336,7 @@ headers:
client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
require.NoError(t, err)
req.Header.Add("Key", "wrong")
@ -349,6 +351,7 @@ headers:
}
func TestStreamingAcquisitionMaxBodySize(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -362,7 +365,7 @@ max_body_size: 5`), 0)
time.Sleep(1 * time.Second)
client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("testtest"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("testtest"))
require.NoError(t, err)
req.Header.Add("Key", "test")
@ -378,6 +381,7 @@ max_body_size: 5`), 0)
}
func TestStreamingAcquisitionSuccess(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -388,13 +392,14 @@ headers:
key: test`), 2)
time.Sleep(1 * time.Second)
rawEvt := `{"test": "test"}`
errChan := make(chan error)
go assertEvents(out, []string{rawEvt}, errChan)
client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
require.NoError(t, err)
req.Header.Add("Key", "test")
@ -414,6 +419,7 @@ headers:
}
func TestStreamingAcquisitionCustomStatusCodeAndCustomHeaders(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -430,9 +436,10 @@ custom_headers:
rawEvt := `{"test": "test"}`
errChan := make(chan error)
go assertEvents(out, []string{rawEvt}, errChan)
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
require.NoError(t, err)
req.Header.Add("Key", "test")
@ -463,9 +470,11 @@ func (sr *slowReader) Read(p []byte) (int, error) {
if sr.index >= len(sr.body) {
return 0, io.EOF
}
time.Sleep(sr.delay) // Simulate a delay in reading
n := copy(p, sr.body[sr.index:])
sr.index += n
return n, nil
}
@ -492,10 +501,12 @@ func assertEvents(out chan types.Event, expected []string, errChan chan error) {
errChan <- fmt.Errorf(`expected %s, got '%+v'`, expected, evt.Line.Raw)
return
}
if evt.Line.Src != "127.0.0.1" {
errChan <- fmt.Errorf("expected '127.0.0.1', got '%s'", evt.Line.Src)
return
}
if evt.Line.Module != "http" {
errChan <- fmt.Errorf("expected 'http', got '%s'", evt.Line.Module)
return
@ -505,6 +516,7 @@ func assertEvents(out chan types.Event, expected []string, errChan chan error) {
}
func TestStreamingAcquisitionTimeout(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -522,7 +534,7 @@ timeout: 1s`), 0)
body: []byte(`{"test": "delayed_payload"}`),
}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), slow)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), slow)
require.NoError(t, err)
req.Header.Add("Key", "test")
@ -566,6 +578,7 @@ tls:
}
func TestStreamingAcquisitionTLSWithHeadersAuthSuccess(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -599,9 +612,10 @@ tls:
rawEvt := `{"test": "test"}`
errChan := make(chan error)
go assertEvents(out, []string{rawEvt}, errChan)
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
require.NoError(t, err)
req.Header.Add("Key", "test")
@ -622,6 +636,7 @@ tls:
}
func TestStreamingAcquisitionMTLS(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -658,9 +673,10 @@ tls:
rawEvt := `{"test": "test"}`
errChan := make(chan error)
go assertEvents(out, []string{rawEvt}, errChan)
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
require.NoError(t, err)
resp, err := client.Do(req)
@ -680,6 +696,7 @@ tls:
}
func TestStreamingAcquisitionGzipData(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -693,6 +710,7 @@ headers:
rawEvt := `{"test": "test"}`
errChan := make(chan error)
go assertEvents(out, []string{rawEvt, rawEvt}, errChan)
var b strings.Builder
@ -709,7 +727,7 @@ headers:
// send gzipped compressed data
client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(b.String()))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(b.String()))
require.NoError(t, err)
req.Header.Add("Key", "test")
@ -733,6 +751,7 @@ headers:
}
func TestStreamingAcquisitionNDJson(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
@ -743,13 +762,14 @@ headers:
key: test`), 2)
time.Sleep(1 * time.Second)
rawEvt := `{"test": "test"}`
rawEvt := `{"test": "test"}`
errChan := make(chan error)
go assertEvents(out, []string{rawEvt, rawEvt}, errChan)
client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(fmt.Sprintf("%s\n%s\n", rawEvt, rawEvt)))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(fmt.Sprintf("%s\n%s\n", rawEvt, rawEvt)))
require.NoError(t, err)
@ -776,10 +796,13 @@ func assertMetrics(t *testing.T, reg *prometheus.Registry, metrics []prometheus.
require.NoError(t, err)
isExist := false
for _, metricFamily := range promMetrics {
if metricFamily.GetName() == "cs_httpsource_hits_total" {
isExist = true
assert.Len(t, metricFamily.GetMetric(), 1)
for _, metric := range metricFamily.GetMetric() {
assert.InDelta(t, float64(expected), metric.GetCounter().GetValue(), 0.000001)
labels := metric.GetLabel()
@ -791,6 +814,7 @@ func assertMetrics(t *testing.T, reg *prometheus.Registry, metrics []prometheus.
}
}
}
if !isExist && expected > 0 {
t.Fatalf("expected metric cs_httpsource_hits_total not found")
}

View file

@ -49,7 +49,7 @@ type AlertsDeleteOpts struct {
func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) {
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &alerts)
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &alerts)
if err != nil {
return nil, nil, err
}
@ -78,7 +78,7 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.
URI = fmt.Sprintf("%s?%s", URI, params.Encode())
}
req, err := s.client.NewRequest(http.MethodGet, URI, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, URI, nil)
if err != nil {
return nil, nil, fmt.Errorf("building request: %w", err)
}
@ -102,7 +102,7 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod
u := fmt.Sprintf("%s/alerts?%s", s.client.URLPrefix, params.Encode())
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
if err != nil {
return nil, nil, err
}
@ -120,7 +120,7 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod
func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) {
u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID)
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
if err != nil {
return nil, nil, err
}
@ -138,7 +138,7 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.
func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) {
u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID)
req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}

View file

@ -27,7 +27,7 @@ func (s *AllowlistsService) List(ctx context.Context, opts AllowlistListOpts) (*
u += "?" + params.Encode()
req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}
@ -58,7 +58,7 @@ func (s *AllowlistsService) Get(ctx context.Context, name string, opts Allowlist
log.Debugf("GET %s", u)
req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}
@ -76,7 +76,7 @@ func (s *AllowlistsService) Get(ctx context.Context, name string, opts Allowlist
func (s *AllowlistsService) CheckIfAllowlisted(ctx context.Context, value string) (bool, *Response, error) {
u := s.client.URLPrefix + "/allowlists/check/" + value
req, err := s.client.NewRequest(http.MethodHead, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodHead, u, nil)
if err != nil {
return false, nil, err
}
@ -94,7 +94,7 @@ func (s *AllowlistsService) CheckIfAllowlisted(ctx context.Context, value string
func (s *AllowlistsService) CheckIfAllowlistedWithReason(ctx context.Context, value string) (*models.CheckAllowlistResponse, *Response, error) {
u := s.client.URLPrefix + "/allowlists/check/" + value
req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}

View file

@ -67,7 +67,7 @@ func (t *JWTTransport) refreshJwtToken() error {
return fmt.Errorf("could not encode jwt auth body: %w", err)
}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
if err != nil {
return fmt.Errorf("could not create request: %w", err)
}
@ -170,6 +170,7 @@ func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error)
// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
var resp *http.Response
attemptsCount := make(map[int]int)
for {
@ -213,6 +214,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
log.Debugf("retrying request to %s", req.URL.String())
attemptsCount[resp.StatusCode]++
log.Infof("attempt %d out of %d", attemptsCount[resp.StatusCode], config.MaxAttempts)
@ -222,6 +224,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
time.Sleep(time.Duration(backoff) * time.Second)
}
}
return resp, nil
}
@ -242,5 +245,6 @@ func (t *JWTTransport) transport() http.RoundTripper {
if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
}

View file

@ -21,7 +21,7 @@ type enrollRequest struct {
func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) {
u := fmt.Sprintf("%s/watchers", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
if err != nil {
return nil, err
}
@ -37,7 +37,7 @@ func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error)
func (s *AuthService) RegisterWatcher(ctx context.Context, registration models.WatcherRegistrationRequest) (*Response, error) {
u := fmt.Sprintf("%s/watchers", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &registration)
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &registration)
if err != nil {
return nil, err
}
@ -55,7 +55,7 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch
u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &auth)
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &auth)
if err != nil {
return authResp, nil, err
}
@ -71,7 +71,7 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch
func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) {
u := fmt.Sprintf("%s/watchers/enroll", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite})
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite})
if err != nil {
return nil, err
}

View file

@ -15,7 +15,7 @@ import (
log "github.com/sirupsen/logrus"
)
func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Request, error) {
func (c *ApiClient) NewRequestWithContext(ctx context.Context, method, url string, body interface{}) (*http.Request, error) {
if !strings.HasSuffix(c.BaseURL.Path, "/") {
return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL)
}
@ -36,7 +36,7 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ
}
}
req, err := http.NewRequest(method, u.String(), buf)
req, err := http.NewRequestWithContext(ctx, method, u.String(), buf)
if err != nil {
return nil, err
}

View file

@ -45,8 +45,8 @@ func (o *DecisionsStreamOpts) addQueryParamsToURL(url string) (string, error) {
return "", err
}
//Those 2 are a bit different
//They default to true, and we only want to include them if they are false
// Those 2 are a bit different
// They default to true, and we only want to include them if they are false
if params.Get("community_pull") == "true" {
params.Del("community_pull")
@ -81,7 +81,7 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}
@ -97,7 +97,7 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m
}
func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) {
req, err := s.client.NewRequest(http.MethodGet, url, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, nil, err
}
@ -138,7 +138,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
scenarioDeleted := "deleted"
durationDeleted := "1h"
req, err := s.client.NewRequest(http.MethodGet, url, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, nil, err
}
@ -183,7 +183,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
client := http.Client{}
req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, *blocklist.URL, nil)
if err != nil {
return nil, false, err
}
@ -192,7 +192,6 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
req.Header.Set("If-Modified-Since", *lastPullTimestamp)
}
req = req.WithContext(ctx)
log.Debugf("[URL] %s %s", req.Method, req.URL)
// we don't use client_http Do method because we need the reader and is not provided.
@ -272,7 +271,7 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream
return nil, nil, err
}
req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}
@ -290,7 +289,7 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream
func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) {
u := fmt.Sprintf("%s/decisions", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
if err != nil {
return nil, err
}
@ -311,7 +310,7 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts)
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
if err != nil {
return nil, nil, err
}
@ -329,7 +328,7 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts)
func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*models.DeleteDecisionResponse, *Response, error) {
u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decisionID)
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
if err != nil {
return nil, nil, err
}

View file

@ -16,7 +16,7 @@ type DecisionDeleteService service
func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) {
u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix)
req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions)
req, err := d.client.NewRequestWithContext(ctx, http.MethodPost, u, &deletedDecisions)
if err != nil {
return nil, nil, fmt.Errorf("while building request: %w", err)
}

View file

@ -17,7 +17,7 @@ type HeartBeatService service
func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) {
u := fmt.Sprintf("%s/heartbeat", h.client.URLPrefix)
req, err := h.client.NewRequest(http.MethodGet, u, nil)
req, err := h.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return false, nil, err
}
@ -33,7 +33,9 @@ func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) {
func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
t.Go(func() error {
defer trace.CatchPanic("crowdsec/apiClient/heartbeat")
hbTimer := time.NewTicker(1 * time.Minute)
for {
select {
case <-hbTimer.C:
@ -46,6 +48,7 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
}
resp.Response.Body.Close()
if resp.Response.StatusCode != http.StatusOK {
log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode)
continue
@ -58,6 +61,7 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
case <-t.Dying():
log.Debugf("heartbeat: stopping")
hbTimer.Stop()
return nil
}
}

View file

@ -13,7 +13,7 @@ type MetricsService service
func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (interface{}, *Response, error) {
u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &metrics)
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &metrics)
if err != nil {
return nil, nil, err
}

View file

@ -15,7 +15,7 @@ type SignalService service
func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsRequest) (interface{}, *Response, error) {
u := fmt.Sprintf("%s/signals", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &signals)
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &signals)
if err != nil {
return nil, nil, fmt.Errorf("while building request: %w", err)
}

View file

@ -13,7 +13,7 @@ type UsageMetricsService service
func (s *UsageMetricsService) Add(ctx context.Context, metrics *models.AllMetrics) (interface{}, *Response, error) {
u := fmt.Sprintf("%s/usage-metrics", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &metrics)
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &metrics)
if err != nil {
return nil, nil, err
}

View file

@ -2,6 +2,7 @@ package v1
import (
"bytes"
"context"
"crypto"
"crypto/x509"
"io"
@ -22,14 +23,14 @@ func NewOCSPChecker(logger *log.Entry) *OCSPChecker {
}
}
func (oc *OCSPChecker) query(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) {
func (oc *OCSPChecker) query(ctx context.Context, server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) {
req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256})
if err != nil {
oc.logger.Errorf("TLSAuth: error creating OCSP request: %s", err)
return nil, err
}
httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req))
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, server, bytes.NewBuffer(req))
if err != nil {
oc.logger.Error("TLSAuth: cannot create HTTP request for OCSP")
return nil, err
@ -69,14 +70,14 @@ func (oc *OCSPChecker) query(server string, cert *x509.Certificate, issuer *x509
// isRevokedBy checks if the client certificate is revoked by the issuer via any of the OCSP servers present in the certificate.
// It returns a boolean indicating if the certificate is revoked and a boolean indicating
// if the OCSP check was successful and could be cached.
func (oc *OCSPChecker) isRevokedBy(cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) {
func (oc *OCSPChecker) isRevokedBy(ctx context.Context, cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) {
if len(cert.OCSPServer) == 0 {
oc.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification")
return false, true
}
for _, server := range cert.OCSPServer {
ocspResponse, err := oc.query(server, cert, issuer)
ocspResponse, err := oc.query(ctx, server, cert, issuer)
if err != nil {
oc.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err)
continue

View file

@ -1,6 +1,7 @@
package v1
import (
"context"
"crypto/x509"
"errors"
"fmt"
@ -36,7 +37,7 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool {
}
// checkRevocationPath checks a single chain against OCSP and CRL
func (ta *TLSAuth) checkRevocationPath(chain []*x509.Certificate) (error, bool) { //nolint:revive
func (ta *TLSAuth) checkRevocationPath(ctx context.Context, chain []*x509.Certificate) (error, bool) { //nolint:revive
// if we ever fail to check OCSP or CRL, we should not cache the result
couldCheck := true
@ -46,7 +47,7 @@ func (ta *TLSAuth) checkRevocationPath(chain []*x509.Certificate) (error, bool)
cert := chain[i-1]
issuer := chain[i]
revokedByOCSP, checkedByOCSP := ta.ocspChecker.isRevokedBy(cert, issuer)
revokedByOCSP, checkedByOCSP := ta.ocspChecker.isRevokedBy(ctx, cert, issuer)
couldCheck = couldCheck && checkedByOCSP
if revokedByOCSP && checkedByOCSP {
@ -130,12 +131,13 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (string, error) {
okToCache := true
var validErr error
var couldCheck bool
var (
validErr error
couldCheck bool
)
for _, chain := range c.Request.TLS.VerifiedChains {
validErr, couldCheck = ta.checkRevocationPath(chain)
validErr, couldCheck = ta.checkRevocationPath(c.Request.Context(), chain)
okToCache = okToCache && couldCheck
if validErr != nil {

View file

@ -1,6 +1,7 @@
package cticlient
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -33,7 +34,7 @@ type CrowdsecCTIClient struct {
Logger *log.Entry
}
func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map[string]string) ([]byte, error) {
func (c *CrowdsecCTIClient) doRequest(ctx context.Context, method string, endpoint string, params map[string]string) ([]byte, error) {
url := CTIBaseUrl + endpoint
if len(params) > 0 {
url += "?"
@ -41,7 +42,8 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map
url += fmt.Sprintf("%s=%s&", k, v)
}
}
req, err := http.NewRequest(method, url, nil)
req, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, err
}
@ -53,78 +55,103 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusForbidden {
return nil, ErrUnauthorized
}
if resp.StatusCode == http.StatusTooManyRequests {
return nil, ErrLimit
}
if resp.StatusCode == http.StatusNotFound {
return nil, ErrNotFound
}
return nil, fmt.Errorf("unexpected http code : %s", resp.Status)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return respBody, nil
}
func (c *CrowdsecCTIClient) GetIPInfo(ip string) (*SmokeItem, error) {
body, err := c.doRequest(http.MethodGet, smokeEndpoint+"/"+ip, nil)
ctx := context.TODO()
body, err := c.doRequest(ctx, http.MethodGet, smokeEndpoint+"/"+ip, nil)
if err != nil {
if errors.Is(err, ErrNotFound) {
return &SmokeItem{}, nil
}
return nil, err
}
item := SmokeItem{}
err = json.Unmarshal(body, &item)
if err != nil {
return nil, err
}
return &item, nil
}
func (c *CrowdsecCTIClient) SearchIPs(ips []string) (*SearchIPResponse, error) {
ctx := context.TODO()
params := make(map[string]string)
params["ips"] = strings.Join(ips, ",")
body, err := c.doRequest(http.MethodGet, smokeEndpoint, params)
body, err := c.doRequest(ctx, http.MethodGet, smokeEndpoint, params)
if err != nil {
return nil, err
}
searchIPResponse := SearchIPResponse{}
err = json.Unmarshal(body, &searchIPResponse)
if err != nil {
return nil, err
}
return &searchIPResponse, nil
}
func (c *CrowdsecCTIClient) Fire(params FireParams) (*FireResponse, error) {
ctx := context.TODO()
paramsMap := make(map[string]string)
if params.Page != nil {
paramsMap["page"] = fmt.Sprintf("%d", *params.Page)
}
if params.Since != nil {
paramsMap["since"] = *params.Since
}
if params.Limit != nil {
paramsMap["limit"] = fmt.Sprintf("%d", *params.Limit)
}
body, err := c.doRequest(http.MethodGet, fireEndpoint, paramsMap)
body, err := c.doRequest(ctx, http.MethodGet, fireEndpoint, paramsMap)
if err != nil {
return nil, err
}
fireResponse := FireResponse{}
err = json.Unmarshal(body, &fireResponse)
if err != nil {
return nil, err
}
return &fireResponse, nil
}
@ -133,13 +160,16 @@ func NewCrowdsecCTIClient(options ...func(*CrowdsecCTIClient)) *CrowdsecCTIClien
for _, option := range options {
option(client)
}
if client.httpClient == nil {
client.httpClient = &http.Client{}
}
// we cannot return with a ni logger, so we set a default one
// we cannot return with a nil logger, so we set a default one
if client.Logger == nil {
client.Logger = log.NewEntry(log.New())
}
return client
}