context propagation: pkg/database/{lock,decision} (#3251)

* context propagation: pkg/database/lock

* QueryAllDecisionsWithFilters(ctx...), QueryExpiredDecisionsWithFilters(ctx...)

* more Query...Decision...(ctx..)

* rest of decisions

* lint
This commit is contained in:
mmetc 2024-09-23 17:33:46 +02:00 committed by GitHub
parent 4a2a663227
commit 1133afe58d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 133 additions and 100 deletions

View file

@ -426,6 +426,7 @@ func (a *apic) CAPIPullIsOld() (bool, error) {
}
func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) {
ctx := context.TODO()
nbDeleted := 0
for _, decision := range deletedDecisions {
@ -438,7 +439,7 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
filter["scopes"] = []string{*decision.Scope}
}
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter)
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter)
if err != nil {
return 0, fmt.Errorf("expiring decisions error: %w", err)
}
@ -458,6 +459,8 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) {
var nbDeleted int
ctx := context.TODO()
for _, decisions := range deletedDecisions {
scope := decisions.Scope
@ -470,7 +473,7 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi
filter["scopes"] = []string{*scope}
}
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter)
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter)
if err != nil {
return 0, fmt.Errorf("expiring decisions error: %w", err)
}
@ -640,7 +643,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
log.Debug("Acquiring lock for pullCAPI")
err = a.dbClient.AcquirePullCAPILock()
err = a.dbClient.AcquirePullCAPILock(ctx)
if a.dbClient.IsLocked(err) {
log.Info("PullCAPI is already running, skipping")
return nil
@ -650,7 +653,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
defer func() {
log.Debug("Releasing lock for pullCAPI")
if err := a.dbClient.ReleasePullCAPILock(); err != nil {
if err := a.dbClient.ReleasePullCAPILock(ctx); err != nil {
log.Errorf("while releasing lock: %v", err)
}
}()

View file

@ -1,8 +1,8 @@
package v1
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
@ -52,7 +52,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) {
return
}
data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query())
data, err = c.DBClient.QueryDecisionWithFilter(ctx, gctx.Request.URL.Query())
if err != nil {
c.HandleDBErrors(gctx, err)
@ -93,7 +93,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
return
}
nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(decisionID)
ctx := gctx.Request.Context()
nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(ctx, decisionID)
if err != nil {
c.HandleDBErrors(gctx, err)
@ -115,7 +117,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
}
func (c *Controller) DeleteDecisions(gctx *gin.Context) {
nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(gctx.Request.URL.Query())
ctx := gctx.Request.Context()
nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, gctx.Request.URL.Query())
if err != nil {
c.HandleDBErrors(gctx, err)
@ -136,32 +140,37 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) {
gctx.JSON(http.StatusOK, deleteDecisionResp)
}
func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(map[string][]string) ([]*ent.Decision, error)) error {
func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(context.Context, map[string][]string) ([]*ent.Decision, error)) error {
// respBuffer := bytes.NewBuffer([]byte{})
limit := 30000 //FIXME : make it configurable
limit := 30000 // FIXME : make it configurable
needComma := false
lastId := 0
limitStr := fmt.Sprintf("%d", limit)
ctx := gctx.Request.Context()
limitStr := strconv.Itoa(limit)
filters["limit"] = []string{limitStr}
for {
if lastId > 0 {
lastIdStr := fmt.Sprintf("%d", lastId)
lastIdStr := strconv.Itoa(lastId)
filters["id_gt"] = []string{lastIdStr}
}
data, err := dbFunc(filters)
data, err := dbFunc(ctx, filters)
if err != nil {
return err
}
if len(data) > 0 {
lastId = data[len(data)-1].ID
results := FormatDecisions(data)
for _, decision := range results {
decisionJSON, _ := json.Marshal(decision)
if needComma {
//respBuffer.Write([]byte(","))
// respBuffer.Write([]byte(","))
gctx.Writer.WriteString(",")
} else {
needComma = true
@ -174,10 +183,12 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun
return err
}
//respBuffer.Reset()
// respBuffer.Reset()
}
}
log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId)
if len(data) < limit {
gctx.Writer.Flush()
@ -188,32 +199,37 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun
return nil
}
func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(*time.Time, map[string][]string) ([]*ent.Decision, error)) error {
//respBuffer := bytes.NewBuffer([]byte{})
limit := 30000 //FIXME : make it configurable
func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(context.Context, *time.Time, map[string][]string) ([]*ent.Decision, error)) error {
// respBuffer := bytes.NewBuffer([]byte{})
limit := 30000 // FIXME : make it configurable
needComma := false
lastId := 0
limitStr := fmt.Sprintf("%d", limit)
ctx := gctx.Request.Context()
limitStr := strconv.Itoa(limit)
filters["limit"] = []string{limitStr}
for {
if lastId > 0 {
lastIdStr := fmt.Sprintf("%d", lastId)
lastIdStr := strconv.Itoa(lastId)
filters["id_gt"] = []string{lastIdStr}
}
data, err := dbFunc(lastPull, filters)
data, err := dbFunc(ctx, lastPull, filters)
if err != nil {
return err
}
if len(data) > 0 {
lastId = data[len(data)-1].ID
results := FormatDecisions(data)
for _, decision := range results {
decisionJSON, _ := json.Marshal(decision)
if needComma {
//respBuffer.Write([]byte(","))
// respBuffer.Write([]byte(","))
gctx.Writer.WriteString(",")
} else {
needComma = true
@ -226,10 +242,12 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul
return err
}
//respBuffer.Reset()
// respBuffer.Reset()
}
}
log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId)
if len(data) < limit {
gctx.Writer.Flush()
@ -261,7 +279,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B
}
gctx.Writer.WriteString(`], "deleted": [`)
//Expired decisions
// Expired decisions
err = writeStartupDecisions(gctx, filters, c.DBClient.QueryExpiredDecisionsWithFilters)
if err != nil {
log.Errorf("failed sending expired decisions for startup: %v", err)
@ -302,8 +320,12 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B
}
func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error {
var data []*ent.Decision
var err error
var (
data []*ent.Decision
err error
)
ctx := gctx.Request.Context()
ret := make(map[string][]*models.Decision, 0)
ret["new"] = []*models.Decision{}
@ -311,7 +333,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
if val, ok := gctx.Request.URL.Query()["startup"]; ok {
if val[0] == "true" {
data, err = c.DBClient.QueryAllDecisionsWithFilters(filters)
data, err = c.DBClient.QueryAllDecisionsWithFilters(ctx, filters)
if err != nil {
log.Errorf("failed querying decisions: %v", err)
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
@ -322,7 +344,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
ret["new"] = FormatDecisions(data)
// getting expired decisions
data, err = c.DBClient.QueryExpiredDecisionsWithFilters(filters)
data, err = c.DBClient.QueryExpiredDecisionsWithFilters(ctx, filters)
if err != nil {
log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
@ -339,14 +361,14 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
}
// getting new decisions
data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(bouncerInfo.LastPull, filters)
data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(ctx, bouncerInfo.LastPull, filters)
if err != nil {
log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err)
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return err
}
//data = KeepLongestDecision(data)
// data = KeepLongestDecision(data)
ret["new"] = FormatDecisions(data)
since := time.Time{}
@ -355,7 +377,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
}
// getting expired decisions
data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(&since, filters) // do we want to give exactly lastPull time ?
data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(ctx, &since, filters) // do we want to give exactly lastPull time ?
if err != nil {
log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
@ -384,8 +406,8 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
}
if gctx.Request.Method == http.MethodHead {
//For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db
//We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true)
// For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db
// We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true)
gctx.String(http.StatusOK, "")
return
@ -403,7 +425,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
}
if err == nil {
//Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions
// Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions
if err := c.DBClient.UpdateBouncerLastPull(ctx, streamStartTime, bouncerInfo.ID); err != nil {
log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err)
}

View file

@ -43,6 +43,8 @@ type listUnsubscribe struct {
}
func DecisionCmd(message *Message, p *Papi, sync bool) error {
ctx := context.TODO()
switch message.Header.OperationCmd {
case "delete":
data, err := json.Marshal(message.Data)
@ -65,7 +67,7 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error {
filter := make(map[string][]string)
filter["uuid"] = UUIDs
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter)
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter)
if err != nil {
return fmt.Errorf("unable to expire decisions %+v: %w", UUIDs, err)
}
@ -168,6 +170,8 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
}
func ManagementCmd(message *Message, p *Papi, sync bool) error {
ctx := context.TODO()
if sync {
p.Logger.Infof("Ignoring management command from PAPI in sync mode")
return nil
@ -195,7 +199,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
filter["origin"] = []string{types.ListOrigin}
filter["scenario"] = []string{unsubscribeMsg.Name}
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter)
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter)
if err != nil {
return fmt.Errorf("unable to expire decisions for list %s : %w", unsubscribeMsg.Name, err)
}

View file

@ -121,7 +121,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string]
return query, nil
}
func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) {
func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) {
query := c.Ent.Decision.Query().Where(
decision.UntilGT(time.Now().UTC()),
)
@ -138,7 +138,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e
query = query.Order(ent.Asc(decision.FieldID))
data, err := query.All(c.CTX)
data, err := query.All(ctx)
if err != nil {
c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err)
return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters")
@ -147,7 +147,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e
return data, nil
}
func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) {
func (c *Client) QueryExpiredDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) {
query := c.Ent.Decision.Query().Where(
decision.UntilLT(time.Now().UTC()),
)
@ -165,7 +165,7 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) (
return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters")
}
data, err := query.All(c.CTX)
data, err := query.All(ctx)
if err != nil {
c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err)
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions")
@ -196,7 +196,7 @@ func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*Decisions
return r, nil
}
func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Decision, error) {
func (c *Client) QueryDecisionWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) {
var data []*ent.Decision
var err error
@ -218,7 +218,7 @@ func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Dec
decision.FieldValue,
decision.FieldScope,
decision.FieldOrigin,
).Scan(c.CTX, &data)
).Scan(ctx, &data)
if err != nil {
c.Log.Warningf("QueryDecisionWithFilter : %s", err)
return []*ent.Decision{}, errors.Wrap(QueryFail, "query decision failed")
@ -255,7 +255,7 @@ func longestDecisionForScopeTypeValue(s *sql.Selector) {
)
}
func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters map[string][]string) ([]*ent.Decision, error) {
func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) {
query := c.Ent.Decision.Query().Where(
decision.UntilLT(time.Now().UTC()),
)
@ -277,7 +277,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters
query = query.Order(ent.Asc(decision.FieldID))
data, err := query.All(c.CTX)
data, err := query.All(ctx)
if err != nil {
c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err)
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters")
@ -286,7 +286,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters
return data, nil
}
func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map[string][]string) ([]*ent.Decision, error) {
func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) {
query := c.Ent.Decision.Query().Where(
decision.UntilGT(time.Now().UTC()),
)
@ -308,7 +308,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map
query = query.Order(ent.Asc(decision.FieldID))
data, err := query.All(c.CTX)
data, err := query.All(ctx)
if err != nil {
c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err)
return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String())
@ -317,20 +317,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map
return data, nil
}
func (c *Client) DeleteDecisionById(decisionID int) ([]*ent.Decision, error) {
toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX)
if err != nil {
c.Log.Warningf("DeleteDecisionById : %s", err)
return nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID)
}
count, err := c.DeleteDecisions(toDelete)
c.Log.Debugf("deleted %d decisions", count)
return toDelete, err
}
func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) {
func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
@ -433,13 +420,13 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string,
return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
}
toDelete, err := decisions.All(c.CTX)
toDelete, err := decisions.All(ctx)
if err != nil {
c.Log.Warningf("DeleteDecisionsWithFilter : %s", err)
return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter")
}
count, err := c.DeleteDecisions(toDelete)
count, err := c.DeleteDecisions(ctx, toDelete)
if err != nil {
c.Log.Warningf("While deleting decisions : %s", err)
return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter")
@ -449,7 +436,7 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string,
}
// ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items
func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) {
func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
@ -558,13 +545,13 @@ func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string,
return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
}
DecisionsToDelete, err := decisions.All(c.CTX)
DecisionsToDelete, err := decisions.All(ctx)
if err != nil {
c.Log.Warningf("ExpireDecisionsWithFilter : %s", err)
return "0", nil, errors.Wrap(DeleteFail, "expire decisions with provided filter")
}
count, err := c.ExpireDecisions(DecisionsToDelete)
count, err := c.ExpireDecisions(ctx, DecisionsToDelete)
if err != nil {
return "0", nil, errors.Wrapf(DeleteFail, "expire decisions with provided filter : %s", err)
}
@ -583,13 +570,13 @@ func decisionIDs(decisions []*ent.Decision) []int {
// ExpireDecisions sets the expiration of a list of decisions to now()
// It returns the number of impacted decisions for the CAPI/PAPI
func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) {
func (c *Client) ExpireDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) {
if len(decisions) <= decisionDeleteBulkSize {
ids := decisionIDs(decisions)
rows, err := c.Ent.Decision.Update().Where(
decision.IDIn(ids...),
).SetUntil(time.Now().UTC()).Save(c.CTX)
).SetUntil(time.Now().UTC()).Save(ctx)
if err != nil {
return 0, fmt.Errorf("expire decisions with provided filter: %w", err)
}
@ -602,7 +589,7 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) {
total := 0
for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) {
rows, err := c.ExpireDecisions(chunk)
rows, err := c.ExpireDecisions(ctx, chunk)
if err != nil {
return total, err
}
@ -615,13 +602,13 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) {
// DeleteDecisions removes a list of decisions from the database
// It returns the number of impacted decisions for the CAPI/PAPI
func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) {
func (c *Client) DeleteDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) {
if len(decisions) < decisionDeleteBulkSize {
ids := decisionIDs(decisions)
rows, err := c.Ent.Decision.Delete().Where(
decision.IDIn(ids...),
).Exec(c.CTX)
).Exec(ctx)
if err != nil {
return 0, fmt.Errorf("hard delete decisions with provided filter: %w", err)
}
@ -634,7 +621,7 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) {
tot := 0
for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) {
rows, err := c.DeleteDecisions(chunk)
rows, err := c.DeleteDecisions(ctx, chunk)
if err != nil {
return tot, err
}
@ -646,8 +633,8 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) {
}
// ExpireDecision set the expiration of a decision to now()
func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error) {
toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX)
func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, []*ent.Decision, error) {
toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx)
// XXX: do we want 500 or 404 here?
if err != nil || len(toUpdate) == 0 {
@ -659,12 +646,12 @@ func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error
return 0, nil, ItemNotFound
}
count, err := c.ExpireDecisions(toUpdate)
count, err := c.ExpireDecisions(ctx, toUpdate)
return count, toUpdate, err
}
func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz, count int
@ -682,7 +669,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
}
count, err = decisions.Count(c.CTX)
count, err = decisions.Count(ctx)
if err != nil {
return 0, errors.Wrapf(err, "fail to count decisions")
}
@ -690,7 +677,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
return count, nil
}
func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) {
func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz, count int
@ -710,7 +697,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error)
decisions = decisions.Where(decision.UntilGT(time.Now().UTC()))
count, err = decisions.Count(c.CTX)
count, err = decisions.Count(ctx)
if err != nil {
return 0, fmt.Errorf("fail to count decisions: %w", err)
}
@ -718,7 +705,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error)
return count, nil
}
func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.Duration, error) {
func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
@ -740,7 +727,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D
decisions = decisions.Order(ent.Desc(decision.FieldUntil))
decision, err := decisions.First(c.CTX)
decision, err := decisions.First(ctx)
if err != nil && !ent.IsNotFound(err) {
return 0, fmt.Errorf("fail to get decision: %w", err)
}
@ -752,7 +739,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D
return decision.Until.Sub(time.Now().UTC()), nil
}
func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) {
func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) {
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue)
if err != nil {
return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err)
@ -768,7 +755,7 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim
return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
}
count, err := decisions.Count(c.CTX)
count, err := decisions.Count(ctx)
if err != nil {
return 0, errors.Wrapf(err, "fail to count decisions")
}

View file

@ -1,6 +1,7 @@
package database
import (
"context"
"time"
"github.com/pkg/errors"
@ -16,40 +17,45 @@ const (
CapiPullLockName = "pullCAPI"
)
func (c *Client) AcquireLock(name string) error {
func (c *Client) AcquireLock(ctx context.Context, name string) error {
log.Debugf("acquiring lock %s", name)
_, err := c.Ent.Lock.Create().
SetName(name).
SetCreatedAt(types.UtcNow()).
Save(c.CTX)
Save(ctx)
if ent.IsConstraintError(err) {
return err
}
if err != nil {
return errors.Wrapf(InsertFail, "insert lock: %s", err)
}
return nil
}
func (c *Client) ReleaseLock(name string) error {
func (c *Client) ReleaseLock(ctx context.Context, name string) error {
log.Debugf("releasing lock %s", name)
_, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(c.CTX)
_, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(ctx)
if err != nil {
return errors.Wrapf(DeleteFail, "delete lock: %s", err)
}
return nil
}
func (c *Client) ReleaseLockWithTimeout(name string, timeout int) error {
func (c *Client) ReleaseLockWithTimeout(ctx context.Context, name string, timeout int) error {
log.Debugf("releasing lock %s with timeout of %d minutes", name, timeout)
_, err := c.Ent.Lock.Delete().Where(
lock.NameEQ(name),
lock.CreatedAtLT(time.Now().UTC().Add(-time.Duration(timeout)*time.Minute)),
).Exec(c.CTX)
).Exec(ctx)
if err != nil {
return errors.Wrapf(DeleteFail, "delete lock: %s", err)
}
return nil
}
@ -57,23 +63,25 @@ func (c *Client) IsLocked(err error) bool {
return ent.IsConstraintError(err)
}
func (c *Client) AcquirePullCAPILock() error {
/*delete orphan "old" lock if present*/
err := c.ReleaseLockWithTimeout(CapiPullLockName, CAPIPullLockTimeout)
func (c *Client) AcquirePullCAPILock(ctx context.Context) error {
// delete orphan "old" lock if present
err := c.ReleaseLockWithTimeout(ctx, CapiPullLockName, CAPIPullLockTimeout)
if err != nil {
log.Errorf("unable to release pullCAPI lock: %s", err)
}
return c.AcquireLock(CapiPullLockName)
return c.AcquireLock(ctx, CapiPullLockName)
}
func (c *Client) ReleasePullCAPILock() error {
func (c *Client) ReleasePullCAPILock(ctx context.Context) error {
log.Debugf("deleting lock %s", CapiPullLockName)
_, err := c.Ent.Lock.Delete().Where(
lock.NameEQ(CapiPullLockName),
).Exec(c.CTX)
).Exec(ctx)
if err != nil {
return errors.Wrapf(DeleteFail, "delete lock: %s", err)
}
return nil
}

View file

@ -2,6 +2,7 @@ package exprhelpers
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
@ -592,7 +593,10 @@ func GetDecisionsCount(params ...any) (any, error) {
return 0, nil
}
count, err := dbClient.CountDecisionsByValue(value)
ctx := context.TODO()
count, err := dbClient.CountDecisionsByValue(ctx, value)
if err != nil {
log.Errorf("Failed to get decisions count from value '%s'", value)
return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility
@ -613,8 +617,11 @@ func GetDecisionsSinceCount(params ...any) (any, error) {
log.Errorf("Failed to parse since parameter '%s' : %s", since, err)
return 0, nil
}
ctx := context.TODO()
sinceTime := time.Now().UTC().Add(-sinceDuration)
count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime)
count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime)
if err != nil {
log.Errorf("Failed to get decisions count from value '%s'", value)
return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility
@ -628,7 +635,8 @@ func GetActiveDecisionsCount(params ...any) (any, error) {
log.Error("No database config to call GetActiveDecisionsCount()")
return 0, nil
}
count, err := dbClient.CountActiveDecisionsByValue(value)
ctx := context.TODO()
count, err := dbClient.CountActiveDecisionsByValue(ctx, value)
if err != nil {
log.Errorf("Failed to get active decisions count from value '%s'", value)
return 0, err
@ -642,7 +650,8 @@ func GetActiveDecisionsTimeLeft(params ...any) (any, error) {
log.Error("No database config to call GetActiveDecisionsTimeLeft()")
return 0, nil
}
timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(value)
ctx := context.TODO()
timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value)
if err != nil {
log.Errorf("Failed to get active decisions time left from value '%s'", value)
return 0, err