mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-11 20:36:12 +02:00
refact: context propagation (apiclient, cticlient...) (#3477)
This commit is contained in:
parent
8da6a4dc92
commit
a3187d6f2c
16 changed files with 131 additions and 68 deletions
|
@ -194,19 +194,16 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, metricsLe
|
|||
|
||||
// let's load the associated appsec_config:
|
||||
if w.config.AppsecConfigPath != "" {
|
||||
err := appsecCfg.LoadByPath(w.config.AppsecConfigPath)
|
||||
if err != nil {
|
||||
if err = appsecCfg.LoadByPath(w.config.AppsecConfigPath); err != nil {
|
||||
return fmt.Errorf("unable to load appsec_config: %w", err)
|
||||
}
|
||||
} else if w.config.AppsecConfig != "" {
|
||||
err := appsecCfg.Load(w.config.AppsecConfig)
|
||||
if err != nil {
|
||||
if err = appsecCfg.Load(w.config.AppsecConfig); err != nil {
|
||||
return fmt.Errorf("unable to load appsec_config: %w", err)
|
||||
}
|
||||
} else if len(w.config.AppsecConfigs) > 0 {
|
||||
for _, appsecConfig := range w.config.AppsecConfigs {
|
||||
err := appsecCfg.Load(appsecConfig)
|
||||
if err != nil {
|
||||
if err = appsecCfg.Load(appsecConfig); err != nil {
|
||||
return fmt.Errorf("unable to load appsec_config: %w", err)
|
||||
}
|
||||
}
|
||||
|
@ -233,6 +230,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, metricsLe
|
|||
if err != nil {
|
||||
return fmt.Errorf("unable to get authenticated LAPI client: %w", err)
|
||||
}
|
||||
|
||||
w.appsecAllowlistClient = allowlists.NewAppsecAllowlist(w.apiClient, w.logger)
|
||||
|
||||
for nbRoutine := range w.config.Routines {
|
||||
|
@ -371,12 +369,12 @@ func (w *AppsecSource) Dump() interface{} {
|
|||
return w
|
||||
}
|
||||
|
||||
func (w *AppsecSource) IsAuth(apiKey string) bool {
|
||||
func (w *AppsecSource) IsAuth(ctx context.Context, apiKey string) bool {
|
||||
client := &http.Client{
|
||||
Timeout: 200 * time.Millisecond,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodHead, w.lapiURL, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, w.lapiURL, nil)
|
||||
if err != nil {
|
||||
log.Errorf("Error creating request: %s", err)
|
||||
return false
|
||||
|
@ -397,6 +395,7 @@ func (w *AppsecSource) IsAuth(apiKey string) bool {
|
|||
|
||||
// should this be in the runner ?
|
||||
func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
w.logger.Debugf("Received request from '%s' on %s", r.RemoteAddr, r.URL.Path)
|
||||
|
||||
apiKey := r.Header.Get(appsec.APIKeyHeaderName)
|
||||
|
@ -413,7 +412,7 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
|
|||
expiration, exists := w.AuthCache.Get(apiKey)
|
||||
// if the apiKey is not in cache or has expired, just recheck the auth
|
||||
if !exists || time.Now().After(expiration) {
|
||||
if !w.IsAuth(apiKey) {
|
||||
if !w.IsAuth(ctx, apiKey) {
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
w.logger.Errorf("Unauthorized request from '%s' (real IP = %s)", remoteIP, clientIP)
|
||||
|
||||
|
|
|
@ -288,6 +288,7 @@ basic_auth:
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionBasicAuth(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -306,7 +307,7 @@ basic_auth:
|
|||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
|
||||
require.NoError(t, err)
|
||||
req.SetBasicAuth("test", "WrongPassword")
|
||||
|
||||
|
@ -321,6 +322,7 @@ basic_auth:
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionBadHeaders(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -334,7 +336,7 @@ headers:
|
|||
|
||||
client := &http.Client{}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("Key", "wrong")
|
||||
|
@ -349,6 +351,7 @@ headers:
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionMaxBodySize(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -362,7 +365,7 @@ max_body_size: 5`), 0)
|
|||
time.Sleep(1 * time.Second)
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("testtest"))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("testtest"))
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("Key", "test")
|
||||
|
@ -378,6 +381,7 @@ max_body_size: 5`), 0)
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionSuccess(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -388,13 +392,14 @@ headers:
|
|||
key: test`), 2)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
rawEvt := `{"test": "test"}`
|
||||
|
||||
errChan := make(chan error)
|
||||
go assertEvents(out, []string{rawEvt}, errChan)
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("Key", "test")
|
||||
|
@ -414,6 +419,7 @@ headers:
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionCustomStatusCodeAndCustomHeaders(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -430,9 +436,10 @@ custom_headers:
|
|||
|
||||
rawEvt := `{"test": "test"}`
|
||||
errChan := make(chan error)
|
||||
|
||||
go assertEvents(out, []string{rawEvt}, errChan)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("Key", "test")
|
||||
|
@ -463,9 +470,11 @@ func (sr *slowReader) Read(p []byte) (int, error) {
|
|||
if sr.index >= len(sr.body) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
time.Sleep(sr.delay) // Simulate a delay in reading
|
||||
n := copy(p, sr.body[sr.index:])
|
||||
sr.index += n
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
@ -492,10 +501,12 @@ func assertEvents(out chan types.Event, expected []string, errChan chan error) {
|
|||
errChan <- fmt.Errorf(`expected %s, got '%+v'`, expected, evt.Line.Raw)
|
||||
return
|
||||
}
|
||||
|
||||
if evt.Line.Src != "127.0.0.1" {
|
||||
errChan <- fmt.Errorf("expected '127.0.0.1', got '%s'", evt.Line.Src)
|
||||
return
|
||||
}
|
||||
|
||||
if evt.Line.Module != "http" {
|
||||
errChan <- fmt.Errorf("expected 'http', got '%s'", evt.Line.Module)
|
||||
return
|
||||
|
@ -505,6 +516,7 @@ func assertEvents(out chan types.Event, expected []string, errChan chan error) {
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionTimeout(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -522,7 +534,7 @@ timeout: 1s`), 0)
|
|||
body: []byte(`{"test": "delayed_payload"}`),
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), slow)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), slow)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("Key", "test")
|
||||
|
@ -566,6 +578,7 @@ tls:
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionTLSWithHeadersAuthSuccess(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -599,9 +612,10 @@ tls:
|
|||
|
||||
rawEvt := `{"test": "test"}`
|
||||
errChan := make(chan error)
|
||||
|
||||
go assertEvents(out, []string{rawEvt}, errChan)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("Key", "test")
|
||||
|
@ -622,6 +636,7 @@ tls:
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionMTLS(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -658,9 +673,10 @@ tls:
|
|||
|
||||
rawEvt := `{"test": "test"}`
|
||||
errChan := make(chan error)
|
||||
|
||||
go assertEvents(out, []string{rawEvt}, errChan)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
@ -680,6 +696,7 @@ tls:
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionGzipData(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -693,6 +710,7 @@ headers:
|
|||
|
||||
rawEvt := `{"test": "test"}`
|
||||
errChan := make(chan error)
|
||||
|
||||
go assertEvents(out, []string{rawEvt, rawEvt}, errChan)
|
||||
|
||||
var b strings.Builder
|
||||
|
@ -709,7 +727,7 @@ headers:
|
|||
|
||||
// send gzipped compressed data
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(b.String()))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(b.String()))
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("Key", "test")
|
||||
|
@ -733,6 +751,7 @@ headers:
|
|||
}
|
||||
|
||||
func TestStreamingAcquisitionNDJson(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
h := &HTTPSource{}
|
||||
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||
source: http
|
||||
|
@ -743,13 +762,14 @@ headers:
|
|||
key: test`), 2)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
rawEvt := `{"test": "test"}`
|
||||
|
||||
rawEvt := `{"test": "test"}`
|
||||
errChan := make(chan error)
|
||||
|
||||
go assertEvents(out, []string{rawEvt, rawEvt}, errChan)
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(fmt.Sprintf("%s\n%s\n", rawEvt, rawEvt)))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(fmt.Sprintf("%s\n%s\n", rawEvt, rawEvt)))
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -776,10 +796,13 @@ func assertMetrics(t *testing.T, reg *prometheus.Registry, metrics []prometheus.
|
|||
require.NoError(t, err)
|
||||
|
||||
isExist := false
|
||||
|
||||
for _, metricFamily := range promMetrics {
|
||||
if metricFamily.GetName() == "cs_httpsource_hits_total" {
|
||||
isExist = true
|
||||
|
||||
assert.Len(t, metricFamily.GetMetric(), 1)
|
||||
|
||||
for _, metric := range metricFamily.GetMetric() {
|
||||
assert.InDelta(t, float64(expected), metric.GetCounter().GetValue(), 0.000001)
|
||||
labels := metric.GetLabel()
|
||||
|
@ -791,6 +814,7 @@ func assertMetrics(t *testing.T, reg *prometheus.Registry, metrics []prometheus.
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isExist && expected > 0 {
|
||||
t.Fatalf("expected metric cs_httpsource_hits_total not found")
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ type AlertsDeleteOpts struct {
|
|||
func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) {
|
||||
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodPost, u, &alerts)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &alerts)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -78,7 +78,7 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.
|
|||
URI = fmt.Sprintf("%s?%s", URI, params.Encode())
|
||||
}
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodGet, URI, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, URI, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("building request: %w", err)
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod
|
|||
|
||||
u := fmt.Sprintf("%s/alerts?%s", s.client.URLPrefix, params.Encode())
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod
|
|||
func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) {
|
||||
u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.
|
|||
func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) {
|
||||
u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ func (s *AllowlistsService) List(ctx context.Context, opts AllowlistListOpts) (*
|
|||
|
||||
u += "?" + params.Encode()
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -58,7 +58,7 @@ func (s *AllowlistsService) Get(ctx context.Context, name string, opts Allowlist
|
|||
|
||||
log.Debugf("GET %s", u)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ func (s *AllowlistsService) Get(ctx context.Context, name string, opts Allowlist
|
|||
func (s *AllowlistsService) CheckIfAllowlisted(ctx context.Context, value string) (bool, *Response, error) {
|
||||
u := s.client.URLPrefix + "/allowlists/check/" + value
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodHead, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodHead, u, nil)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ func (s *AllowlistsService) CheckIfAllowlisted(ctx context.Context, value string
|
|||
func (s *AllowlistsService) CheckIfAllowlistedWithReason(ctx context.Context, value string) (*models.CheckAllowlistResponse, *Response, error) {
|
||||
u := s.client.URLPrefix + "/allowlists/check/" + value
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ func (t *JWTTransport) refreshJwtToken() error {
|
|||
return fmt.Errorf("could not encode jwt auth body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create request: %w", err)
|
||||
}
|
||||
|
@ -170,6 +170,7 @@ func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error)
|
|||
// RoundTrip implements the RoundTripper interface.
|
||||
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
var resp *http.Response
|
||||
|
||||
attemptsCount := make(map[int]int)
|
||||
|
||||
for {
|
||||
|
@ -213,6 +214,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
}
|
||||
|
||||
log.Debugf("retrying request to %s", req.URL.String())
|
||||
|
||||
attemptsCount[resp.StatusCode]++
|
||||
log.Infof("attempt %d out of %d", attemptsCount[resp.StatusCode], config.MaxAttempts)
|
||||
|
||||
|
@ -222,6 +224,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
time.Sleep(time.Duration(backoff) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
@ -242,5 +245,6 @@ func (t *JWTTransport) transport() http.RoundTripper {
|
|||
if t.Transport != nil {
|
||||
return t.Transport
|
||||
}
|
||||
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ type enrollRequest struct {
|
|||
func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) {
|
||||
u := fmt.Sprintf("%s/watchers", s.client.URLPrefix)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error)
|
|||
func (s *AuthService) RegisterWatcher(ctx context.Context, registration models.WatcherRegistrationRequest) (*Response, error) {
|
||||
u := fmt.Sprintf("%s/watchers", s.client.URLPrefix)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodPost, u, ®istration)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, ®istration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch
|
|||
|
||||
u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodPost, u, &auth)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &auth)
|
||||
if err != nil {
|
||||
return authResp, nil, err
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch
|
|||
func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) {
|
||||
u := fmt.Sprintf("%s/watchers/enroll", s.client.URLPrefix)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite})
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Request, error) {
|
||||
func (c *ApiClient) NewRequestWithContext(ctx context.Context, method, url string, body interface{}) (*http.Request, error) {
|
||||
if !strings.HasSuffix(c.BaseURL.Path, "/") {
|
||||
return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL)
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ
|
|||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, u.String(), buf)
|
||||
req, err := http.NewRequestWithContext(ctx, method, u.String(), buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -45,8 +45,8 @@ func (o *DecisionsStreamOpts) addQueryParamsToURL(url string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
//Those 2 are a bit different
|
||||
//They default to true, and we only want to include them if they are false
|
||||
// Those 2 are a bit different
|
||||
// They default to true, and we only want to include them if they are false
|
||||
|
||||
if params.Get("community_pull") == "true" {
|
||||
params.Del("community_pull")
|
||||
|
@ -81,7 +81,7 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m
|
|||
|
||||
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -97,7 +97,7 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m
|
|||
}
|
||||
|
||||
func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) {
|
||||
req, err := s.client.NewRequest(http.MethodGet, url, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
|
|||
scenarioDeleted := "deleted"
|
||||
durationDeleted := "1h"
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodGet, url, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
|||
|
||||
client := http.Client{}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, *blocklist.URL, nil)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
@ -192,7 +192,6 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
|||
req.Header.Set("If-Modified-Since", *lastPullTimestamp)
|
||||
}
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
log.Debugf("[URL] %s %s", req.Method, req.URL)
|
||||
|
||||
// we don't use client_http Do method because we need the reader and is not provided.
|
||||
|
@ -272,7 +271,7 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -290,7 +289,7 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream
|
|||
func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) {
|
||||
u := fmt.Sprintf("%s/decisions", s.client.URLPrefix)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -311,7 +310,7 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts)
|
|||
|
||||
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -329,7 +328,7 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts)
|
|||
func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*models.DeleteDecisionResponse, *Response, error) {
|
||||
u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decisionID)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ type DecisionDeleteService service
|
|||
func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) {
|
||||
u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix)
|
||||
|
||||
req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions)
|
||||
req, err := d.client.NewRequestWithContext(ctx, http.MethodPost, u, &deletedDecisions)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("while building request: %w", err)
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ type HeartBeatService service
|
|||
func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) {
|
||||
u := fmt.Sprintf("%s/heartbeat", h.client.URLPrefix)
|
||||
|
||||
req, err := h.client.NewRequest(http.MethodGet, u, nil)
|
||||
req, err := h.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
@ -33,7 +33,9 @@ func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) {
|
|||
func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
|
||||
t.Go(func() error {
|
||||
defer trace.CatchPanic("crowdsec/apiClient/heartbeat")
|
||||
|
||||
hbTimer := time.NewTicker(1 * time.Minute)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hbTimer.C:
|
||||
|
@ -46,6 +48,7 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
|
|||
}
|
||||
|
||||
resp.Response.Body.Close()
|
||||
|
||||
if resp.Response.StatusCode != http.StatusOK {
|
||||
log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode)
|
||||
continue
|
||||
|
@ -58,6 +61,7 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
|
|||
case <-t.Dying():
|
||||
log.Debugf("heartbeat: stopping")
|
||||
hbTimer.Stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ type MetricsService service
|
|||
func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (interface{}, *Response, error) {
|
||||
u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodPost, u, &metrics)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &metrics)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ type SignalService service
|
|||
func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsRequest) (interface{}, *Response, error) {
|
||||
u := fmt.Sprintf("%s/signals", s.client.URLPrefix)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodPost, u, &signals)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &signals)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("while building request: %w", err)
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ type UsageMetricsService service
|
|||
func (s *UsageMetricsService) Add(ctx context.Context, metrics *models.AllMetrics) (interface{}, *Response, error) {
|
||||
u := fmt.Sprintf("%s/usage-metrics", s.client.URLPrefix)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodPost, u, &metrics)
|
||||
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &metrics)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package v1
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
|
@ -22,14 +23,14 @@ func NewOCSPChecker(logger *log.Entry) *OCSPChecker {
|
|||
}
|
||||
}
|
||||
|
||||
func (oc *OCSPChecker) query(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) {
|
||||
func (oc *OCSPChecker) query(ctx context.Context, server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) {
|
||||
req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256})
|
||||
if err != nil {
|
||||
oc.logger.Errorf("TLSAuth: error creating OCSP request: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req))
|
||||
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, server, bytes.NewBuffer(req))
|
||||
if err != nil {
|
||||
oc.logger.Error("TLSAuth: cannot create HTTP request for OCSP")
|
||||
return nil, err
|
||||
|
@ -69,14 +70,14 @@ func (oc *OCSPChecker) query(server string, cert *x509.Certificate, issuer *x509
|
|||
// isRevokedBy checks if the client certificate is revoked by the issuer via any of the OCSP servers present in the certificate.
|
||||
// It returns a boolean indicating if the certificate is revoked and a boolean indicating
|
||||
// if the OCSP check was successful and could be cached.
|
||||
func (oc *OCSPChecker) isRevokedBy(cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) {
|
||||
func (oc *OCSPChecker) isRevokedBy(ctx context.Context, cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) {
|
||||
if len(cert.OCSPServer) == 0 {
|
||||
oc.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification")
|
||||
return false, true
|
||||
}
|
||||
|
||||
for _, server := range cert.OCSPServer {
|
||||
ocspResponse, err := oc.query(server, cert, issuer)
|
||||
ocspResponse, err := oc.query(ctx, server, cert, issuer)
|
||||
if err != nil {
|
||||
oc.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err)
|
||||
continue
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -36,7 +37,7 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool {
|
|||
}
|
||||
|
||||
// checkRevocationPath checks a single chain against OCSP and CRL
|
||||
func (ta *TLSAuth) checkRevocationPath(chain []*x509.Certificate) (error, bool) { //nolint:revive
|
||||
func (ta *TLSAuth) checkRevocationPath(ctx context.Context, chain []*x509.Certificate) (error, bool) { //nolint:revive
|
||||
// if we ever fail to check OCSP or CRL, we should not cache the result
|
||||
couldCheck := true
|
||||
|
||||
|
@ -46,7 +47,7 @@ func (ta *TLSAuth) checkRevocationPath(chain []*x509.Certificate) (error, bool)
|
|||
cert := chain[i-1]
|
||||
issuer := chain[i]
|
||||
|
||||
revokedByOCSP, checkedByOCSP := ta.ocspChecker.isRevokedBy(cert, issuer)
|
||||
revokedByOCSP, checkedByOCSP := ta.ocspChecker.isRevokedBy(ctx, cert, issuer)
|
||||
couldCheck = couldCheck && checkedByOCSP
|
||||
|
||||
if revokedByOCSP && checkedByOCSP {
|
||||
|
@ -130,12 +131,13 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (string, error) {
|
|||
|
||||
okToCache := true
|
||||
|
||||
var validErr error
|
||||
|
||||
var couldCheck bool
|
||||
var (
|
||||
validErr error
|
||||
couldCheck bool
|
||||
)
|
||||
|
||||
for _, chain := range c.Request.TLS.VerifiedChains {
|
||||
validErr, couldCheck = ta.checkRevocationPath(chain)
|
||||
validErr, couldCheck = ta.checkRevocationPath(c.Request.Context(), chain)
|
||||
okToCache = okToCache && couldCheck
|
||||
|
||||
if validErr != nil {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cticlient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -33,7 +34,7 @@ type CrowdsecCTIClient struct {
|
|||
Logger *log.Entry
|
||||
}
|
||||
|
||||
func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map[string]string) ([]byte, error) {
|
||||
func (c *CrowdsecCTIClient) doRequest(ctx context.Context, method string, endpoint string, params map[string]string) ([]byte, error) {
|
||||
url := CTIBaseUrl + endpoint
|
||||
if len(params) > 0 {
|
||||
url += "?"
|
||||
|
@ -41,7 +42,8 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map
|
|||
url += fmt.Sprintf("%s=%s&", k, v)
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequest(method, url, nil)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -53,78 +55,103 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
return nil, ErrLimit
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected http code : %s", resp.Status)
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (c *CrowdsecCTIClient) GetIPInfo(ip string) (*SmokeItem, error) {
|
||||
body, err := c.doRequest(http.MethodGet, smokeEndpoint+"/"+ip, nil)
|
||||
ctx := context.TODO()
|
||||
|
||||
body, err := c.doRequest(ctx, http.MethodGet, smokeEndpoint+"/"+ip, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
return &SmokeItem{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
item := SmokeItem{}
|
||||
|
||||
err = json.Unmarshal(body, &item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (c *CrowdsecCTIClient) SearchIPs(ips []string) (*SearchIPResponse, error) {
|
||||
ctx := context.TODO()
|
||||
params := make(map[string]string)
|
||||
params["ips"] = strings.Join(ips, ",")
|
||||
body, err := c.doRequest(http.MethodGet, smokeEndpoint, params)
|
||||
|
||||
body, err := c.doRequest(ctx, http.MethodGet, smokeEndpoint, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
searchIPResponse := SearchIPResponse{}
|
||||
|
||||
err = json.Unmarshal(body, &searchIPResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &searchIPResponse, nil
|
||||
}
|
||||
|
||||
func (c *CrowdsecCTIClient) Fire(params FireParams) (*FireResponse, error) {
|
||||
ctx := context.TODO()
|
||||
paramsMap := make(map[string]string)
|
||||
|
||||
if params.Page != nil {
|
||||
paramsMap["page"] = fmt.Sprintf("%d", *params.Page)
|
||||
}
|
||||
|
||||
if params.Since != nil {
|
||||
paramsMap["since"] = *params.Since
|
||||
}
|
||||
|
||||
if params.Limit != nil {
|
||||
paramsMap["limit"] = fmt.Sprintf("%d", *params.Limit)
|
||||
}
|
||||
|
||||
body, err := c.doRequest(http.MethodGet, fireEndpoint, paramsMap)
|
||||
body, err := c.doRequest(ctx, http.MethodGet, fireEndpoint, paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fireResponse := FireResponse{}
|
||||
|
||||
err = json.Unmarshal(body, &fireResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &fireResponse, nil
|
||||
}
|
||||
|
||||
|
@ -133,13 +160,16 @@ func NewCrowdsecCTIClient(options ...func(*CrowdsecCTIClient)) *CrowdsecCTIClien
|
|||
for _, option := range options {
|
||||
option(client)
|
||||
}
|
||||
|
||||
if client.httpClient == nil {
|
||||
client.httpClient = &http.Client{}
|
||||
}
|
||||
// we cannot return with a ni logger, so we set a default one
|
||||
|
||||
// we cannot return with a nil logger, so we set a default one
|
||||
if client.Logger == nil {
|
||||
client.Logger = log.NewEntry(log.New())
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue