context propagation: pkg/database/alerts (#3252)

* alerts
* drop CTX from dbclient
* lint
* pkg/database/alerts: context.TODO()
* cscli: context.Background() -> cmd.Context()
This commit is contained in:
mmetc 2024-09-24 14:13:45 +02:00 committed by GitHub
parent 1133afe58d
commit 3945a991bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 141 additions and 123 deletions

View file

@ -35,12 +35,12 @@ const (
// CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it
// if alert already exists, it checks it associated decisions already exists
// if some associated decisions are missing (ie. previous insert ended up in error) it inserts them
func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) (string, error) {
func (c *Client) CreateOrUpdateAlert(ctx context.Context, machineID string, alertItem *models.Alert) (string, error) {
if alertItem.UUID == "" {
return "", errors.New("alert UUID is empty")
}
alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(c.CTX)
alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(ctx)
if err != nil && !ent.IsNotFound(err) {
return "", fmt.Errorf("unable to query alerts for uuid %s: %w", alertItem.UUID, err)
@ -48,7 +48,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert)
// alert wasn't found, insert it (expected hotpath)
if ent.IsNotFound(err) || len(alerts) == 0 {
alertIDs, err := c.CreateAlert(machineID, []*models.Alert{alertItem})
alertIDs, err := c.CreateAlert(ctx, machineID, []*models.Alert{alertItem})
if err != nil {
return "", fmt.Errorf("unable to create alert: %w", err)
}
@ -165,7 +165,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert)
builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize)
for _, builderChunk := range builderChunks {
decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(c.CTX)
decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(ctx)
if err != nil {
return "", fmt.Errorf("creating alert decisions: %w", err)
}
@ -178,7 +178,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert)
decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize)
for _, decisionChunk := range decisionChunks {
err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(c.CTX)
err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(ctx)
if err != nil {
return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err)
}
@ -191,7 +191,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert)
// it takes care of creating the new alert with the associated decisions, and it will as well deleted the "older" overlapping decisions:
// 1st pull, you get decisions [1,2,3]. it inserts [1,2,3]
// 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin
func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) {
func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models.Alert) (int, int, int, error) {
if alertItem == nil {
return 0, 0, 0, errors.New("nil alert")
}
@ -244,7 +244,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
SetScenarioHash(*alertItem.ScenarioHash).
SetRemediation(true) // it's from CAPI, we always have decisions
alertRef, err := alertB.Save(c.CTX)
alertRef, err := alertB.Save(ctx)
if err != nil {
return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err)
}
@ -253,7 +253,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
return alertRef.ID, 0, 0, nil
}
txClient, err := c.Ent.Tx(c.CTX)
txClient, err := c.Ent.Tx(ctx)
if err != nil {
return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err)
}
@ -347,7 +347,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
decision.OriginEQ(DecOrigin),
decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))),
decision.ValueIn(deleteChunk...),
)).Exec(c.CTX)
)).Exec(ctx)
if err != nil {
rollbackErr := txClient.Rollback()
if rollbackErr != nil {
@ -363,7 +363,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize)
for _, builderChunk := range builderChunks {
insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(c.CTX)
insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx)
if err != nil {
rollbackErr := txClient.Rollback()
if rollbackErr != nil {
@ -391,7 +391,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
return alertRef.ID, inserted, deleted, nil
}
func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) {
func (c *Client) createDecisionChunk(ctx context.Context, simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) {
decisionCreate := []*ent.DecisionCreate{}
for _, decisionItem := range decisions {
@ -436,7 +436,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis
return nil, nil
}
ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(c.CTX)
ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(ctx)
if err != nil {
return nil, err
}
@ -444,7 +444,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis
return ret, nil
}
func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) {
func (c *Client) createAlertChunk(ctx context.Context, machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) {
alertBuilders := []*ent.AlertCreate{}
alertDecisions := [][]*ent.Decision{}
@ -540,7 +540,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario)
}
events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(c.CTX)
events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(ctx)
if err != nil {
return nil, errors.Wrapf(BulkError, "creating alert events: %s", err)
}
@ -554,12 +554,14 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
value := metaItem.Value
if len(metaItem.Value) > 4095 {
c.Log.Warningf("truncated meta %s : value too long", metaItem.Key)
c.Log.Warningf("truncated meta %s: value too long", metaItem.Key)
value = value[:4095]
}
if len(metaItem.Key) > 255 {
c.Log.Warningf("truncated meta %s : key too long", metaItem.Key)
c.Log.Warningf("truncated meta %s: key too long", metaItem.Key)
key = key[:255]
}
@ -568,7 +570,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
SetValue(value)
}
metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(c.CTX)
metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(ctx)
if err != nil {
c.Log.Warningf("error creating alert meta: %s", err)
}
@ -578,7 +580,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize)
for _, decisionChunk := range decisionChunks {
decisionRet, err := c.createDecisionChunk(*alertItem.Simulated, stopAtTime, decisionChunk)
decisionRet, err := c.createDecisionChunk(ctx, *alertItem.Simulated, stopAtTime, decisionChunk)
if err != nil {
return nil, fmt.Errorf("creating alert decisions: %w", err)
}
@ -636,7 +638,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
return nil, nil
}
alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(c.CTX)
alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(ctx)
if err != nil {
return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err)
}
@ -653,7 +655,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
for retry < maxLockRetries {
// so much for the happy path... but sqlite3 errors work differently
_, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX)
_, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(ctx)
if err == nil {
break
}
@ -678,17 +680,16 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [
}
}
}
return ret, nil
}
func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]string, error) {
func (c *Client) CreateAlert(ctx context.Context, machineID string, alertList []*models.Alert) ([]string, error) {
var (
owner *ent.Machine
err error
)
ctx := context.TODO()
if machineID != "" {
owner, err = c.QueryMachineByID(ctx, machineID)
if err != nil {
@ -708,7 +709,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str
alertIDs := []string{}
for _, alertChunk := range alertChunks {
ids, err := c.createAlertChunk(machineID, owner, alertChunk)
ids, err := c.createAlertChunk(ctx, machineID, owner, alertChunk)
if err != nil {
return nil, fmt.Errorf("machine '%s': %w", machineID, err)
}
@ -717,7 +718,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str
}
if owner != nil {
err = owner.Update().SetLastPush(time.Now().UTC()).Exec(c.CTX)
err = owner.Update().SetLastPush(time.Now().UTC()).Exec(ctx)
if err != nil {
return nil, fmt.Errorf("machine '%s': %w", machineID, err)
}
@ -919,7 +920,6 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e
case "since", "created_before", "until":
if err := handleTimeFilters(param, value[0], &predicates); err != nil {
return nil, err
}
case "decision_type":
predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0])))
@ -954,7 +954,6 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e
if err := handleIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil {
return nil, err
}
return predicates, nil
@ -996,11 +995,11 @@ func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string]
return counts, nil
}
func (c *Client) TotalAlerts() (int, error) {
return c.Ent.Alert.Query().Count(c.CTX)
func (c *Client) TotalAlerts(ctx context.Context) (int, error) {
return c.Ent.Alert.Query().Count(ctx)
}
func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, error) {
func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Alert, error) {
sort := "DESC" // we sort by desc by default
if val, ok := filter["sort"]; ok {
@ -1047,7 +1046,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert,
WithOwner()
if limit == 0 {
limit, err = alerts.Count(c.CTX)
limit, err = alerts.Count(ctx)
if err != nil {
return nil, fmt.Errorf("unable to count nb alerts: %w", err)
}
@ -1059,7 +1058,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert,
alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID))
}
result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX)
result, err := alerts.Limit(paginationSize).Offset(offset).All(ctx)
if err != nil {
return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err)
}
@ -1088,35 +1087,35 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert,
return ret, nil
}
func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) {
func (c *Client) DeleteAlertGraphBatch(ctx context.Context, alertItems []*ent.Alert) (int, error) {
idList := make([]int, 0)
for _, alert := range alertItems {
idList = append(idList, alert.ID)
}
_, err := c.Ent.Event.Delete().
Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX)
Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx)
if err != nil {
c.Log.Warningf("DeleteAlertGraphBatch : %s", err)
return 0, errors.Wrapf(DeleteFail, "alert graph delete batch events")
}
_, err = c.Ent.Meta.Delete().
Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX)
Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx)
if err != nil {
c.Log.Warningf("DeleteAlertGraphBatch : %s", err)
return 0, errors.Wrapf(DeleteFail, "alert graph delete batch meta")
}
_, err = c.Ent.Decision.Delete().
Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX)
Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx)
if err != nil {
c.Log.Warningf("DeleteAlertGraphBatch : %s", err)
return 0, errors.Wrapf(DeleteFail, "alert graph delete batch decisions")
}
deleted, err := c.Ent.Alert.Delete().
Where(alert.IDIn(idList...)).Exec(c.CTX)
Where(alert.IDIn(idList...)).Exec(ctx)
if err != nil {
c.Log.Warningf("DeleteAlertGraphBatch : %s", err)
return deleted, errors.Wrapf(DeleteFail, "alert graph delete batch")
@ -1127,10 +1126,10 @@ func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) {
return deleted, nil
}
func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error {
func (c *Client) DeleteAlertGraph(ctx context.Context, alertItem *ent.Alert) error {
// delete the associated events
_, err := c.Ent.Event.Delete().
Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX)
Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx)
if err != nil {
c.Log.Warningf("DeleteAlertGraph : %s", err)
return errors.Wrapf(DeleteFail, "event with alert ID '%d'", alertItem.ID)
@ -1138,7 +1137,7 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error {
// delete the associated meta
_, err = c.Ent.Meta.Delete().
Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX)
Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx)
if err != nil {
c.Log.Warningf("DeleteAlertGraph : %s", err)
return errors.Wrapf(DeleteFail, "meta with alert ID '%d'", alertItem.ID)
@ -1146,14 +1145,14 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error {
// delete the associated decisions
_, err = c.Ent.Decision.Delete().
Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX)
Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx)
if err != nil {
c.Log.Warningf("DeleteAlertGraph : %s", err)
return errors.Wrapf(DeleteFail, "decision with alert ID '%d'", alertItem.ID)
}
// delete the alert
err = c.Ent.Alert.DeleteOne(alertItem).Exec(c.CTX)
err = c.Ent.Alert.DeleteOne(alertItem).Exec(ctx)
if err != nil {
c.Log.Warningf("DeleteAlertGraph : %s", err)
return errors.Wrapf(DeleteFail, "alert with ID '%d'", alertItem.ID)
@ -1162,26 +1161,26 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error {
return nil
}
func (c *Client) DeleteAlertByID(id int) error {
alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(c.CTX)
func (c *Client) DeleteAlertByID(ctx context.Context, id int) error {
alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(ctx)
if err != nil {
return err
}
return c.DeleteAlertGraph(alertItem)
return c.DeleteAlertGraph(ctx, alertItem)
}
func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) {
func (c *Client) DeleteAlertWithFilter(ctx context.Context, filter map[string][]string) (int, error) {
preds, err := AlertPredicatesFromFilter(filter)
if err != nil {
return 0, err
}
return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX)
return c.Ent.Alert.Delete().Where(preds...).Exec(ctx)
}
func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) {
alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(c.CTX)
func (c *Client) GetAlertByID(ctx context.Context, alertID int) (*ent.Alert, error) {
alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(ctx)
if err != nil {
/*record not found, 404*/
if ent.IsNotFound(err) {