context propagation: pkg/database/alerts (#3252)

* alerts
* drop CTX from dbclient
* lint
* pkg/database/alerts: context.TODO()
* cscli: context.Background() -> cmd.Context()
This commit is contained in:
mmetc 2024-09-24 14:13:45 +02:00 committed by GitHub
parent 1133afe58d
commit 3945a991bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 141 additions and 123 deletions

View file

@ -20,7 +20,7 @@ linters-settings:
maintidx: maintidx:
# raise this after refactoring # raise this after refactoring
under: 16 under: 15
misspell: misspell:
locale: US locale: US
@ -118,7 +118,7 @@ linters-settings:
arguments: [6] arguments: [6]
- name: function-length - name: function-length
# lower this after refactoring # lower this after refactoring
arguments: [110, 235] arguments: [110, 237]
- name: get-return - name: get-return
disabled: true disabled: true
- name: increment-decrement - name: increment-decrement

View file

@ -235,7 +235,7 @@ func (cli *cliAlerts) NewCommand() *cobra.Command {
return cmd return cmd
} }
func (cli *cliAlerts) list(alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error { func (cli *cliAlerts) list(ctx context.Context, alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error {
var err error var err error
*alertListFilter.ScopeEquals, err = SanitizeScope(*alertListFilter.ScopeEquals, *alertListFilter.IPEquals, *alertListFilter.RangeEquals) *alertListFilter.ScopeEquals, err = SanitizeScope(*alertListFilter.ScopeEquals, *alertListFilter.IPEquals, *alertListFilter.RangeEquals)
@ -311,7 +311,7 @@ func (cli *cliAlerts) list(alertListFilter apiclient.AlertsListOpts, limit *int,
alertListFilter.Contains = new(bool) alertListFilter.Contains = new(bool)
} }
alerts, _, err := cli.client.Alerts.List(context.Background(), alertListFilter) alerts, _, err := cli.client.Alerts.List(ctx, alertListFilter)
if err != nil { if err != nil {
return fmt.Errorf("unable to list alerts: %w", err) return fmt.Errorf("unable to list alerts: %w", err)
} }
@ -354,7 +354,7 @@ cscli alerts list --type ban`,
Long: `List alerts with optional filters`, Long: `List alerts with optional filters`,
DisableAutoGenTag: true, DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, _ []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
return cli.list(alertListFilter, limit, contained, printMachine) return cli.list(cmd.Context(), alertListFilter, limit, contained, printMachine)
}, },
} }
@ -377,7 +377,7 @@ cscli alerts list --type ban`,
return cmd return cmd
} }
func (cli *cliAlerts) delete(delFilter apiclient.AlertsDeleteOpts, activeDecision *bool, deleteAll bool, delAlertByID string, contained *bool) error { func (cli *cliAlerts) delete(ctx context.Context, delFilter apiclient.AlertsDeleteOpts, activeDecision *bool, deleteAll bool, delAlertByID string, contained *bool) error {
var err error var err error
if !deleteAll { if !deleteAll {
@ -423,12 +423,12 @@ func (cli *cliAlerts) delete(delFilter apiclient.AlertsDeleteOpts, activeDecisio
var alerts *models.DeleteAlertsResponse var alerts *models.DeleteAlertsResponse
if delAlertByID == "" { if delAlertByID == "" {
alerts, _, err = cli.client.Alerts.Delete(context.Background(), delFilter) alerts, _, err = cli.client.Alerts.Delete(ctx, delFilter)
if err != nil { if err != nil {
return fmt.Errorf("unable to delete alerts: %w", err) return fmt.Errorf("unable to delete alerts: %w", err)
} }
} else { } else {
alerts, _, err = cli.client.Alerts.DeleteOne(context.Background(), delAlertByID) alerts, _, err = cli.client.Alerts.DeleteOne(ctx, delAlertByID)
if err != nil { if err != nil {
return fmt.Errorf("unable to delete alert: %w", err) return fmt.Errorf("unable to delete alert: %w", err)
} }
@ -480,7 +480,7 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`,
return nil return nil
}, },
RunE: func(cmd *cobra.Command, _ []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
return cli.delete(delFilter, activeDecision, deleteAll, delAlertByID, contained) return cli.delete(cmd.Context(), delFilter, activeDecision, deleteAll, delAlertByID, contained)
}, },
} }
@ -498,7 +498,7 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`,
return cmd return cmd
} }
func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error { func (cli *cliAlerts) inspect(ctx context.Context, details bool, alertIDs ...string) error {
cfg := cli.cfg() cfg := cli.cfg()
for _, alertID := range alertIDs { for _, alertID := range alertIDs {
@ -507,7 +507,7 @@ func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error {
return fmt.Errorf("bad alert id %s", alertID) return fmt.Errorf("bad alert id %s", alertID)
} }
alert, _, err := cli.client.Alerts.GetByID(context.Background(), id) alert, _, err := cli.client.Alerts.GetByID(ctx, id)
if err != nil { if err != nil {
return fmt.Errorf("can't find alert with id %s: %w", alertID, err) return fmt.Errorf("can't find alert with id %s: %w", alertID, err)
} }
@ -551,7 +551,7 @@ func (cli *cliAlerts) newInspectCmd() *cobra.Command {
_ = cmd.Help() _ = cmd.Help()
return errors.New("missing alert_id") return errors.New("missing alert_id")
} }
return cli.inspect(details, args...) return cli.inspect(cmd.Context(), details, args...)
}, },
} }

View file

@ -66,7 +66,7 @@ func (cli *cliConsole) NewCommand() *cobra.Command {
return cmd return cmd
} }
func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []string, opts []string) error { func (cli *cliConsole) enroll(ctx context.Context, key string, name string, overwrite bool, tags []string, opts []string) error {
cfg := cli.cfg() cfg := cli.cfg()
password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password) password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password)
@ -127,7 +127,7 @@ func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []st
VersionPrefix: "v3", VersionPrefix: "v3",
}) })
resp, err := c.Auth.EnrollWatcher(context.Background(), key, name, tags, overwrite) resp, err := c.Auth.EnrollWatcher(ctx, key, name, tags, overwrite)
if err != nil { if err != nil {
return fmt.Errorf("could not enroll instance: %w", err) return fmt.Errorf("could not enroll instance: %w", err)
} }
@ -173,8 +173,8 @@ After running this command your will need to validate the enrollment in the weba
valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")), valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")),
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
DisableAutoGenTag: true, DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return cli.enroll(args[0], name, overwrite, tags, opts) return cli.enroll(cmd.Context(), args[0], name, overwrite, tags, opts)
}, },
} }

View file

@ -170,7 +170,7 @@ func (cli *cliDecisions) NewCommand() *cobra.Command {
return cmd return cmd
} }
func (cli *cliDecisions) list(filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error { func (cli *cliDecisions) list(ctx context.Context, filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error {
var err error var err error
*filter.ScopeEquals, err = clialert.SanitizeScope(*filter.ScopeEquals, *filter.IPEquals, *filter.RangeEquals) *filter.ScopeEquals, err = clialert.SanitizeScope(*filter.ScopeEquals, *filter.IPEquals, *filter.RangeEquals)
@ -249,7 +249,7 @@ func (cli *cliDecisions) list(filter apiclient.AlertsListOpts, NoSimu *bool, con
filter.Contains = new(bool) filter.Contains = new(bool)
} }
alerts, _, err := cli.client.Alerts.List(context.Background(), filter) alerts, _, err := cli.client.Alerts.List(ctx, filter)
if err != nil { if err != nil {
return fmt.Errorf("unable to retrieve decisions: %w", err) return fmt.Errorf("unable to retrieve decisions: %w", err)
} }
@ -293,7 +293,7 @@ cscli decisions list --origin lists --scenario list_name
Args: cobra.ExactArgs(0), Args: cobra.ExactArgs(0),
DisableAutoGenTag: true, DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, _ []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
return cli.list(filter, NoSimu, contained, printMachine) return cli.list(cmd.Context(), filter, NoSimu, contained, printMachine)
}, },
} }
@ -317,7 +317,7 @@ cscli decisions list --origin lists --scenario list_name
return cmd return cmd
} }
func (cli *cliDecisions) add(addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error { func (cli *cliDecisions) add(ctx context.Context, addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error {
alerts := models.AddAlertsRequest{} alerts := models.AddAlertsRequest{}
origin := types.CscliOrigin origin := types.CscliOrigin
capacity := int32(0) capacity := int32(0)
@ -386,7 +386,7 @@ func (cli *cliDecisions) add(addIP, addRange, addDuration, addValue, addScope, a
} }
alerts = append(alerts, &alert) alerts = append(alerts, &alert)
_, _, err = cli.client.Alerts.Add(context.Background(), alerts) _, _, err = cli.client.Alerts.Add(ctx, alerts)
if err != nil { if err != nil {
return err return err
} }
@ -419,7 +419,7 @@ cscli decisions add --scope username --value foobar
Args: cobra.ExactArgs(0), Args: cobra.ExactArgs(0),
DisableAutoGenTag: true, DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, _ []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
return cli.add(addIP, addRange, addDuration, addValue, addScope, addReason, addType) return cli.add(cmd.Context(), addIP, addRange, addDuration, addValue, addScope, addReason, addType)
}, },
} }
@ -436,7 +436,7 @@ cscli decisions add --scope username --value foobar
return cmd return cmd
} }
func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error { func (cli *cliDecisions) delete(ctx context.Context, delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error {
var err error var err error
/*take care of shorthand options*/ /*take care of shorthand options*/
@ -480,7 +480,7 @@ func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDeci
var decisions *models.DeleteDecisionResponse var decisions *models.DeleteDecisionResponse
if delDecisionID == "" { if delDecisionID == "" {
decisions, _, err = cli.client.Decisions.Delete(context.Background(), delFilter) decisions, _, err = cli.client.Decisions.Delete(ctx, delFilter)
if err != nil { if err != nil {
return fmt.Errorf("unable to delete decisions: %w", err) return fmt.Errorf("unable to delete decisions: %w", err)
} }
@ -489,7 +489,7 @@ func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDeci
return fmt.Errorf("id '%s' is not an integer: %w", delDecisionID, err) return fmt.Errorf("id '%s' is not an integer: %w", delDecisionID, err)
} }
decisions, _, err = cli.client.Decisions.DeleteOne(context.Background(), delDecisionID) decisions, _, err = cli.client.Decisions.DeleteOne(ctx, delDecisionID)
if err != nil { if err != nil {
return fmt.Errorf("unable to delete decision: %w", err) return fmt.Errorf("unable to delete decision: %w", err)
} }
@ -543,8 +543,8 @@ cscli decisions delete --origin lists --scenario list_name
return nil return nil
}, },
RunE: func(_ *cobra.Command, _ []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
return cli.delete(delFilter, delDecisionID, contained) return cli.delete(cmd.Context(), delFilter, delDecisionID, contained)
}, },
} }

View file

@ -68,7 +68,7 @@ func queryLAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login
Scenarios: itemsForAPI, Scenarios: itemsForAPI,
} }
_, _, err = client.Auth.AuthenticateWatcher(context.Background(), t) _, _, err = client.Auth.AuthenticateWatcher(ctx, t)
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -368,9 +368,9 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
`, `,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
DisableAutoGenTag: true, DisableAutoGenTag: true,
PreRunE: func(_ *cobra.Command, args []string) error { PreRunE: func(cmd *cobra.Command, args []string) error {
var err error var err error
alert, err = cli.fetchAlertFromArgString(args[0]) alert, err = cli.fetchAlertFromArgString(cmd.Context(), args[0])
if err != nil { if err != nil {
return err return err
} }
@ -447,7 +447,7 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
return cmd return cmd
} }
func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Alert, error) { func (cli *cliNotifications) fetchAlertFromArgString(ctx context.Context, toParse string) (*models.Alert, error) {
cfg := cli.cfg() cfg := cli.cfg()
id, err := strconv.Atoi(toParse) id, err := strconv.Atoi(toParse)
@ -470,7 +470,7 @@ func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Al
return nil, fmt.Errorf("error creating the client for the API: %w", err) return nil, fmt.Errorf("error creating the client for the API: %w", err)
} }
alert, _, err := client.Alerts.GetByID(context.Background(), id) alert, _, err := client.Alerts.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("can't find alert with id %d: %w", id, err) return nil, fmt.Errorf("can't find alert with id %d: %w", id, err)
} }

View file

@ -406,13 +406,13 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
} }
} }
func (a *apic) CAPIPullIsOld() (bool, error) { func (a *apic) CAPIPullIsOld(ctx context.Context) (bool, error) {
/*only pull community blocklist if it's older than 1h30 */ /*only pull community blocklist if it's older than 1h30 */
alerts := a.dbClient.Ent.Alert.Query() alerts := a.dbClient.Ent.Alert.Query()
alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID)))
alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert
count, err := alerts.Count(a.dbClient.CTX) count, err := alerts.Count(ctx)
if err != nil { if err != nil {
return false, fmt.Errorf("while looking for CAPI alert: %w", err) return false, fmt.Errorf("while looking for CAPI alert: %w", err)
} }
@ -634,7 +634,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
} }
if !forcePull { if !forcePull {
if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil { if lastPullIsOld, err := a.CAPIPullIsOld(ctx); err != nil {
return err return err
} else if !lastPullIsOld { } else if !lastPullIsOld {
return nil return nil
@ -769,6 +769,8 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis
} }
func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error { func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error {
ctx := context.TODO()
for _, alert := range alertsFromCapi { for _, alert := range alertsFromCapi {
setAlertScenario(alert, addCounters, deleteCounters) setAlertScenario(alert, addCounters, deleteCounters)
log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions)) log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions))
@ -777,7 +779,7 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string
log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist") log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist")
} }
alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert) alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(ctx, alert)
if err != nil { if err != nil {
return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err) return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err)
} }

View file

