mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-10 20:05:55 +02:00
context propagation: pkg/database/{lock,decision} (#3251)
* context propagation: pkg/database/lock * QueryAllDecisionsWithFilters(ctx...), QueryExpiredDecisionsWithFilters(ctx...) * more Query...Decision...(ctx..) * rest of decisions * lint
This commit is contained in:
parent
4a2a663227
commit
1133afe58d
6 changed files with 133 additions and 100 deletions
|
@ -426,6 +426,7 @@ func (a *apic) CAPIPullIsOld() (bool, error) {
|
|||
}
|
||||
|
||||
func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) {
|
||||
ctx := context.TODO()
|
||||
nbDeleted := 0
|
||||
|
||||
for _, decision := range deletedDecisions {
|
||||
|
@ -438,7 +439,7 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
|
|||
filter["scopes"] = []string{*decision.Scope}
|
||||
}
|
||||
|
||||
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter)
|
||||
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("expiring decisions error: %w", err)
|
||||
}
|
||||
|
@ -458,6 +459,8 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
|
|||
func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) {
|
||||
var nbDeleted int
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
for _, decisions := range deletedDecisions {
|
||||
scope := decisions.Scope
|
||||
|
||||
|
@ -470,7 +473,7 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi
|
|||
filter["scopes"] = []string{*scope}
|
||||
}
|
||||
|
||||
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter)
|
||||
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("expiring decisions error: %w", err)
|
||||
}
|
||||
|
@ -640,7 +643,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
|
|||
|
||||
log.Debug("Acquiring lock for pullCAPI")
|
||||
|
||||
err = a.dbClient.AcquirePullCAPILock()
|
||||
err = a.dbClient.AcquirePullCAPILock(ctx)
|
||||
if a.dbClient.IsLocked(err) {
|
||||
log.Info("PullCAPI is already running, skipping")
|
||||
return nil
|
||||
|
@ -650,7 +653,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
|
|||
defer func() {
|
||||
log.Debug("Releasing lock for pullCAPI")
|
||||
|
||||
if err := a.dbClient.ReleasePullCAPILock(); err != nil {
|
||||
if err := a.dbClient.ReleasePullCAPILock(ctx); err != nil {
|
||||
log.Errorf("while releasing lock: %v", err)
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -52,7 +52,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query())
|
||||
data, err = c.DBClient.QueryDecisionWithFilter(ctx, gctx.Request.URL.Query())
|
||||
if err != nil {
|
||||
c.HandleDBErrors(gctx, err)
|
||||
|
||||
|
@ -93,7 +93,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(decisionID)
|
||||
ctx := gctx.Request.Context()
|
||||
|
||||
nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(ctx, decisionID)
|
||||
if err != nil {
|
||||
c.HandleDBErrors(gctx, err)
|
||||
|
||||
|
@ -115,7 +117,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
|
|||
}
|
||||
|
||||
func (c *Controller) DeleteDecisions(gctx *gin.Context) {
|
||||
nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(gctx.Request.URL.Query())
|
||||
ctx := gctx.Request.Context()
|
||||
|
||||
nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, gctx.Request.URL.Query())
|
||||
if err != nil {
|
||||
c.HandleDBErrors(gctx, err)
|
||||
|
||||
|
@ -136,32 +140,37 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) {
|
|||
gctx.JSON(http.StatusOK, deleteDecisionResp)
|
||||
}
|
||||
|
||||
func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(map[string][]string) ([]*ent.Decision, error)) error {
|
||||
func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(context.Context, map[string][]string) ([]*ent.Decision, error)) error {
|
||||
// respBuffer := bytes.NewBuffer([]byte{})
|
||||
limit := 30000 //FIXME : make it configurable
|
||||
limit := 30000 // FIXME : make it configurable
|
||||
needComma := false
|
||||
lastId := 0
|
||||
|
||||
limitStr := fmt.Sprintf("%d", limit)
|
||||
ctx := gctx.Request.Context()
|
||||
|
||||
limitStr := strconv.Itoa(limit)
|
||||
filters["limit"] = []string{limitStr}
|
||||
|
||||
for {
|
||||
if lastId > 0 {
|
||||
lastIdStr := fmt.Sprintf("%d", lastId)
|
||||
lastIdStr := strconv.Itoa(lastId)
|
||||
filters["id_gt"] = []string{lastIdStr}
|
||||
}
|
||||
|
||||
data, err := dbFunc(filters)
|
||||
data, err := dbFunc(ctx, filters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
lastId = data[len(data)-1].ID
|
||||
|
||||
results := FormatDecisions(data)
|
||||
for _, decision := range results {
|
||||
decisionJSON, _ := json.Marshal(decision)
|
||||
|
||||
if needComma {
|
||||
//respBuffer.Write([]byte(","))
|
||||
// respBuffer.Write([]byte(","))
|
||||
gctx.Writer.WriteString(",")
|
||||
} else {
|
||||
needComma = true
|
||||
|
@ -174,10 +183,12 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun
|
|||
|
||||
return err
|
||||
}
|
||||
//respBuffer.Reset()
|
||||
// respBuffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId)
|
||||
|
||||
if len(data) < limit {
|
||||
gctx.Writer.Flush()
|
||||
|
||||
|
@ -188,32 +199,37 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun
|
|||
return nil
|
||||
}
|
||||
|
||||
func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(*time.Time, map[string][]string) ([]*ent.Decision, error)) error {
|
||||
//respBuffer := bytes.NewBuffer([]byte{})
|
||||
limit := 30000 //FIXME : make it configurable
|
||||
func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(context.Context, *time.Time, map[string][]string) ([]*ent.Decision, error)) error {
|
||||
// respBuffer := bytes.NewBuffer([]byte{})
|
||||
limit := 30000 // FIXME : make it configurable
|
||||
needComma := false
|
||||
lastId := 0
|
||||
|
||||
limitStr := fmt.Sprintf("%d", limit)
|
||||
ctx := gctx.Request.Context()
|
||||
|
||||
limitStr := strconv.Itoa(limit)
|
||||
filters["limit"] = []string{limitStr}
|
||||
|
||||
for {
|
||||
if lastId > 0 {
|
||||
lastIdStr := fmt.Sprintf("%d", lastId)
|
||||
lastIdStr := strconv.Itoa(lastId)
|
||||
filters["id_gt"] = []string{lastIdStr}
|
||||
}
|
||||
|
||||
data, err := dbFunc(lastPull, filters)
|
||||
data, err := dbFunc(ctx, lastPull, filters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
lastId = data[len(data)-1].ID
|
||||
|
||||
results := FormatDecisions(data)
|
||||
for _, decision := range results {
|
||||
decisionJSON, _ := json.Marshal(decision)
|
||||
|
||||
if needComma {
|
||||
//respBuffer.Write([]byte(","))
|
||||
// respBuffer.Write([]byte(","))
|
||||
gctx.Writer.WriteString(",")
|
||||
} else {
|
||||
needComma = true
|
||||
|
@ -226,10 +242,12 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul
|
|||
|
||||
return err
|
||||
}
|
||||
//respBuffer.Reset()
|
||||
// respBuffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId)
|
||||
|
||||
if len(data) < limit {
|
||||
gctx.Writer.Flush()
|
||||
|
||||
|
@ -261,7 +279,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B
|
|||
}
|
||||
|
||||
gctx.Writer.WriteString(`], "deleted": [`)
|
||||
//Expired decisions
|
||||
// Expired decisions
|
||||
err = writeStartupDecisions(gctx, filters, c.DBClient.QueryExpiredDecisionsWithFilters)
|
||||
if err != nil {
|
||||
log.Errorf("failed sending expired decisions for startup: %v", err)
|
||||
|
@ -302,8 +320,12 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B
|
|||
}
|
||||
|
||||
func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error {
|
||||
var data []*ent.Decision
|
||||
var err error
|
||||
var (
|
||||
data []*ent.Decision
|
||||
err error
|
||||
)
|
||||
|
||||
ctx := gctx.Request.Context()
|
||||
|
||||
ret := make(map[string][]*models.Decision, 0)
|
||||
ret["new"] = []*models.Decision{}
|
||||
|
@ -311,7 +333,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
|
|||
|
||||
if val, ok := gctx.Request.URL.Query()["startup"]; ok {
|
||||
if val[0] == "true" {
|
||||
data, err = c.DBClient.QueryAllDecisionsWithFilters(filters)
|
||||
data, err = c.DBClient.QueryAllDecisionsWithFilters(ctx, filters)
|
||||
if err != nil {
|
||||
log.Errorf("failed querying decisions: %v", err)
|
||||
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
|
||||
|
@ -322,7 +344,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
|
|||
ret["new"] = FormatDecisions(data)
|
||||
|
||||
// getting expired decisions
|
||||
data, err = c.DBClient.QueryExpiredDecisionsWithFilters(filters)
|
||||
data, err = c.DBClient.QueryExpiredDecisionsWithFilters(ctx, filters)
|
||||
if err != nil {
|
||||
log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
|
||||
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
|
||||
|
@ -339,14 +361,14 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
|
|||
}
|
||||
|
||||
// getting new decisions
|
||||
data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(bouncerInfo.LastPull, filters)
|
||||
data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(ctx, bouncerInfo.LastPull, filters)
|
||||
if err != nil {
|
||||
log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err)
|
||||
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
|
||||
|
||||
return err
|
||||
}
|
||||
//data = KeepLongestDecision(data)
|
||||
// data = KeepLongestDecision(data)
|
||||
ret["new"] = FormatDecisions(data)
|
||||
|
||||
since := time.Time{}
|
||||
|
@ -355,7 +377,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
|
|||
}
|
||||
|
||||
// getting expired decisions
|
||||
data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(&since, filters) // do we want to give exactly lastPull time ?
|
||||
data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(ctx, &since, filters) // do we want to give exactly lastPull time ?
|
||||
if err != nil {
|
||||
log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
|
||||
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
|
||||
|
@ -384,8 +406,8 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
|
|||
}
|
||||
|
||||
if gctx.Request.Method == http.MethodHead {
|
||||
//For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db
|
||||
//We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true)
|
||||
// For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db
|
||||
// We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true)
|
||||
gctx.String(http.StatusOK, "")
|
||||
|
||||
return
|
||||
|
@ -403,7 +425,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
|
|||
}
|
||||
|
||||
if err == nil {
|
||||
//Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions
|
||||
// Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions
|
||||
if err := c.DBClient.UpdateBouncerLastPull(ctx, streamStartTime, bouncerInfo.ID); err != nil {
|
||||
log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err)
|
||||
}
|
||||
|
|
|
@ -43,6 +43,8 @@ type listUnsubscribe struct {
|
|||
}
|
||||
|
||||
func DecisionCmd(message *Message, p *Papi, sync bool) error {
|
||||
ctx := context.TODO()
|
||||
|
||||
switch message.Header.OperationCmd {
|
||||
case "delete":
|
||||
data, err := json.Marshal(message.Data)
|
||||
|
@ -65,7 +67,7 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error {
|
|||
filter := make(map[string][]string)
|
||||
filter["uuid"] = UUIDs
|
||||
|
||||
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter)
|
||||
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to expire decisions %+v: %w", UUIDs, err)
|
||||
}
|
||||
|
@ -168,6 +170,8 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
|
|||
}
|
||||
|
||||
func ManagementCmd(message *Message, p *Papi, sync bool) error {
|
||||
ctx := context.TODO()
|
||||
|
||||
if sync {
|
||||
p.Logger.Infof("Ignoring management command from PAPI in sync mode")
|
||||
return nil
|
||||
|
@ -195,7 +199,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
|
|||
filter["origin"] = []string{types.ListOrigin}
|
||||
filter["scenario"] = []string{unsubscribeMsg.Name}
|
||||
|
||||
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter)
|
||||
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to expire decisions for list %s : %w", unsubscribeMsg.Name, err)
|
||||
}
|
||||
|
|
|
@ -121,7 +121,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string]
|
|||
return query, nil
|
||||
}
|
||||
|
||||
func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) {
|
||||
func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) {
|
||||
query := c.Ent.Decision.Query().Where(
|
||||
decision.UntilGT(time.Now().UTC()),
|
||||
)
|
||||
|
@ -138,7 +138,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e
|
|||
|
||||
query = query.Order(ent.Asc(decision.FieldID))
|
||||
|
||||
data, err := query.All(c.CTX)
|
||||
data, err := query.All(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err)
|
||||
return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters")
|
||||
|
@ -147,7 +147,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) {
|
||||
func (c *Client) QueryExpiredDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) {
|
||||
query := c.Ent.Decision.Query().Where(
|
||||
decision.UntilLT(time.Now().UTC()),
|
||||
)
|
||||
|
@ -165,7 +165,7 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) (
|
|||
return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters")
|
||||
}
|
||||
|
||||
data, err := query.All(c.CTX)
|
||||
data, err := query.All(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err)
|
||||
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions")
|
||||
|
@ -196,7 +196,7 @@ func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*Decisions
|
|||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Decision, error) {
|
||||
func (c *Client) QueryDecisionWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) {
|
||||
var data []*ent.Decision
|
||||
var err error
|
||||
|
||||
|
@ -218,7 +218,7 @@ func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Dec
|
|||
decision.FieldValue,
|
||||
decision.FieldScope,
|
||||
decision.FieldOrigin,
|
||||
).Scan(c.CTX, &data)
|
||||
).Scan(ctx, &data)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryDecisionWithFilter : %s", err)
|
||||
return []*ent.Decision{}, errors.Wrap(QueryFail, "query decision failed")
|
||||
|
@ -255,7 +255,7 @@ func longestDecisionForScopeTypeValue(s *sql.Selector) {
|
|||
)
|
||||
}
|
||||
|
||||
func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters map[string][]string) ([]*ent.Decision, error) {
|
||||
func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) {
|
||||
query := c.Ent.Decision.Query().Where(
|
||||
decision.UntilLT(time.Now().UTC()),
|
||||
)
|
||||
|
@ -277,7 +277,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters
|
|||
|
||||
query = query.Order(ent.Asc(decision.FieldID))
|
||||
|
||||
data, err := query.All(c.CTX)
|
||||
data, err := query.All(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err)
|
||||
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters")
|
||||
|
@ -286,7 +286,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map[string][]string) ([]*ent.Decision, error) {
|
||||
func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) {
|
||||
query := c.Ent.Decision.Query().Where(
|
||||
decision.UntilGT(time.Now().UTC()),
|
||||
)
|
||||
|
@ -308,7 +308,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map
|
|||
|
||||
query = query.Order(ent.Asc(decision.FieldID))
|
||||
|
||||
data, err := query.All(c.CTX)
|
||||
data, err := query.All(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err)
|
||||
return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String())
|
||||
|
@ -317,20 +317,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (c *Client) DeleteDecisionById(decisionID int) ([]*ent.Decision, error) {
|
||||
toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX)
|
||||
if err != nil {
|
||||
c.Log.Warningf("DeleteDecisionById : %s", err)
|
||||
return nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID)
|
||||
}
|
||||
|
||||
count, err := c.DeleteDecisions(toDelete)
|
||||
c.Log.Debugf("deleted %d decisions", count)
|
||||
|
||||
return toDelete, err
|
||||
}
|
||||
|
||||
func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) {
|
||||
func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) {
|
||||
var err error
|
||||
var start_ip, start_sfx, end_ip, end_sfx int64
|
||||
var ip_sz int
|
||||
|
@ -433,13 +420,13 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string,
|
|||
return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
|
||||
}
|
||||
|
||||
toDelete, err := decisions.All(c.CTX)
|
||||
toDelete, err := decisions.All(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("DeleteDecisionsWithFilter : %s", err)
|
||||
return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter")
|
||||
}
|
||||
|
||||
count, err := c.DeleteDecisions(toDelete)
|
||||
count, err := c.DeleteDecisions(ctx, toDelete)
|
||||
if err != nil {
|
||||
c.Log.Warningf("While deleting decisions : %s", err)
|
||||
return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter")
|
||||
|
@ -449,7 +436,7 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string,
|
|||
}
|
||||
|
||||
// ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items
|
||||
func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) {
|
||||
func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) {
|
||||
var err error
|
||||
var start_ip, start_sfx, end_ip, end_sfx int64
|
||||
var ip_sz int
|
||||
|
@ -558,13 +545,13 @@ func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string,
|
|||
return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
|
||||
}
|
||||
|
||||
DecisionsToDelete, err := decisions.All(c.CTX)
|
||||
DecisionsToDelete, err := decisions.All(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("ExpireDecisionsWithFilter : %s", err)
|
||||
return "0", nil, errors.Wrap(DeleteFail, "expire decisions with provided filter")
|
||||
}
|
||||
|
||||
count, err := c.ExpireDecisions(DecisionsToDelete)
|
||||
count, err := c.ExpireDecisions(ctx, DecisionsToDelete)
|
||||
if err != nil {
|
||||
return "0", nil, errors.Wrapf(DeleteFail, "expire decisions with provided filter : %s", err)
|
||||
}
|
||||
|
@ -583,13 +570,13 @@ func decisionIDs(decisions []*ent.Decision) []int {
|
|||
|
||||
// ExpireDecisions sets the expiration of a list of decisions to now()
|
||||
// It returns the number of impacted decisions for the CAPI/PAPI
|
||||
func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) {
|
||||
func (c *Client) ExpireDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) {
|
||||
if len(decisions) <= decisionDeleteBulkSize {
|
||||
ids := decisionIDs(decisions)
|
||||
|
||||
rows, err := c.Ent.Decision.Update().Where(
|
||||
decision.IDIn(ids...),
|
||||
).SetUntil(time.Now().UTC()).Save(c.CTX)
|
||||
).SetUntil(time.Now().UTC()).Save(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("expire decisions with provided filter: %w", err)
|
||||
}
|
||||
|
@ -602,7 +589,7 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) {
|
|||
total := 0
|
||||
|
||||
for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) {
|
||||
rows, err := c.ExpireDecisions(chunk)
|
||||
rows, err := c.ExpireDecisions(ctx, chunk)
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
@ -615,13 +602,13 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) {
|
|||
|
||||
// DeleteDecisions removes a list of decisions from the database
|
||||
// It returns the number of impacted decisions for the CAPI/PAPI
|
||||
func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) {
|
||||
func (c *Client) DeleteDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) {
|
||||
if len(decisions) < decisionDeleteBulkSize {
|
||||
ids := decisionIDs(decisions)
|
||||
|
||||
rows, err := c.Ent.Decision.Delete().Where(
|
||||
decision.IDIn(ids...),
|
||||
).Exec(c.CTX)
|
||||
).Exec(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("hard delete decisions with provided filter: %w", err)
|
||||
}
|
||||
|
@ -634,7 +621,7 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) {
|
|||
tot := 0
|
||||
|
||||
for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) {
|
||||
rows, err := c.DeleteDecisions(chunk)
|
||||
rows, err := c.DeleteDecisions(ctx, chunk)
|
||||
if err != nil {
|
||||
return tot, err
|
||||
}
|
||||
|
@ -646,8 +633,8 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) {
|
|||
}
|
||||
|
||||
// ExpireDecision set the expiration of a decision to now()
|
||||
func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error) {
|
||||
toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX)
|
||||
func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, []*ent.Decision, error) {
|
||||
toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx)
|
||||
|
||||
// XXX: do we want 500 or 404 here?
|
||||
if err != nil || len(toUpdate) == 0 {
|
||||
|
@ -659,12 +646,12 @@ func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error
|
|||
return 0, nil, ItemNotFound
|
||||
}
|
||||
|
||||
count, err := c.ExpireDecisions(toUpdate)
|
||||
count, err := c.ExpireDecisions(ctx, toUpdate)
|
||||
|
||||
return count, toUpdate, err
|
||||
}
|
||||
|
||||
func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
|
||||
func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) {
|
||||
var err error
|
||||
var start_ip, start_sfx, end_ip, end_sfx int64
|
||||
var ip_sz, count int
|
||||
|
@ -682,7 +669,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
|
|||
return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
|
||||
}
|
||||
|
||||
count, err = decisions.Count(c.CTX)
|
||||
count, err = decisions.Count(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Wrapf(err, "fail to count decisions")
|
||||
}
|
||||
|
@ -690,7 +677,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
|
|||
return count, nil
|
||||
}
|
||||
|
||||
func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) {
|
||||
func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) {
|
||||
var err error
|
||||
var start_ip, start_sfx, end_ip, end_sfx int64
|
||||
var ip_sz, count int
|
||||
|
@ -710,7 +697,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error)
|
|||
|
||||
decisions = decisions.Where(decision.UntilGT(time.Now().UTC()))
|
||||
|
||||
count, err = decisions.Count(c.CTX)
|
||||
count, err = decisions.Count(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("fail to count decisions: %w", err)
|
||||
}
|
||||
|
@ -718,7 +705,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error)
|
|||
return count, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.Duration, error) {
|
||||
func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) {
|
||||
var err error
|
||||
var start_ip, start_sfx, end_ip, end_sfx int64
|
||||
var ip_sz int
|
||||
|
@ -740,7 +727,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D
|
|||
|
||||
decisions = decisions.Order(ent.Desc(decision.FieldUntil))
|
||||
|
||||
decision, err := decisions.First(c.CTX)
|
||||
decision, err := decisions.First(ctx)
|
||||
if err != nil && !ent.IsNotFound(err) {
|
||||
return 0, fmt.Errorf("fail to get decision: %w", err)
|
||||
}
|
||||
|
@ -752,7 +739,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D
|
|||
return decision.Until.Sub(time.Now().UTC()), nil
|
||||
}
|
||||
|
||||
func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) {
|
||||
func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) {
|
||||
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue)
|
||||
if err != nil {
|
||||
return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err)
|
||||
|
@ -768,7 +755,7 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim
|
|||
return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
|
||||
}
|
||||
|
||||
count, err := decisions.Count(c.CTX)
|
||||
count, err := decisions.Count(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Wrapf(err, "fail to count decisions")
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -16,40 +17,45 @@ const (
|
|||
CapiPullLockName = "pullCAPI"
|
||||
)
|
||||
|
||||
func (c *Client) AcquireLock(name string) error {
|
||||
func (c *Client) AcquireLock(ctx context.Context, name string) error {
|
||||
log.Debugf("acquiring lock %s", name)
|
||||
_, err := c.Ent.Lock.Create().
|
||||
SetName(name).
|
||||
SetCreatedAt(types.UtcNow()).
|
||||
Save(c.CTX)
|
||||
Save(ctx)
|
||||
|
||||
if ent.IsConstraintError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrapf(InsertFail, "insert lock: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ReleaseLock(name string) error {
|
||||
func (c *Client) ReleaseLock(ctx context.Context, name string) error {
|
||||
log.Debugf("releasing lock %s", name)
|
||||
_, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(c.CTX)
|
||||
_, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrapf(DeleteFail, "delete lock: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ReleaseLockWithTimeout(name string, timeout int) error {
|
||||
func (c *Client) ReleaseLockWithTimeout(ctx context.Context, name string, timeout int) error {
|
||||
log.Debugf("releasing lock %s with timeout of %d minutes", name, timeout)
|
||||
|
||||
_, err := c.Ent.Lock.Delete().Where(
|
||||
lock.NameEQ(name),
|
||||
lock.CreatedAtLT(time.Now().UTC().Add(-time.Duration(timeout)*time.Minute)),
|
||||
).Exec(c.CTX)
|
||||
|
||||
).Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrapf(DeleteFail, "delete lock: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -57,23 +63,25 @@ func (c *Client) IsLocked(err error) bool {
|
|||
return ent.IsConstraintError(err)
|
||||
}
|
||||
|
||||
func (c *Client) AcquirePullCAPILock() error {
|
||||
|
||||
/*delete orphan "old" lock if present*/
|
||||
err := c.ReleaseLockWithTimeout(CapiPullLockName, CAPIPullLockTimeout)
|
||||
func (c *Client) AcquirePullCAPILock(ctx context.Context) error {
|
||||
// delete orphan "old" lock if present
|
||||
err := c.ReleaseLockWithTimeout(ctx, CapiPullLockName, CAPIPullLockTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("unable to release pullCAPI lock: %s", err)
|
||||
}
|
||||
return c.AcquireLock(CapiPullLockName)
|
||||
|
||||
return c.AcquireLock(ctx, CapiPullLockName)
|
||||
}
|
||||
|
||||
func (c *Client) ReleasePullCAPILock() error {
|
||||
func (c *Client) ReleasePullCAPILock(ctx context.Context) error {
|
||||
log.Debugf("deleting lock %s", CapiPullLockName)
|
||||
|
||||
_, err := c.Ent.Lock.Delete().Where(
|
||||
lock.NameEQ(CapiPullLockName),
|
||||
).Exec(c.CTX)
|
||||
).Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrapf(DeleteFail, "delete lock: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package exprhelpers
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -592,7 +593,10 @@ func GetDecisionsCount(params ...any) (any, error) {
|
|||
return 0, nil
|
||||
|
||||
}
|
||||
count, err := dbClient.CountDecisionsByValue(value)
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
count, err := dbClient.CountDecisionsByValue(ctx, value)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get decisions count from value '%s'", value)
|
||||
return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility
|
||||
|
@ -613,8 +617,11 @@ func GetDecisionsSinceCount(params ...any) (any, error) {
|
|||
log.Errorf("Failed to parse since parameter '%s' : %s", since, err)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
sinceTime := time.Now().UTC().Add(-sinceDuration)
|
||||
count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime)
|
||||
|
||||
count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get decisions count from value '%s'", value)
|
||||
return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility
|
||||
|
@ -628,7 +635,8 @@ func GetActiveDecisionsCount(params ...any) (any, error) {
|
|||
log.Error("No database config to call GetActiveDecisionsCount()")
|
||||
return 0, nil
|
||||
}
|
||||
count, err := dbClient.CountActiveDecisionsByValue(value)
|
||||
ctx := context.TODO()
|
||||
count, err := dbClient.CountActiveDecisionsByValue(ctx, value)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get active decisions count from value '%s'", value)
|
||||
return 0, err
|
||||
|
@ -642,7 +650,8 @@ func GetActiveDecisionsTimeLeft(params ...any) (any, error) {
|
|||
log.Error("No database config to call GetActiveDecisionsTimeLeft()")
|
||||
return 0, nil
|
||||
}
|
||||
timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(value)
|
||||
ctx := context.TODO()
|
||||
timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get active decisions time left from value '%s'", value)
|
||||
return 0, err
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue