From 6810b41dd872670d58e028b1cfa4d12bffc8b19b Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 12 Sep 2024 17:28:16 +0200 Subject: [PATCH] refact pkg/database: context propagation (start) (#3226) * refact pkg/database: context propagation (part) * more context propagation (usagemetrics) * propagate errors when updating metrics --- cmd/crowdsec/metrics.go | 6 ++++-- pkg/apiserver/controllers/v1/usagemetrics.go | 15 ++++++++------- pkg/database/alerts.go | 4 +--- pkg/database/bouncers.go | 5 +++-- pkg/database/decisions.go | 5 +++-- pkg/database/machines.go | 5 +++-- pkg/database/metrics.go | 5 +++-- 7 files changed, 25 insertions(+), 20 deletions(-) diff --git a/cmd/crowdsec/metrics.go b/cmd/crowdsec/metrics.go index d3c6e1720..ff280fc35 100644 --- a/cmd/crowdsec/metrics.go +++ b/cmd/crowdsec/metrics.go @@ -118,7 +118,9 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha return } - decisions, err := dbClient.QueryDecisionCountByScenario() + ctx := r.Context() + + decisions, err := dbClient.QueryDecisionCountByScenario(ctx) if err != nil { log.Errorf("Error querying decisions for metrics: %v", err) next.ServeHTTP(w, r) @@ -138,7 +140,7 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha "include_capi": {"false"}, } - alerts, err := dbClient.AlertsCountPerScenario(alertsFilter) + alerts, err := dbClient.AlertsCountPerScenario(ctx, alertsFilter) if err != nil { log.Errorf("Error querying alerts for metrics: %v", err) next.ServeHTTP(w, r) diff --git a/pkg/apiserver/controllers/v1/usagemetrics.go b/pkg/apiserver/controllers/v1/usagemetrics.go index 74f27bb6c..27b1b819a 100644 --- a/pkg/apiserver/controllers/v1/usagemetrics.go +++ b/pkg/apiserver/controllers/v1/usagemetrics.go @@ -1,6 +1,7 @@ package v1 import ( + "context" "encoding/json" "errors" "net/http" @@ -18,17 +19,15 @@ import ( ) // updateBaseMetrics updates the base metrics for a machine or bouncer -func (c *Controller) updateBaseMetrics(machineID string, bouncer *ent.Bouncer, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { +func (c *Controller) updateBaseMetrics(ctx context.Context, machineID string, bouncer *ent.Bouncer, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { switch { case machineID != "": - c.DBClient.MachineUpdateBaseMetrics(machineID, baseMetrics, hubItems, datasources) + return c.DBClient.MachineUpdateBaseMetrics(ctx, machineID, baseMetrics, hubItems, datasources) case bouncer != nil: - c.DBClient.BouncerUpdateBaseMetrics(bouncer.Name, bouncer.Type, baseMetrics) + return c.DBClient.BouncerUpdateBaseMetrics(ctx, bouncer.Name, bouncer.Type, baseMetrics) default: return errors.New("no machineID or bouncerName set") } - - return nil } // UsageMetrics receives metrics from log processors and remediation components @@ -172,7 +171,9 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { } } - err := c.updateBaseMetrics(machineID, bouncer, baseMetrics, hubItems, datasources) + ctx := gctx.Request.Context() + + err := c.updateBaseMetrics(ctx, machineID, bouncer, baseMetrics, hubItems, datasources) if err != nil { logger.Errorf("Failed to update base metrics: %s", err) c.HandleDBErrors(gctx, err) @@ -190,7 +191,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { receivedAt := time.Now().UTC() - if _, err := c.DBClient.CreateMetric(generatedType, generatedBy, receivedAt, string(jsonPayload)); err != nil { + if _, err := c.DBClient.CreateMetric(ctx, generatedType, generatedBy, receivedAt, string(jsonPayload)); err != nil { logger.Error(err) c.HandleDBErrors(gctx, err) diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 0f6d87fb1..3e3e480c7 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -941,14 +941,12 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str return alerts.Where(preds...), nil } -func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string]int, error) { +func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string][]string) (map[string]int, error) { var res []struct { Scenario string Count int } - ctx := context.TODO() - query := c.Ent.Alert.Query() query, err := BuildAlertRequestFromFilter(query, filters) diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index f79e9580a..a7378bbb2 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "strings" "time" @@ -20,7 +21,7 @@ func (e *BouncerNotFoundError) Error() string { return fmt.Sprintf("'%s' does not exist", e.BouncerName) } -func (c *Client) BouncerUpdateBaseMetrics(bouncerName string, bouncerType string, baseMetrics models.BaseMetrics) error { +func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName string, bouncerType string, baseMetrics models.BaseMetrics) error { os := baseMetrics.Os features := strings.Join(baseMetrics.FeatureFlags, ",") @@ -32,7 +33,7 @@ func (c *Client) BouncerUpdateBaseMetrics(bouncerName string, bouncerType string SetOsversion(*os.Version). SetFeatureflags(features). SetType(bouncerType). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update base bouncer metrics in database: %w", err) } diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index fc582247e..5fd4757c8 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "strconv" "strings" @@ -173,7 +174,7 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ( return data, nil } -func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error) { +func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*DecisionsByScenario, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) @@ -186,7 +187,7 @@ func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error) var r []*DecisionsByScenario - err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(c.CTX, &r) + err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(ctx, &r) if err != nil { c.Log.Warningf("QueryDecisionCountByScenario : %s", err) return nil, errors.Wrap(QueryFail, "count all decisions with filters") diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 3c8cbabbf..27d737e62 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "strings" "time" @@ -29,7 +30,7 @@ func (e *MachineNotFoundError) Error() string { return fmt.Sprintf("'%s' does not exist", e.MachineID) } -func (c *Client) MachineUpdateBaseMetrics(machineID string, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { +func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { os := baseMetrics.Os features := strings.Join(baseMetrics.FeatureFlags, ",") @@ -63,7 +64,7 @@ func (c *Client) MachineUpdateBaseMetrics(machineID string, baseMetrics models.B SetLastHeartbeat(heartbeat). SetHubstate(hubState). SetDatasources(datasources). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update base machine metrics in database: %w", err) } diff --git a/pkg/database/metrics.go b/pkg/database/metrics.go index 7626c39f6..1619fcc92 100644 --- a/pkg/database/metrics.go +++ b/pkg/database/metrics.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "time" @@ -8,14 +9,14 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" ) -func (c *Client) CreateMetric(generatedType metric.GeneratedType, generatedBy string, receivedAt time.Time, payload string) (*ent.Metric, error) { +func (c *Client) CreateMetric(ctx context.Context, generatedType metric.GeneratedType, generatedBy string, receivedAt time.Time, payload string) (*ent.Metric, error) { metric, err := c.Ent.Metric. Create(). SetGeneratedType(generatedType). SetGeneratedBy(generatedBy). SetReceivedAt(receivedAt). SetPayload(payload). - Save(c.CTX) + Save(ctx) if err != nil { c.Log.Warningf("CreateMetric: %s", err) return nil, fmt.Errorf("storing metrics snapshot for '%s' at %s: %w", generatedBy, receivedAt, InsertFail)