@ -113,7 +113,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) {
func TestAPICCAPIPullIsOld(t *testing.T) { func TestAPICCAPIPullIsOld(t *testing.T) {
api := getAPIC(t) api := getAPIC(t)
isOld, err := api.CAPIPullIsOld() ctx := context.Background()
isOld, err := api.CAPIPullIsOld(ctx)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, isOld) assert.True(t, isOld)
@ -124,7 +126,7 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
SetScope("Country"). SetScope("Country").
SetValue("Blah"). SetValue("Blah").
SetOrigin(types.CAPIOrigin). SetOrigin(types.CAPIOrigin).
SaveX(context.Background()) SaveX(ctx)
api.dbClient.Ent.Alert.Create(). api.dbClient.Ent.Alert.Create().
SetCreatedAt(time.Now()). SetCreatedAt(time.Now()).
@ -132,9 +134,9 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
AddDecisions( AddDecisions(
decision, decision,
). ).
SaveX(context.Background()) SaveX(ctx)
isOld, err = api.CAPIPullIsOld() isOld, err = api.CAPIPullIsOld(ctx)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, isOld) assert.False(t, isOld)

View file

@ -127,6 +127,7 @@ func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uin
func (c *Controller) CreateAlert(gctx *gin.Context) { func (c *Controller) CreateAlert(gctx *gin.Context) {
var input models.AddAlertsRequest var input models.AddAlertsRequest
ctx := gctx.Request.Context()
machineID, _ := getMachineIDFromContext(gctx) machineID, _ := getMachineIDFromContext(gctx)
if err := gctx.ShouldBindJSON(&input); err != nil { if err := gctx.ShouldBindJSON(&input); err != nil {
@ -239,7 +240,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
c.DBClient.CanFlush = false c.DBClient.CanFlush = false
} }
alerts, err := c.DBClient.CreateAlert(machineID, input) alerts, err := c.DBClient.CreateAlert(ctx, machineID, input)
c.DBClient.CanFlush = true c.DBClient.CanFlush = true
if err != nil { if err != nil {
@ -261,7 +262,9 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
// FindAlerts: returns alerts from the database based on the specified filter // FindAlerts: returns alerts from the database based on the specified filter
func (c *Controller) FindAlerts(gctx *gin.Context) { func (c *Controller) FindAlerts(gctx *gin.Context) {
result, err := c.DBClient.QueryAlertWithFilter(gctx.Request.URL.Query()) ctx := gctx.Request.Context()
result, err := c.DBClient.QueryAlertWithFilter(ctx, gctx.Request.URL.Query())
if err != nil { if err != nil {
c.HandleDBErrors(gctx, err) c.HandleDBErrors(gctx, err)
return return
@ -279,6 +282,7 @@ func (c *Controller) FindAlerts(gctx *gin.Context) {
// FindAlertByID returns the alert associated with the ID // FindAlertByID returns the alert associated with the ID
func (c *Controller) FindAlertByID(gctx *gin.Context) { func (c *Controller) FindAlertByID(gctx *gin.Context) {
ctx := gctx.Request.Context()
alertIDStr := gctx.Param("alert_id") alertIDStr := gctx.Param("alert_id")
alertID, err := strconv.Atoi(alertIDStr) alertID, err := strconv.Atoi(alertIDStr)
@ -287,7 +291,7 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) {
return return
} }
result, err := c.DBClient.GetAlertByID(alertID) result, err := c.DBClient.GetAlertByID(ctx, alertID)
if err != nil { if err != nil {
c.HandleDBErrors(gctx, err) c.HandleDBErrors(gctx, err)
return return
@ -307,6 +311,8 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) {
func (c *Controller) DeleteAlertByID(gctx *gin.Context) { func (c *Controller) DeleteAlertByID(gctx *gin.Context) {
var err error var err error
ctx := gctx.Request.Context()
incomingIP := gctx.ClientIP() incomingIP := gctx.ClientIP()
if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) {
gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)})
@ -321,7 +327,7 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) {
return return
} }
err = c.DBClient.DeleteAlertByID(decisionID) err = c.DBClient.DeleteAlertByID(ctx, decisionID)
if err != nil { if err != nil {
c.HandleDBErrors(gctx, err) c.HandleDBErrors(gctx, err)
return return
@ -334,13 +340,15 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) {
// DeleteAlerts deletes alerts from the database based on the specified filter // DeleteAlerts deletes alerts from the database based on the specified filter
func (c *Controller) DeleteAlerts(gctx *gin.Context) { func (c *Controller) DeleteAlerts(gctx *gin.Context) {
ctx := gctx.Request.Context()
incomingIP := gctx.ClientIP() incomingIP := gctx.ClientIP()
if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) {
gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)})
return return
} }
nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query()) nbDeleted, err := c.DBClient.DeleteAlertWithFilter(ctx, gctx.Request.URL.Query())
if err != nil { if err != nil {
c.HandleDBErrors(gctx, err) c.HandleDBErrors(gctx, err)
return return

View file

@ -264,7 +264,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B
gctx.Writer.Header().Set("Content-Type", "application/json") gctx.Writer.Header().Set("Content-Type", "application/json")
gctx.Writer.Header().Set("Transfer-Encoding", "chunked") gctx.Writer.Header().Set("Transfer-Encoding", "chunked")
gctx.Writer.WriteHeader(http.StatusOK) gctx.Writer.WriteHeader(http.StatusOK)
gctx.Writer.WriteString(`{"new": [`) //No need to check for errors, the doc says it always returns nil gctx.Writer.WriteString(`{"new": [`) // No need to check for errors, the doc says it always returns nil
// if the blocker just started, return all decisions // if the blocker just started, return all decisions
if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" { if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" {
@ -340,7 +340,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
return err return err
} }
//data = KeepLongestDecision(data) // data = KeepLongestDecision(data)
ret["new"] = FormatDecisions(data) ret["new"] = FormatDecisions(data)
// getting expired decisions // getting expired decisions

View file

@ -68,7 +68,8 @@ func PrometheusBouncersHasEmptyDecision(c *gin.Context) {
bouncer, _ := getBouncerFromContext(c) bouncer, _ := getBouncerFromContext(c)
if bouncer != nil { if bouncer != nil {
LapiNilDecisions.With(prometheus.Labels{ LapiNilDecisions.With(prometheus.Labels{
"bouncer": bouncer.Name}).Inc() "bouncer": bouncer.Name,
}).Inc()
} }
} }
@ -76,7 +77,8 @@ func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) {
bouncer, _ := getBouncerFromContext(c) bouncer, _ := getBouncerFromContext(c)
if bouncer != nil { if bouncer != nil {
LapiNonNilDecisions.With(prometheus.Labels{ LapiNonNilDecisions.With(prometheus.Labels{
"bouncer": bouncer.Name}).Inc() "bouncer": bouncer.Name,
}).Inc()
} }
} }
@ -87,7 +89,8 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc {
LapiMachineHits.With(prometheus.Labels{ LapiMachineHits.With(prometheus.Labels{
"machine": machineID, "machine": machineID,
"route": c.Request.URL.Path, "route": c.Request.URL.Path,
"method": c.Request.Method}).Inc() "method": c.Request.Method,
}).Inc()
} }
c.Next() c.Next()
@ -101,7 +104,8 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc {
LapiBouncerHits.With(prometheus.Labels{ LapiBouncerHits.With(prometheus.Labels{
"bouncer": bouncer.Name, "bouncer": bouncer.Name,
"route": c.Request.URL.Path, "route": c.Request.URL.Path,
"method": c.Request.Method}).Inc() "method": c.Request.Method,
}).Inc()
} }
c.Next() c.Next()
@ -114,7 +118,8 @@ func PrometheusMiddleware() gin.HandlerFunc {
LapiRouteHits.With(prometheus.Labels{ LapiRouteHits.With(prometheus.Labels{
"route": c.Request.URL.Path, "route": c.Request.URL.Path,
"method": c.Request.Method}).Inc() "method": c.Request.Method,
}).Inc()
c.Next() c.Next()
elapsed := time.Since(startTime) elapsed := time.Since(startTime)

View file

@ -191,7 +191,7 @@ func TestDeleteDecisionByID(t *testing.T) {
// Create Valid Alert // Create Valid Alert
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
//Have one alerts // Have one alert
w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code := readDecisionsStreamResp(t, w) decisions, code := readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
@ -210,7 +210,7 @@ func TestDeleteDecisionByID(t *testing.T) {
errResp, _ = readDecisionsErrorResp(t, w) errResp, _ = readDecisionsErrorResp(t, w)
assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"]) assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"])
//Have one alerts // Have one alert
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w) decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
@ -223,7 +223,7 @@ func TestDeleteDecisionByID(t *testing.T) {
resp, _ := readDecisionsDeleteResp(t, w) resp, _ := readDecisionsDeleteResp(t, w)
assert.Equal(t, "1", resp.NbDeleted) assert.Equal(t, "1", resp.NbDeleted)
//Have one alert (because we delete an alert that has dup targets) // Have one alert (because we delete an alert that has dup targets)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w) decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
@ -251,7 +251,7 @@ func TestDeleteDecision(t *testing.T) {
} }
func TestStreamStartDecisionDedup(t *testing.T) { func TestStreamStartDecisionDedup(t *testing.T) {
//Ensure that at stream startup we only get the longest decision // Ensure that at stream startup we only get the longest decision
lapi := SetupLAPITest(t) lapi := SetupLAPITest(t)
// Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3 // Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3
@ -299,7 +299,7 @@ func TestStreamStartDecisionDedup(t *testing.T) {
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
//and now we only get a deleted decision // and now we only get a deleted decision
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w) decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)

View file

@ -9,7 +9,7 @@ import (
) )
type cacheEntry struct { type cacheEntry struct {
err error // if nil, the certificate is not revocated err error // if nil, the certificate is not revocated
timestamp time.Time timestamp time.Time
} }

View file

@ -12,13 +12,13 @@ import (
) )
type CRLChecker struct { type CRLChecker struct {
path string // path to the CRL file path string // path to the CRL file
fileInfo os.FileInfo // last stat of the CRL file fileInfo os.FileInfo // last stat of the CRL file
crls []*x509.RevocationList // parsed CRLs crls []*x509.RevocationList // parsed CRLs
logger *log.Entry logger *log.Entry
mu sync.RWMutex mu sync.RWMutex
lastLoad time.Time // time when the CRL file was last read successfully lastLoad time.Time // time when the CRL file was last read successfully
onLoad func() // called when the CRL file changes (and is read successfully) onLoad func() // called when the CRL file changes (and is read successfully)
} }
func NewCRLChecker(crlPath string, onLoad func(), logger *log.Entry) (*CRLChecker, error) { func NewCRLChecker(crlPath string, onLoad func(), logger *log.Entry) (*CRLChecker, error) {

View file

@ -128,6 +128,8 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
err error err error
) )
ctx := c.Request.Context()
ret := authInput{} ret := authInput{}
if err = c.ShouldBindJSON(&loginInput); err != nil { if err = c.ShouldBindJSON(&loginInput); err != nil {
@ -144,7 +146,7 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
Where(machine.MachineId(ret.machineID)). Where(machine.MachineId(ret.machineID)).
First(j.DbClient.CTX) First(ctx)
if err != nil { if err != nil {
log.Infof("Error machine login for %s : %+v ", ret.machineID, err) log.Infof("Error machine login for %s : %+v ", ret.machineID, err)
return nil, err return nil, err

View file

@ -97,6 +97,8 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error {
} }
func AlertCmd(message *Message, p *Papi, sync bool) error { func AlertCmd(message *Message, p *Papi, sync bool) error {
ctx := context.TODO()
switch message.Header.OperationCmd { switch message.Header.OperationCmd {
case "add": case "add":
data, err := json.Marshal(message.Data) data, err := json.Marshal(message.Data)
@ -155,7 +157,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
} }
// use a different method: alert and/or decision might already be partially present in the database // use a different method: alert and/or decision might already be partially present in the database
_, err = p.DBClient.CreateOrUpdateAlert("", alert) _, err = p.DBClient.CreateOrUpdateAlert(ctx, "", alert)
if err != nil { if err != nil {
log.Errorf("Failed to create alerts in DB: %s", err) log.Errorf("Failed to create alerts in DB: %s", err)
} else { } else {

View file

@ -35,12 +35,12 @@ const (
// CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it // CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it
// if alert already exists, it checks it associated decisions already exists // if alert already exists, it checks it associated decisions already exists
// if some associated decisions are missing (ie. previous insert ended up in error) it inserts them // if some associated decisions are missing (ie. previous insert ended up in error) it inserts them
func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) (string, error) { func (c *Client) CreateOrUpdateAlert(ctx context.Context, machineID string, alertItem *models.Alert) (string, error) {
if alertItem.UUID == "" { if alertItem.UUID == "" {
return "", errors.New("alert UUID is empty") return "", errors.New("alert UUID is empty")
} }
alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(c.CTX) alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(ctx)
if err != nil && !ent.IsNotFound(err) { if err != nil && !ent.IsNotFound(err) {
return "", fmt.Errorf("unable to query alerts for uuid %s: %w", alertItem.UUID, err) return "", fmt.Errorf("unable to query alerts for uuid %s: %w", alertItem.UUID, err)
@ -48,7 +48,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert)
// alert wasn't found, insert it (expected hotpath) // alert wasn't found, insert it (expected hotpath)
if ent.IsNotFound(err) || len(alerts) == 0 { if ent.IsNotFound(err) || len(alerts) == 0 {
alertIDs, err := c.CreateAlert(machineID, []*models.Alert{alertItem}) alertIDs, err := c.CreateAlert(ctx, machineID, []*models.Alert{alertItem})
if err != nil { if err != nil {
return "", fmt.Errorf("unable to create alert: %w", err) return "", fmt.Errorf("unable to create alert: %w", err)
} }
@ -165,7 +165,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert)
builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize)
for _, builderChunk := range builderChunks { for _, builderChunk := range builderChunks {
decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(c.CTX) decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("creating alert decisions: %w", err) return "", fmt.Errorf("creating alert decisions: %w", err)
} }
@ -178,7 +178,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert)
decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize) decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize)
for _, decisionChunk := range decisionChunks { for _, decisionChunk := range decisionChunks {
err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(c.CTX) err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err)
} }
@ -191,7 +191,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert)
// it takes care of creating the new alert with the associated decisions, and it will as well deleted the "older" overlapping decisions: // it takes care of creating the new alert with the associated decisions, and it will as well deleted the "older" overlapping decisions:
// 1st pull, you get decisions [1,2,3]. it inserts [1,2,3] // 1st pull, you get decisions [1,2,3]. it inserts [1,2,3]
// 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin // 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin
func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) { func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models.Alert) (int, int, int, error) {
if alertItem == nil { if alertItem == nil {
return 0, 0, 0, errors.New("nil alert") return 0, 0, 0, errors.New("nil alert")
} }
@ -244,7 +244,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
SetScenarioHash(*alertItem.ScenarioHash). SetScenarioHash(*alertItem.ScenarioHash).
SetRemediation(true) // it's from CAPI, we always have decisions SetRemediation(true) // it's from CAPI, we always have decisions
alertRef, err := alertB.Save(c.CTX) alertRef, err := alertB.Save(ctx)
if err != nil { if err != nil {
return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err) return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err)
} }
@ -253,7 +253,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
return alertRef.ID, 0, 0, nil return alertRef.ID, 0, 0, nil
} }
txClient, err := c.Ent.Tx(c.CTX) txClient, err := c.Ent.Tx(ctx)
if err != nil { if err != nil {
return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err) return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err)
} }
@ -347,7 +347,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
decision.OriginEQ(DecOrigin), decision.OriginEQ(DecOrigin),
decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))),
decision.ValueIn(deleteChunk...), decision.ValueIn(deleteChunk...),
)).Exec(c.CTX) )).Exec(ctx)
if err != nil { if err != nil {
rollbackErr := txClient.Rollback() rollbackErr := txClient.Rollback()
if rollbackErr != nil { if rollbackErr != nil {
@ -363,7 +363,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize)
for _, builderChunk := range builderChunks { for _, builderChunk := range builderChunks {
insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(c.CTX) insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx)
if err != nil { if err != nil {
rollbackErr := txClient.Rollback() rollbackErr := txClient.Rollback()
if rollbackErr != nil { if rollbackErr != nil {
@ -391,7 +391,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
return alertRef.ID, inserted, deleted, nil return alertRef.ID, inserted, deleted, nil
} }
func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { func (c *Client) createDecisionChunk(ctx context.Context, simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) {
decisionCreate := []*ent.DecisionCreate{} decisionCreate := []*ent.DecisionCreate{}
for _, decisionItem := range decisions { for _, decisionItem := range decisions {
@ -436,7 +436,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis
return nil, nil return nil, nil
} }
ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(c.CTX) ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -444,7 +444,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis
return ret, nil return ret, nil
} }
func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { func (c *Client) createAlertChunk(ctx context.Context, machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) {
alertBuilders := []*ent.AlertCreate{} alertBuilders := []*ent.AlertCreate{}
alertDecisions := [][]*ent.Decision{} alertDecisions := [][]*ent.Decision{}
@ -540,7 +540,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario)
} }
events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(c.CTX) events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(ctx)
if err != nil { if err != nil {
return nil, errors.Wrapf(BulkError, "creating alert events: %s", err) return nil, errors.Wrapf(BulkError, "creating alert events: %s", err)
} }
@ -554,12 +554,14 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
value := metaItem.Value value := metaItem.Value
if len(metaItem.Value) > 4095 { if len(metaItem.Value) > 4095 {
c.Log.Warningf("truncated meta %s : value too long", metaItem.Key) c.Log.Warningf("truncated meta %s: value too long", metaItem.Key)
value = value[:4095] value = value[:4095]
} }
if len(metaItem.Key) > 255 { if len(metaItem.Key) > 255 {
c.Log.Warningf("truncated meta %s : key too long", metaItem.Key) c.Log.Warningf("truncated meta %s: key too long", metaItem.Key)
key = key[:255] key = key[:255]
} }
@ -568,7 +570,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
SetValue(value) SetValue(value)
} }
metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(c.CTX) metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(ctx)
if err != nil { if err != nil {
c.Log.Warningf("error creating alert meta: %s", err) c.Log.Warningf("error creating alert meta: %s", err)
} }
@ -578,7 +580,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize) decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize)
for _, decisionChunk := range decisionChunks { for _, decisionChunk := range decisionChunks {
decisionRet, err := c.createDecisionChunk(*alertItem.Simulated, stopAtTime, decisionChunk) decisionRet, err := c.createDecisionChunk(ctx, *alertItem.Simulated, stopAtTime, decisionChunk)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating alert decisions: %w", err) return nil, fmt.Errorf("creating alert decisions: %w", err)
} }
@ -636,7 +638,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
return nil, nil return nil, nil
} }
alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(c.CTX) alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(ctx)
if err != nil { if err != nil {
return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err)
} }
@ -653,7 +655,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
for retry < maxLockRetries { for retry < maxLockRetries {
// so much for the happy path... but sqlite3 errors work differently // so much for the happy path... but sqlite3 errors work differently
_, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(ctx)
if err == nil { if err == nil {
break break
} }
@ -678,17 +680,16 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
} }
} }
} }
return ret, nil return ret, nil
} }
func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]string, error) { func (c *Client) CreateAlert(ctx context.Context, machineID string, alertList []*models.Alert) ([]string, error) {
var ( var (
owner *ent.Machine owner *ent.Machine
err error err error
) )
ctx := context.TODO()
if machineID != "" { if machineID != "" {
owner, err = c.QueryMachineByID(ctx, machineID) owner, err = c.QueryMachineByID(ctx, machineID)
if err != nil { if err != nil {
@ -708,7 +709,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str
alertIDs := []string{} alertIDs := []string{}
for _, alertChunk := range alertChunks { for _, alertChunk := range alertChunks {
ids, err := c.createAlertChunk(machineID, owner, alertChunk) ids, err := c.createAlertChunk(ctx, machineID, owner, alertChunk)
if err != nil { if err != nil {
return nil, fmt.Errorf("machine '%s': %w", machineID, err) return nil, fmt.Errorf("machine '%s': %w", machineID, err)
} }
@ -717,7 +718,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str
} }
if owner != nil { if owner != nil {
err = owner.Update().SetLastPush(time.Now().UTC()).Exec(c.CTX) err = owner.Update().SetLastPush(time.Now().UTC()).Exec(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("machine '%s': %w", machineID, err) return nil, fmt.Errorf("machine '%s': %w", machineID, err)
} }
@ -919,7 +920,6 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e
case "since", "created_before", "until": case "since", "created_before", "until":
if err := handleTimeFilters(param, value[0], &predicates); err != nil { if err := handleTimeFilters(param, value[0], &predicates); err != nil {
return nil, err return nil, err
} }
case "decision_type": case "decision_type":
predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0]))) predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0])))
@ -954,7 +954,6 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e
if err := handleIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil { if err := handleIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil {
return nil, err return nil, err
} }
return predicates, nil return predicates, nil
@ -996,11 +995,11 @@ func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string]
return counts, nil return counts, nil
} }
func (c *Client) TotalAlerts() (int, error) { func (c *Client) TotalAlerts(ctx context.Context) (int, error) {
return c.Ent.Alert.Query().Count(c.CTX) return c.Ent.Alert.Query().Count(ctx)
} }
func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, error) { func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Alert, error) {
sort := "DESC" // we sort by desc by default sort := "DESC" // we sort by desc by default
if val, ok := filter["sort"]; ok { if val, ok := filter["sort"]; ok {
@ -1047,7 +1046,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert,
WithOwner() WithOwner()
if limit == 0 { if limit == 0 {
limit, err = alerts.Count(c.CTX) limit, err = alerts.Count(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to count nb alerts: %w", err) return nil, fmt.Errorf("unable to count nb alerts: %w", err)
} }
@ -1059,7 +1058,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert,
alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID)) alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID))
} }
result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX) result, err := alerts.Limit(paginationSize).Offset(offset).All(ctx)
if err != nil { if err != nil {
return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err)
} }
@ -1088,35 +1087,35 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert,
return ret, nil return ret, nil
} }
func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { func (c *Client) DeleteAlertGraphBatch(ctx context.Context, alertItems []*ent.Alert) (int, error) {
idList := make([]int, 0) idList := make([]int, 0)
for _, alert := range alertItems { for _, alert := range alertItems {
idList = append(idList, alert.ID) idList = append(idList, alert.ID)
} }
_, err := c.Ent.Event.Delete(). _, err := c.Ent.Event.Delete().
Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx)
if err != nil { if err != nil {
c.Log.Warningf("DeleteAlertGraphBatch : %s", err) c.Log.Warningf("DeleteAlertGraphBatch : %s", err)
return 0, errors.Wrapf(DeleteFail, "alert graph delete batch events") return 0, errors.Wrapf(DeleteFail, "alert graph delete batch events")
} }
_, err = c.Ent.Meta.Delete(). _, err = c.Ent.Meta.Delete().
Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx)
if err != nil { if err != nil {
c.Log.Warningf("DeleteAlertGraphBatch : %s", err) c.Log.Warningf("DeleteAlertGraphBatch : %s", err)
return 0, errors.Wrapf(DeleteFail, "alert graph delete batch meta") return 0, errors.Wrapf(DeleteFail, "alert graph delete batch meta")
} }
_, err = c.Ent.Decision.Delete(). _, err = c.Ent.Decision.Delete().
Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx)
if err != nil { if err != nil {
c.Log.Warningf("DeleteAlertGraphBatch : %s", err) c.Log.Warningf("DeleteAlertGraphBatch : %s", err)
return 0, errors.Wrapf(DeleteFail, "alert graph delete batch decisions") return 0, errors.Wrapf(DeleteFail, "alert graph delete batch decisions")
} }
deleted, err := c.Ent.Alert.Delete(). deleted, err := c.Ent.Alert.Delete().
Where(alert.IDIn(idList...)).Exec(c.CTX) Where(alert.IDIn(idList...)).Exec(ctx)
if err != nil { if err != nil {
c.Log.Warningf("DeleteAlertGraphBatch : %s", err) c.Log.Warningf("DeleteAlertGraphBatch : %s", err)
return deleted, errors.Wrapf(DeleteFail, "alert graph delete batch") return deleted, errors.Wrapf(DeleteFail, "alert graph delete batch")
@ -1127,10 +1126,10 @@ func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) {
return deleted, nil return deleted, nil
} }
func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { func (c *Client) DeleteAlertGraph(ctx context.Context, alertItem *ent.Alert) error {
// delete the associated events // delete the associated events
_, err := c.Ent.Event.Delete(). _, err := c.Ent.Event.Delete().
Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx)
if err != nil { if err != nil {
c.Log.Warningf("DeleteAlertGraph : %s", err) c.Log.Warningf("DeleteAlertGraph : %s", err)
return errors.Wrapf(DeleteFail, "event with alert ID '%d'", alertItem.ID) return errors.Wrapf(DeleteFail, "event with alert ID '%d'", alertItem.ID)
@ -1138,7 +1137,7 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error {
// delete the associated meta // delete the associated meta
_, err = c.Ent.Meta.Delete(). _, err = c.Ent.Meta.Delete().
Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx)
if err != nil { if err != nil {
c.Log.Warningf("DeleteAlertGraph : %s", err) c.Log.Warningf("DeleteAlertGraph : %s", err)
return errors.Wrapf(DeleteFail, "meta with alert ID '%d'", alertItem.ID) return errors.Wrapf(DeleteFail, "meta with alert ID '%d'", alertItem.ID)
@ -1146,14 +1145,14 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error {
// delete the associated decisions // delete the associated decisions
_, err = c.Ent.Decision.Delete(). _, err = c.Ent.Decision.Delete().
Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx)
if err != nil { if err != nil {
c.Log.Warningf("DeleteAlertGraph : %s", err) c.Log.Warningf("DeleteAlertGraph : %s", err)
return errors.Wrapf(DeleteFail, "decision with alert ID '%d'", alertItem.ID) return errors.Wrapf(DeleteFail, "decision with alert ID '%d'", alertItem.ID)
} }
// delete the alert // delete the alert
err = c.Ent.Alert.DeleteOne(alertItem).Exec(c.CTX) err = c.Ent.Alert.DeleteOne(alertItem).Exec(ctx)
if err != nil { if err != nil {
c.Log.Warningf("DeleteAlertGraph : %s", err) c.Log.Warningf("DeleteAlertGraph : %s", err)
return errors.Wrapf(DeleteFail, "alert with ID '%d'", alertItem.ID) return errors.Wrapf(DeleteFail, "alert with ID '%d'", alertItem.ID)
@ -1162,26 +1161,26 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error {
return nil return nil
} }
func (c *Client) DeleteAlertByID(id int) error { func (c *Client) DeleteAlertByID(ctx context.Context, id int) error {
alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(c.CTX) alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(ctx)
if err != nil { if err != nil {
return err return err
} }
return c.DeleteAlertGraph(alertItem) return c.DeleteAlertGraph(ctx, alertItem)
} }
func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) { func (c *Client) DeleteAlertWithFilter(ctx context.Context, filter map[string][]string) (int, error) {
preds, err := AlertPredicatesFromFilter(filter) preds, err := AlertPredicatesFromFilter(filter)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX) return c.Ent.Alert.Delete().Where(preds...).Exec(ctx)
} }
func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) { func (c *Client) GetAlertByID(ctx context.Context, alertID int) (*ent.Alert, error) {
alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(c.CTX) alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(ctx)
if err != nil { if err != nil {
/*record not found, 404*/ /*record not found, 404*/
if ent.IsNotFound(err) { if ent.IsNotFound(err) {

View file

@ -21,7 +21,6 @@ import (
type Client struct { type Client struct {
Ent *ent.Client Ent *ent.Client
CTX context.Context
Log *log.Logger Log *log.Logger
CanFlush bool CanFlush bool
Type string Type string
@ -106,7 +105,6 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro
return &Client{ return &Client{
Ent: client, Ent: client,
CTX: ctx,
Log: clog, Log: clog,
CanFlush: true, CanFlush: true,
Type: config.Type, Type: config.Type,

View file

@ -31,7 +31,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string]
var err error var err error
var start_ip, start_sfx, end_ip, end_sfx int64 var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int var ip_sz int
var contains = true contains := true
/*if contains is true, return bans that *contains* the given value (value is the inner) /*if contains is true, return bans that *contains* the given value (value is the inner)
else, return bans that are *contained* by the given value (value is the outer)*/ else, return bans that are *contained* by the given value (value is the outer)*/
@ -321,7 +321,7 @@ func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[strin
var err error var err error
var start_ip, start_sfx, end_ip, end_sfx int64 var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int var ip_sz int
var contains = true contains := true
/*if contains is true, return bans that *contains* the given value (value is the inner) /*if contains is true, return bans that *contains* the given value (value is the inner)
else, return bans that are *contained* by the given value (value is the outer) */ else, return bans that are *contained* by the given value (value is the outer) */
@ -440,7 +440,7 @@ func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[strin
var err error var err error
var start_ip, start_sfx, end_ip, end_sfx int64 var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int var ip_sz int
var contains = true contains := true
/*if contains is true, return bans that *contains* the given value (value is the inner) /*if contains is true, return bans that *contains* the given value (value is the inner)
else, return bans that are *contained* by the given value (value is the outer)*/ else, return bans that are *contained* by the given value (value is the outer)*/
decisions := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now().UTC())) decisions := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now().UTC()))

View file

@ -239,7 +239,7 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e
c.FlushOrphans(ctx) c.FlushOrphans(ctx)
c.Log.Debug("Done flushing orphan alerts") c.Log.Debug("Done flushing orphan alerts")
totalAlerts, err = c.TotalAlerts() totalAlerts, err = c.TotalAlerts(ctx)
if err != nil { if err != nil {
c.Log.Warningf("FlushAlerts (max items count): %s", err) c.Log.Warningf("FlushAlerts (max items count): %s", err)
return fmt.Errorf("unable to get alerts count: %w", err) return fmt.Errorf("unable to get alerts count: %w", err)
@ -252,7 +252,7 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e
"created_before": {MaxAge}, "created_before": {MaxAge},
} }
nbDeleted, err := c.DeleteAlertWithFilter(filter) nbDeleted, err := c.DeleteAlertWithFilter(ctx, filter)
if err != nil { if err != nil {
c.Log.Warningf("FlushAlerts (max age): %s", err) c.Log.Warningf("FlushAlerts (max age): %s", err)
return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err)
@ -268,7 +268,7 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e
// This gives us the oldest alert that we want to keep // This gives us the oldest alert that we want to keep
// We then delete all the alerts with an id lower than this one // We then delete all the alerts with an id lower than this one
// We can do this because the id is auto-increment, and the database won't reuse the same id twice // We can do this because the id is auto-increment, and the database won't reuse the same id twice
lastAlert, err := c.QueryAlertWithFilter(map[string][]string{ lastAlert, err := c.QueryAlertWithFilter(ctx, map[string][]string{
"sort": {"DESC"}, "sort": {"DESC"},
"limit": {"1"}, "limit": {"1"},
// we do not care about fetching the edges, we just want the id // we do not care about fetching the edges, we just want the id

View file

@ -17,7 +17,7 @@ func (c *Client) CreateMetric(ctx context.Context, generatedType metric.Generate
SetReceivedAt(receivedAt). SetReceivedAt(receivedAt).
SetPayload(payload). SetPayload(payload).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
c.Log.Warningf("CreateMetric: %s", err) c.Log.Warningf("CreateMetric: %s", err)
return nil, fmt.Errorf("storing metrics snapshot for '%s' at %s: %w", generatedBy, receivedAt, InsertFail) return nil, fmt.Errorf("storing metrics snapshot for '%s' at %s: %w", generatedBy, receivedAt, InsertFail)
} }