context propagation: papi, loki (#3308)

* context propagation: AuthenticatedLAPIClient()

* context propagation: papi

* context propagation: loki
This commit is contained in:
mmetc 2024-11-15 15:31:10 +01:00 committed by GitHub
parent b96a7a5f06
commit a4497da6b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 22 additions and 21 deletions

View file

@ -136,7 +136,7 @@ func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client
t.Go(papi.SyncDecisions)
err = papi.PullOnce(time.Time{}, true)
err = papi.PullOnce(ctx, time.Time{}, true)
if err != nil {
return fmt.Errorf("unable to sync decisions: %w", err)
}

View file

@ -116,7 +116,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H
})
bucketWg.Wait()
apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub)
apiClient, err := AuthenticatedLAPIClient(context.TODO(), *cConfig.API.Client.Credentials, hub)
if err != nil {
return err
}

View file

@ -14,7 +14,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/models"
)
func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) {
func AuthenticatedLAPIClient(ctx context.Context, credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) {
apiURL, err := url.Parse(credentials.URL)
if err != nil {
return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err)
@ -44,7 +44,7 @@ func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.
return nil, fmt.Errorf("new client api: %w", err)
}
authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
authResp, _, err := client.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{
MachineID: &credentials.Login,
Password: &password,
Scenarios: itemsForAPI,

View file

@ -119,7 +119,7 @@ func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQu
case <-lc.t.Dying():
return lc.t.Err()
case <-ticker.C:
resp, err := lc.Get(uri)
resp, err := lc.Get(ctx, uri)
if err != nil {
if ok := lc.shouldRetry(); !ok {
return fmt.Errorf("error querying range: %w", err)
@ -215,7 +215,7 @@ func (lc *LokiClient) Ready(ctx context.Context) error {
return lc.t.Err()
case <-tick.C:
lc.Logger.Debug("Checking if Loki is ready")
resp, err := lc.Get(url)
resp, err := lc.Get(ctx, url)
if err != nil {
lc.Logger.Warnf("Error checking if Loki is ready: %s", err)
continue
@ -300,8 +300,8 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ
}
// Create a wrapper for http.Get to be able to set headers and auth
func (lc *LokiClient) Get(url string) (*http.Response, error) {
request, err := http.NewRequest(http.MethodGet, url, nil)
func (lc *LokiClient) Get(ctx context.Context, url string) (*http.Response, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}

View file

@ -205,8 +205,8 @@ func reverse(s []longpollclient.Event) []longpollclient.Event {
return a
}
func (p *Papi) PullOnce(since time.Time, sync bool) error {
events, err := p.Client.PullOnce(since)
func (p *Papi) PullOnce(ctx context.Context, since time.Time, sync bool) error {
events, err := p.Client.PullOnce(ctx, since)
if err != nil {
return err
}
@ -261,7 +261,7 @@ func (p *Papi) Pull(ctx context.Context) error {
p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp)
for event := range p.Client.Start(lastTimestamp) {
for event := range p.Client.Start(ctx, lastTimestamp) {
logger := p.Logger.WithField("request-id", event.RequestId)
// update last timestamp in database
newTime := time.Now().UTC()

View file

@ -1,6 +1,7 @@
package longpollclient
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -50,7 +51,7 @@ var errUnauthorized = errors.New("user is not authorized to use PAPI")
const timeoutMessage = "no events before timeout"
func (c *LongPollClient) doQuery() (*http.Response, error) {
func (c *LongPollClient) doQuery(ctx context.Context) (*http.Response, error) {
logger := c.logger.WithField("method", "doQuery")
query := c.url.Query()
query.Set("since_time", fmt.Sprintf("%d", c.since))
@ -59,7 +60,7 @@ func (c *LongPollClient) doQuery() (*http.Response, error) {
logger.Debugf("Query parameters: %s", c.url.RawQuery)
req, err := http.NewRequest(http.MethodGet, c.url.String(), nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url.String(), nil)
if err != nil {
logger.Errorf("failed to create request: %s", err)
return nil, err
@ -73,10 +74,10 @@ func (c *LongPollClient) doQuery() (*http.Response, error) {
return resp, nil
}
func (c *LongPollClient) poll() error {
func (c *LongPollClient) poll(ctx context.Context) error {
logger := c.logger.WithField("method", "poll")
resp, err := c.doQuery()
resp, err := c.doQuery(ctx)
if err != nil {
return err
}
@ -146,7 +147,7 @@ func (c *LongPollClient) poll() error {
}
}
func (c *LongPollClient) pollEvents() error {
func (c *LongPollClient) pollEvents(ctx context.Context) error {
for {
select {
case <-c.t.Dying():
@ -154,7 +155,7 @@ func (c *LongPollClient) pollEvents() error {
return nil
default:
c.logger.Debug("Polling PAPI")
err := c.poll()
err := c.poll(ctx)
if err != nil {
c.logger.Errorf("failed to poll: %s", err)
if errors.Is(err, errUnauthorized) {
@ -168,12 +169,12 @@ func (c *LongPollClient) pollEvents() error {
}
}
func (c *LongPollClient) Start(since time.Time) chan Event {
func (c *LongPollClient) Start(ctx context.Context, since time.Time) chan Event {
c.logger.Infof("starting polling client")
c.c = make(chan Event)
c.since = since.Unix() * 1000
c.timeout = "45"
c.t.Go(c.pollEvents)
c.t.Go(func() error {return c.pollEvents(ctx)})
return c.c
}
@ -182,11 +183,11 @@ func (c *LongPollClient) Stop() error {
return nil
}
func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) {
func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event, error) {
c.logger.Debug("Pulling PAPI once")
c.since = since.Unix() * 1000
c.timeout = "1"
resp, err := c.doQuery()
resp, err := c.doQuery(ctx)
if err != nil {
return nil, err
}