From 3945a991bd265b765e9ab80e28aff70c21c707c9 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:13:45 +0200 Subject: [PATCH] context propagation: pkg/database/alerts (#3252) * alerts * drop CTX from dbclient * lint * pkg/database/alerts: context.TODO() * cscli: context.Background() -> cmd.Context() --- .golangci.yml | 4 +- cmd/crowdsec-cli/clialert/alerts.go | 20 ++-- cmd/crowdsec-cli/cliconsole/console.go | 8 +- cmd/crowdsec-cli/clidecision/decisions.go | 22 ++--- cmd/crowdsec-cli/clilapi/lapi.go | 2 +- .../clinotifications/notifications.go | 8 +- pkg/apiserver/apic.go | 10 +- pkg/apiserver/apic_test.go | 10 +- pkg/apiserver/controllers/v1/alerts.go | 18 +++- pkg/apiserver/controllers/v1/decisions.go | 4 +- pkg/apiserver/controllers/v1/metrics.go | 15 ++- pkg/apiserver/decisions_test.go | 10 +- pkg/apiserver/middlewares/v1/cache.go | 2 +- pkg/apiserver/middlewares/v1/crl.go | 10 +- pkg/apiserver/middlewares/v1/jwt.go | 4 +- pkg/apiserver/papi_cmd.go | 4 +- pkg/database/alerts.go | 97 +++++++++---------- pkg/database/database.go | 2 - pkg/database/decisions.go | 6 +- pkg/database/flush.go | 6 +- pkg/database/metrics.go | 2 +- 21 files changed, 141 insertions(+), 123 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index b76e2613b..54c0acb06 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -20,7 +20,7 @@ linters-settings: maintidx: # raise this after refactoring - under: 16 + under: 15 misspell: locale: US @@ -118,7 +118,7 @@ linters-settings: arguments: [6] - name: function-length # lower this after refactoring - arguments: [110, 235] + arguments: [110, 237] - name: get-return disabled: true - name: increment-decrement diff --git a/cmd/crowdsec-cli/clialert/alerts.go b/cmd/crowdsec-cli/clialert/alerts.go index dbb7ca14d..75454e945 100644 --- a/cmd/crowdsec-cli/clialert/alerts.go +++ b/cmd/crowdsec-cli/clialert/alerts.go @@ -235,7 +235,7 @@ func (cli *cliAlerts) NewCommand() *cobra.Command { return cmd } -func (cli *cliAlerts) list(alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error { +func (cli *cliAlerts) list(ctx context.Context, alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error { var err error *alertListFilter.ScopeEquals, err = SanitizeScope(*alertListFilter.ScopeEquals, *alertListFilter.IPEquals, *alertListFilter.RangeEquals) @@ -311,7 +311,7 @@ func (cli *cliAlerts) list(alertListFilter apiclient.AlertsListOpts, limit *int, alertListFilter.Contains = new(bool) } - alerts, _, err := cli.client.Alerts.List(context.Background(), alertListFilter) + alerts, _, err := cli.client.Alerts.List(ctx, alertListFilter) if err != nil { return fmt.Errorf("unable to list alerts: %w", err) } @@ -354,7 +354,7 @@ cscli alerts list --type ban`, Long: `List alerts with optional filters`, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { - return cli.list(alertListFilter, limit, contained, printMachine) + return cli.list(cmd.Context(), alertListFilter, limit, contained, printMachine) }, } @@ -377,7 +377,7 @@ cscli alerts list --type ban`, return cmd } -func (cli *cliAlerts) delete(delFilter apiclient.AlertsDeleteOpts, activeDecision *bool, deleteAll bool, delAlertByID string, contained *bool) error { +func (cli *cliAlerts) delete(ctx context.Context, delFilter apiclient.AlertsDeleteOpts, activeDecision *bool, deleteAll bool, delAlertByID string, contained *bool) error { var err error if !deleteAll { @@ -423,12 +423,12 @@ func (cli *cliAlerts) delete(delFilter apiclient.AlertsDeleteOpts, activeDecisio var alerts *models.DeleteAlertsResponse if delAlertByID == "" { - alerts, _, err = cli.client.Alerts.Delete(context.Background(), delFilter) + alerts, _, err = cli.client.Alerts.Delete(ctx, delFilter) if err != nil { return fmt.Errorf("unable to delete alerts: %w", err) } } else { - alerts, _, err = cli.client.Alerts.DeleteOne(context.Background(), delAlertByID) + alerts, _, err = cli.client.Alerts.DeleteOne(ctx, delAlertByID) if err != nil { return fmt.Errorf("unable to delete alert: %w", err) } @@ -480,7 +480,7 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`, return nil }, RunE: func(cmd *cobra.Command, _ []string) error { - return cli.delete(delFilter, activeDecision, deleteAll, delAlertByID, contained) + return cli.delete(cmd.Context(), delFilter, activeDecision, deleteAll, delAlertByID, contained) }, } @@ -498,7 +498,7 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`, return cmd } -func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error { +func (cli *cliAlerts) inspect(ctx context.Context, details bool, alertIDs ...string) error { cfg := cli.cfg() for _, alertID := range alertIDs { @@ -507,7 +507,7 @@ func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error { return fmt.Errorf("bad alert id %s", alertID) } - alert, _, err := cli.client.Alerts.GetByID(context.Background(), id) + alert, _, err := cli.client.Alerts.GetByID(ctx, id) if err != nil { return fmt.Errorf("can't find alert with id %s: %w", alertID, err) } @@ -551,7 +551,7 @@ func (cli *cliAlerts) newInspectCmd() *cobra.Command { _ = cmd.Help() return errors.New("missing alert_id") } - return cli.inspect(details, args...) + return cli.inspect(cmd.Context(), details, args...) }, } diff --git a/cmd/crowdsec-cli/cliconsole/console.go b/cmd/crowdsec-cli/cliconsole/console.go index af1ba316c..448ddcee7 100644 --- a/cmd/crowdsec-cli/cliconsole/console.go +++ b/cmd/crowdsec-cli/cliconsole/console.go @@ -66,7 +66,7 @@ func (cli *cliConsole) NewCommand() *cobra.Command { return cmd } -func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []string, opts []string) error { +func (cli *cliConsole) enroll(ctx context.Context, key string, name string, overwrite bool, tags []string, opts []string) error { cfg := cli.cfg() password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password) @@ -127,7 +127,7 @@ func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []st VersionPrefix: "v3", }) - resp, err := c.Auth.EnrollWatcher(context.Background(), key, name, tags, overwrite) + resp, err := c.Auth.EnrollWatcher(ctx, key, name, tags, overwrite) if err != nil { return fmt.Errorf("could not enroll instance: %w", err) } @@ -173,8 +173,8 @@ After running this command your will need to validate the enrollment in the weba valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")), Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.enroll(args[0], name, overwrite, tags, opts) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.enroll(cmd.Context(), args[0], name, overwrite, tags, opts) }, } diff --git a/cmd/crowdsec-cli/clidecision/decisions.go b/cmd/crowdsec-cli/clidecision/decisions.go index b82ebe308..1f8781a37 100644 --- a/cmd/crowdsec-cli/clidecision/decisions.go +++ b/cmd/crowdsec-cli/clidecision/decisions.go @@ -170,7 +170,7 @@ func (cli *cliDecisions) NewCommand() *cobra.Command { return cmd } -func (cli *cliDecisions) list(filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error { +func (cli *cliDecisions) list(ctx context.Context, filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error { var err error *filter.ScopeEquals, err = clialert.SanitizeScope(*filter.ScopeEquals, *filter.IPEquals, *filter.RangeEquals) @@ -249,7 +249,7 @@ func (cli *cliDecisions) list(filter apiclient.AlertsListOpts, NoSimu *bool, con filter.Contains = new(bool) } - alerts, _, err := cli.client.Alerts.List(context.Background(), filter) + alerts, _, err := cli.client.Alerts.List(ctx, filter) if err != nil { return fmt.Errorf("unable to retrieve decisions: %w", err) } @@ -293,7 +293,7 @@ cscli decisions list --origin lists --scenario list_name Args: cobra.ExactArgs(0), DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { - return cli.list(filter, NoSimu, contained, printMachine) + return cli.list(cmd.Context(), filter, NoSimu, contained, printMachine) }, } @@ -317,7 +317,7 @@ cscli decisions list --origin lists --scenario list_name return cmd } -func (cli *cliDecisions) add(addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error { +func (cli *cliDecisions) add(ctx context.Context, addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error { alerts := models.AddAlertsRequest{} origin := types.CscliOrigin capacity := int32(0) @@ -386,7 +386,7 @@ func (cli *cliDecisions) add(addIP, addRange, addDuration, addValue, addScope, a } alerts = append(alerts, &alert) - _, _, err = cli.client.Alerts.Add(context.Background(), alerts) + _, _, err = cli.client.Alerts.Add(ctx, alerts) if err != nil { return err } @@ -419,7 +419,7 @@ cscli decisions add --scope username --value foobar Args: cobra.ExactArgs(0), DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { - return cli.add(addIP, addRange, addDuration, addValue, addScope, addReason, addType) + return cli.add(cmd.Context(), addIP, addRange, addDuration, addValue, addScope, addReason, addType) }, } @@ -436,7 +436,7 @@ cscli decisions add --scope username --value foobar return cmd } -func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error { +func (cli *cliDecisions) delete(ctx context.Context, delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error { var err error /*take care of shorthand options*/ @@ -480,7 +480,7 @@ func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDeci var decisions *models.DeleteDecisionResponse if delDecisionID == "" { - decisions, _, err = cli.client.Decisions.Delete(context.Background(), delFilter) + decisions, _, err = cli.client.Decisions.Delete(ctx, delFilter) if err != nil { return fmt.Errorf("unable to delete decisions: %w", err) } @@ -489,7 +489,7 @@ func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDeci return fmt.Errorf("id '%s' is not an integer: %w", delDecisionID, err) } - decisions, _, err = cli.client.Decisions.DeleteOne(context.Background(), delDecisionID) + decisions, _, err = cli.client.Decisions.DeleteOne(ctx, delDecisionID) if err != nil { return fmt.Errorf("unable to delete decision: %w", err) } @@ -543,8 +543,8 @@ cscli decisions delete --origin lists --scenario list_name return nil }, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.delete(delFilter, delDecisionID, contained) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.delete(cmd.Context(), delFilter, delDecisionID, contained) }, } diff --git a/cmd/crowdsec-cli/clilapi/lapi.go b/cmd/crowdsec-cli/clilapi/lapi.go index 75fdc5c23..bb721eefe 100644 --- a/cmd/crowdsec-cli/clilapi/lapi.go +++ b/cmd/crowdsec-cli/clilapi/lapi.go @@ -68,7 +68,7 @@ func queryLAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login Scenarios: itemsForAPI, } - _, _, err = client.Auth.AuthenticateWatcher(context.Background(), t) + _, _, err = client.Auth.AuthenticateWatcher(ctx, t) if err != nil { return false, err } diff --git a/cmd/crowdsec-cli/clinotifications/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go index 314f97db2..0641dd1a7 100644 --- a/cmd/crowdsec-cli/clinotifications/notifications.go +++ b/cmd/crowdsec-cli/clinotifications/notifications.go @@ -368,9 +368,9 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not `, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - PreRunE: func(_ *cobra.Command, args []string) error { + PreRunE: func(cmd *cobra.Command, args []string) error { var err error - alert, err = cli.fetchAlertFromArgString(args[0]) + alert, err = cli.fetchAlertFromArgString(cmd.Context(), args[0]) if err != nil { return err } @@ -447,7 +447,7 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return cmd } -func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Alert, error) { +func (cli *cliNotifications) fetchAlertFromArgString(ctx context.Context, toParse string) (*models.Alert, error) { cfg := cli.cfg() id, err := strconv.Atoi(toParse) @@ -470,7 +470,7 @@ func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Al return nil, fmt.Errorf("error creating the client for the API: %w", err) } - alert, _, err := client.Alerts.GetByID(context.Background(), id) + alert, _, err := client.Alerts.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("can't find alert with id %d: %w", id, err) } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 8b09e3e5f..9b56fef65 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -406,13 +406,13 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { } } -func (a *apic) CAPIPullIsOld() (bool, error) { +func (a *apic) CAPIPullIsOld(ctx context.Context) (bool, error) { /*only pull community blocklist if it's older than 1h30 */ alerts := a.dbClient.Ent.Alert.Query() alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert - count, err := alerts.Count(a.dbClient.CTX) + count, err := alerts.Count(ctx) if err != nil { return false, fmt.Errorf("while looking for CAPI alert: %w", err) } @@ -634,7 +634,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error { } if !forcePull { - if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil { + if lastPullIsOld, err := a.CAPIPullIsOld(ctx); err != nil { return err } else if !lastPullIsOld { return nil @@ -769,6 +769,8 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis } func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error { + ctx := context.TODO() + for _, alert := range alertsFromCapi { setAlertScenario(alert, addCounters, deleteCounters) log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions)) @@ -777,7 +779,7 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist") } - alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert) + alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(ctx, alert) if err != nil { return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err) } diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 97943b495..3bb158acf 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -113,7 +113,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) { func TestAPICCAPIPullIsOld(t *testing.T) { api := getAPIC(t) - isOld, err := api.CAPIPullIsOld() + ctx := context.Background() + + isOld, err := api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.True(t, isOld) @@ -124,7 +126,7 @@ func TestAPICCAPIPullIsOld(t *testing.T) { SetScope("Country"). SetValue("Blah"). SetOrigin(types.CAPIOrigin). - SaveX(context.Background()) + SaveX(ctx) api.dbClient.Ent.Alert.Create(). SetCreatedAt(time.Now()). @@ -132,9 +134,9 @@ func TestAPICCAPIPullIsOld(t *testing.T) { AddDecisions( decision, ). - SaveX(context.Background()) + SaveX(ctx) - isOld, err = api.CAPIPullIsOld() + isOld, err = api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.False(t, isOld) diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 84b309486..d1f932285 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -127,6 +127,7 @@ func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uin func (c *Controller) CreateAlert(gctx *gin.Context) { var input models.AddAlertsRequest + ctx := gctx.Request.Context() machineID, _ := getMachineIDFromContext(gctx) if err := gctx.ShouldBindJSON(&input); err != nil { @@ -239,7 +240,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { c.DBClient.CanFlush = false } - alerts, err := c.DBClient.CreateAlert(machineID, input) + alerts, err := c.DBClient.CreateAlert(ctx, machineID, input) c.DBClient.CanFlush = true if err != nil { @@ -261,7 +262,9 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { // FindAlerts: returns alerts from the database based on the specified filter func (c *Controller) FindAlerts(gctx *gin.Context) { - result, err := c.DBClient.QueryAlertWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + result, err := c.DBClient.QueryAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return @@ -279,6 +282,7 @@ func (c *Controller) FindAlerts(gctx *gin.Context) { // FindAlertByID returns the alert associated with the ID func (c *Controller) FindAlertByID(gctx *gin.Context) { + ctx := gctx.Request.Context() alertIDStr := gctx.Param("alert_id") alertID, err := strconv.Atoi(alertIDStr) @@ -287,7 +291,7 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) { return } - result, err := c.DBClient.GetAlertByID(alertID) + result, err := c.DBClient.GetAlertByID(ctx, alertID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -307,6 +311,8 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) { func (c *Controller) DeleteAlertByID(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) @@ -321,7 +327,7 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { return } - err = c.DBClient.DeleteAlertByID(decisionID) + err = c.DBClient.DeleteAlertByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -334,13 +340,15 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { // DeleteAlerts deletes alerts from the database based on the specified filter func (c *Controller) DeleteAlerts(gctx *gin.Context) { + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } - nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query()) + nbDeleted, err := c.DBClient.DeleteAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index d1aa79bbf..ffefffc22 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -264,7 +264,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B gctx.Writer.Header().Set("Content-Type", "application/json") gctx.Writer.Header().Set("Transfer-Encoding", "chunked") gctx.Writer.WriteHeader(http.StatusOK) - gctx.Writer.WriteString(`{"new": [`) //No need to check for errors, the doc says it always returns nil + gctx.Writer.WriteString(`{"new": [`) // No need to check for errors, the doc says it always returns nil // if the blocker just started, return all decisions if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" { @@ -340,7 +340,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en return err } - //data = KeepLongestDecision(data) + // data = KeepLongestDecision(data) ret["new"] = FormatDecisions(data) // getting expired decisions diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index ddb38512a..4f6ee0986 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -68,7 +68,8 @@ func PrometheusBouncersHasEmptyDecision(c *gin.Context) { bouncer, _ := getBouncerFromContext(c) if bouncer != nil { LapiNilDecisions.With(prometheus.Labels{ - "bouncer": bouncer.Name}).Inc() + "bouncer": bouncer.Name, + }).Inc() } } @@ -76,7 +77,8 @@ func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { bouncer, _ := getBouncerFromContext(c) if bouncer != nil { LapiNonNilDecisions.With(prometheus.Labels{ - "bouncer": bouncer.Name}).Inc() + "bouncer": bouncer.Name, + }).Inc() } } @@ -87,7 +89,8 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc { LapiMachineHits.With(prometheus.Labels{ "machine": machineID, "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() } c.Next() @@ -101,7 +104,8 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc { LapiBouncerHits.With(prometheus.Labels{ "bouncer": bouncer.Name, "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() } c.Next() @@ -114,7 +118,8 @@ func PrometheusMiddleware() gin.HandlerFunc { LapiRouteHits.With(prometheus.Labels{ "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() c.Next() elapsed := time.Since(startTime) diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index e4c9dda47..1c70c495a 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -191,7 +191,7 @@ func TestDeleteDecisionByID(t *testing.T) { // Create Valid Alert lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") - //Have one alerts + // Have one alert w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) @@ -210,7 +210,7 @@ func TestDeleteDecisionByID(t *testing.T) { errResp, _ = readDecisionsErrorResp(t, w) assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"]) - //Have one alerts + // Have one alert w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) @@ -223,7 +223,7 @@ func TestDeleteDecisionByID(t *testing.T) { resp, _ := readDecisionsDeleteResp(t, w) assert.Equal(t, "1", resp.NbDeleted) - //Have one alert (because we delete an alert that has dup targets) + // Have one alert (because we delete an alert that has dup targets) w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) @@ -251,7 +251,7 @@ func TestDeleteDecision(t *testing.T) { } func TestStreamStartDecisionDedup(t *testing.T) { - //Ensure that at stream startup we only get the longest decision + // Ensure that at stream startup we only get the longest decision lapi := SetupLAPITest(t) // Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3 @@ -299,7 +299,7 @@ func TestStreamStartDecisionDedup(t *testing.T) { w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - //and now we only get a deleted decision + // and now we only get a deleted decision w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) diff --git a/pkg/apiserver/middlewares/v1/cache.go b/pkg/apiserver/middlewares/v1/cache.go index a058ec403..b0037bc4f 100644 --- a/pkg/apiserver/middlewares/v1/cache.go +++ b/pkg/apiserver/middlewares/v1/cache.go @@ -9,7 +9,7 @@ import ( ) type cacheEntry struct { - err error // if nil, the certificate is not revocated + err error // if nil, the certificate is not revocated timestamp time.Time } diff --git a/pkg/apiserver/middlewares/v1/crl.go b/pkg/apiserver/middlewares/v1/crl.go index f85a41099..64d7d3f0d 100644 --- a/pkg/apiserver/middlewares/v1/crl.go +++ b/pkg/apiserver/middlewares/v1/crl.go @@ -12,13 +12,13 @@ import ( ) type CRLChecker struct { - path string // path to the CRL file - fileInfo os.FileInfo // last stat of the CRL file - crls []*x509.RevocationList // parsed CRLs + path string // path to the CRL file + fileInfo os.FileInfo // last stat of the CRL file + crls []*x509.RevocationList // parsed CRLs logger *log.Entry mu sync.RWMutex - lastLoad time.Time // time when the CRL file was last read successfully - onLoad func() // called when the CRL file changes (and is read successfully) + lastLoad time.Time // time when the CRL file was last read successfully + onLoad func() // called when the CRL file changes (and is read successfully) } func NewCRLChecker(crlPath string, onLoad func(), logger *log.Entry) (*CRLChecker, error) { diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 17ca5b283..9171e9fce 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -128,6 +128,8 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { err error ) + ctx := c.Request.Context() + ret := authInput{} if err = c.ShouldBindJSON(&loginInput); err != nil { @@ -144,7 +146,7 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(ctx) if err != nil { log.Infof("Error machine login for %s : %+v ", ret.machineID, err) return nil, err diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index 18b19b034..78f5dc9b0 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -97,6 +97,8 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { } func AlertCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "add": data, err := json.Marshal(message.Data) @@ -155,7 +157,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { } // use a different method: alert and/or decision might already be partially present in the database - _, err = p.DBClient.CreateOrUpdateAlert("", alert) + _, err = p.DBClient.CreateOrUpdateAlert(ctx, "", alert) if err != nil { log.Errorf("Failed to create alerts in DB: %s", err) } else { diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index d2760a209..ede9c89fe 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -35,12 +35,12 @@ const ( // CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it // if alert already exists, it checks it associated decisions already exists // if some associated decisions are missing (ie. previous insert ended up in error) it inserts them -func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) (string, error) { +func (c *Client) CreateOrUpdateAlert(ctx context.Context, machineID string, alertItem *models.Alert) (string, error) { if alertItem.UUID == "" { return "", errors.New("alert UUID is empty") } - alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(c.CTX) + alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(ctx) if err != nil && !ent.IsNotFound(err) { return "", fmt.Errorf("unable to query alerts for uuid %s: %w", alertItem.UUID, err) @@ -48,7 +48,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) // alert wasn't found, insert it (expected hotpath) if ent.IsNotFound(err) || len(alerts) == 0 { - alertIDs, err := c.CreateAlert(machineID, []*models.Alert{alertItem}) + alertIDs, err := c.CreateAlert(ctx, machineID, []*models.Alert{alertItem}) if err != nil { return "", fmt.Errorf("unable to create alert: %w", err) } @@ -165,7 +165,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(c.CTX) + decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { return "", fmt.Errorf("creating alert decisions: %w", err) } @@ -178,7 +178,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(c.CTX) + err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(ctx) if err != nil { return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) } @@ -191,7 +191,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) // it takes care of creating the new alert with the associated decisions, and it will as well deleted the "older" overlapping decisions: // 1st pull, you get decisions [1,2,3]. it inserts [1,2,3] // 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin -func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) { +func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models.Alert) (int, int, int, error) { if alertItem == nil { return 0, 0, 0, errors.New("nil alert") } @@ -244,7 +244,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in SetScenarioHash(*alertItem.ScenarioHash). SetRemediation(true) // it's from CAPI, we always have decisions - alertRef, err := alertB.Save(c.CTX) + alertRef, err := alertB.Save(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err) } @@ -253,7 +253,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, 0, 0, nil } - txClient, err := c.Ent.Tx(c.CTX) + txClient, err := c.Ent.Tx(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err) } @@ -347,7 +347,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in decision.OriginEQ(DecOrigin), decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), decision.ValueIn(deleteChunk...), - )).Exec(c.CTX) + )).Exec(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { @@ -363,7 +363,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(c.CTX) + insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { @@ -391,7 +391,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, inserted, deleted, nil } -func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { +func (c *Client) createDecisionChunk(ctx context.Context, simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { decisionCreate := []*ent.DecisionCreate{} for _, decisionItem := range decisions { @@ -436,7 +436,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return nil, nil } - ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(c.CTX) + ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(ctx) if err != nil { return nil, err } @@ -444,7 +444,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return ret, nil } -func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { +func (c *Client) createAlertChunk(ctx context.Context, machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { alertBuilders := []*ent.AlertCreate{} alertDecisions := [][]*ent.Decision{} @@ -540,7 +540,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) } - events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(c.CTX) + events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "creating alert events: %s", err) } @@ -554,12 +554,14 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ value := metaItem.Value if len(metaItem.Value) > 4095 { - c.Log.Warningf("truncated meta %s : value too long", metaItem.Key) + c.Log.Warningf("truncated meta %s: value too long", metaItem.Key) + value = value[:4095] } if len(metaItem.Key) > 255 { - c.Log.Warningf("truncated meta %s : key too long", metaItem.Key) + c.Log.Warningf("truncated meta %s: key too long", metaItem.Key) + key = key[:255] } @@ -568,7 +570,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ SetValue(value) } - metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(c.CTX) + metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(ctx) if err != nil { c.Log.Warningf("error creating alert meta: %s", err) } @@ -578,7 +580,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - decisionRet, err := c.createDecisionChunk(*alertItem.Simulated, stopAtTime, decisionChunk) + decisionRet, err := c.createDecisionChunk(ctx, *alertItem.Simulated, stopAtTime, decisionChunk) if err != nil { return nil, fmt.Errorf("creating alert decisions: %w", err) } @@ -636,7 +638,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ return nil, nil } - alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(c.CTX) + alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) } @@ -653,7 +655,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ for retry < maxLockRetries { // so much for the happy path... but sqlite3 errors work differently - _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) + _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(ctx) if err == nil { break } @@ -678,17 +680,16 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ } } } + return ret, nil } -func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]string, error) { +func (c *Client) CreateAlert(ctx context.Context, machineID string, alertList []*models.Alert) ([]string, error) { var ( owner *ent.Machine err error ) - ctx := context.TODO() - if machineID != "" { owner, err = c.QueryMachineByID(ctx, machineID) if err != nil { @@ -708,7 +709,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str alertIDs := []string{} for _, alertChunk := range alertChunks { - ids, err := c.createAlertChunk(machineID, owner, alertChunk) + ids, err := c.createAlertChunk(ctx, machineID, owner, alertChunk) if err != nil { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } @@ -717,7 +718,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str } if owner != nil { - err = owner.Update().SetLastPush(time.Now().UTC()).Exec(c.CTX) + err = owner.Update().SetLastPush(time.Now().UTC()).Exec(ctx) if err != nil { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } @@ -919,7 +920,6 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e case "since", "created_before", "until": if err := handleTimeFilters(param, value[0], &predicates); err != nil { return nil, err - } case "decision_type": predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0]))) @@ -954,7 +954,6 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e if err := handleIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil { return nil, err - } return predicates, nil @@ -996,11 +995,11 @@ func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string] return counts, nil } -func (c *Client) TotalAlerts() (int, error) { - return c.Ent.Alert.Query().Count(c.CTX) +func (c *Client) TotalAlerts(ctx context.Context) (int, error) { + return c.Ent.Alert.Query().Count(ctx) } -func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, error) { +func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Alert, error) { sort := "DESC" // we sort by desc by default if val, ok := filter["sort"]; ok { @@ -1047,7 +1046,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, WithOwner() if limit == 0 { - limit, err = alerts.Count(c.CTX) + limit, err = alerts.Count(ctx) if err != nil { return nil, fmt.Errorf("unable to count nb alerts: %w", err) } @@ -1059,7 +1058,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID)) } - result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX) + result, err := alerts.Limit(paginationSize).Offset(offset).All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) } @@ -1088,35 +1087,35 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, return ret, nil } -func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { +func (c *Client) DeleteAlertGraphBatch(ctx context.Context, alertItems []*ent.Alert) (int, error) { idList := make([]int, 0) for _, alert := range alertItems { idList = append(idList, alert.ID) } _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch events") } _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch meta") } _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch decisions") } deleted, err := c.Ent.Alert.Delete(). - Where(alert.IDIn(idList...)).Exec(c.CTX) + Where(alert.IDIn(idList...)).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return deleted, errors.Wrapf(DeleteFail, "alert graph delete batch") @@ -1127,10 +1126,10 @@ func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { return deleted, nil } -func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { +func (c *Client) DeleteAlertGraph(ctx context.Context, alertItem *ent.Alert) error { // delete the associated events _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "event with alert ID '%d'", alertItem.ID) @@ -1138,7 +1137,7 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated meta _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "meta with alert ID '%d'", alertItem.ID) @@ -1146,14 +1145,14 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated decisions _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "decision with alert ID '%d'", alertItem.ID) } // delete the alert - err = c.Ent.Alert.DeleteOne(alertItem).Exec(c.CTX) + err = c.Ent.Alert.DeleteOne(alertItem).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "alert with ID '%d'", alertItem.ID) @@ -1162,26 +1161,26 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { return nil } -func (c *Client) DeleteAlertByID(id int) error { - alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(c.CTX) +func (c *Client) DeleteAlertByID(ctx context.Context, id int) error { + alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(ctx) if err != nil { return err } - return c.DeleteAlertGraph(alertItem) + return c.DeleteAlertGraph(ctx, alertItem) } -func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) { +func (c *Client) DeleteAlertWithFilter(ctx context.Context, filter map[string][]string) (int, error) { preds, err := AlertPredicatesFromFilter(filter) if err != nil { return 0, err } - return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX) + return c.Ent.Alert.Delete().Where(preds...).Exec(ctx) } -func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) { - alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(c.CTX) +func (c *Client) GetAlertByID(ctx context.Context, alertID int) (*ent.Alert, error) { + alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(ctx) if err != nil { /*record not found, 404*/ if ent.IsNotFound(err) { diff --git a/pkg/database/database.go b/pkg/database/database.go index e51345919..bb41dd3b6 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -21,7 +21,6 @@ import ( type Client struct { Ent *ent.Client - CTX context.Context Log *log.Logger CanFlush bool Type string @@ -106,7 +105,6 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro return &Client{ Ent: client, - CTX: ctx, Log: clog, CanFlush: true, Type: config.Type, diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index 8547990c2..7522a2727 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -31,7 +31,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ @@ -321,7 +321,7 @@ func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[strin var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer) */ @@ -440,7 +440,7 @@ func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[strin var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ decisions := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now().UTC())) diff --git a/pkg/database/flush.go b/pkg/database/flush.go index 46c8edfa3..8f646ddc9 100644 --- a/pkg/database/flush.go +++ b/pkg/database/flush.go @@ -239,7 +239,7 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e c.FlushOrphans(ctx) c.Log.Debug("Done flushing orphan alerts") - totalAlerts, err = c.TotalAlerts() + totalAlerts, err = c.TotalAlerts(ctx) if err != nil { c.Log.Warningf("FlushAlerts (max items count): %s", err) return fmt.Errorf("unable to get alerts count: %w", err) @@ -252,7 +252,7 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e "created_before": {MaxAge}, } - nbDeleted, err := c.DeleteAlertWithFilter(filter) + nbDeleted, err := c.DeleteAlertWithFilter(ctx, filter) if err != nil { c.Log.Warningf("FlushAlerts (max age): %s", err) return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) @@ -268,7 +268,7 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e // This gives us the oldest alert that we want to keep // We then delete all the alerts with an id lower than this one // We can do this because the id is auto-increment, and the database won't reuse the same id twice - lastAlert, err := c.QueryAlertWithFilter(map[string][]string{ + lastAlert, err := c.QueryAlertWithFilter(ctx, map[string][]string{ "sort": {"DESC"}, "limit": {"1"}, // we do not care about fetching the edges, we just want the id diff --git a/pkg/database/metrics.go b/pkg/database/metrics.go index 99ba90c80..eb4c47282 100644 --- a/pkg/database/metrics.go +++ b/pkg/database/metrics.go @@ -17,7 +17,7 @@ func (c *Client) CreateMetric(ctx context.Context, generatedType metric.Generate SetReceivedAt(receivedAt). SetPayload(payload). Save(ctx) - if err != nil { + if err != nil { c.Log.Warningf("CreateMetric: %s", err) return nil, fmt.Errorf("storing metrics snapshot for '%s' at %s: %w", generatedBy, receivedAt, InsertFail) }