From 34e306505c9ae08b996e28f592d35b8f2a0a6bb5 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 24 Apr 2025 10:25:48 +0200 Subject: [PATCH] refact pkg/database: dry decision count (#3586) --- pkg/database/decisions.go | 74 ++++++-------------------------------- pkg/exprhelpers/helpers.go | 6 ++-- 2 files changed, 14 insertions(+), 66 deletions(-) diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index 049560a48..94b8a54b7 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -533,44 +533,10 @@ func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, [ return count, toUpdate, err } -func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { - var ( - err error - start_ip, start_sfx, end_ip, end_sfx int64 - ip_sz, count int - ) - - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) +func (c *Client) CountDecisionsByValue(ctx context.Context, value string, since *time.Time, onlyActive bool) (int, error) { + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(value) if err != nil { - return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) - } - - contains := true - decisions := c.Ent.Decision.Query() - - decisions, err = decisionIPFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) - if err != nil { - return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") - } - - count, err = decisions.Count(ctx) - if err != nil { - return 0, errors.Wrapf(err, "fail to count decisions") - } - - return count, nil -} - -func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { - var ( - err error - start_ip, start_sfx, end_ip, end_sfx int64 - ip_sz, count int - ) - - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) - if err != nil { - return 0, fmt.Errorf("unable to convert '%s' to int: %w", decisionValue, err) + return 0, fmt.Errorf("unable to convert '%s' to int: %w", value, err) } contains := true @@ -581,9 +547,15 @@ func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue return 0, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err) } - decisions = decisions.Where(decision.UntilGT(time.Now().UTC())) + if since != nil { + decisions = decisions.Where(decision.CreatedAtGT(*since)) + } - count, err = decisions.Count(ctx) + if onlyActive { + decisions = decisions.Where(decision.UntilGT(time.Now().UTC())) + } + + count, err := decisions.Count(ctx) if err != nil { return 0, fmt.Errorf("fail to count decisions: %w", err) } @@ -627,30 +599,6 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decision return decision.Until.Sub(time.Now().UTC()), nil } -func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) { - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue) - if err != nil { - return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) - } - - contains := true - decisions := c.Ent.Decision.Query().Where( - decision.CreatedAtGT(since), - ) - - decisions, err = decisionIPFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) - if err != nil { - return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") - } - - count, err := decisions.Count(ctx) - if err != nil { - return 0, errors.Wrapf(err, "fail to count decisions") - } - - return count, nil -} - func decisionIPv4Filter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) { if contains { /*Decision contains {start_ip,end_ip}*/ diff --git a/pkg/exprhelpers/helpers.go b/pkg/exprhelpers/helpers.go index 1e5426fdb..8d6e0cd65 100644 --- a/pkg/exprhelpers/helpers.go +++ b/pkg/exprhelpers/helpers.go @@ -642,7 +642,7 @@ func GetDecisionsCount(params ...any) (any, error) { ctx := context.TODO() - count, err := dbClient.CountDecisionsByValue(ctx, value) + count, err := dbClient.CountDecisionsByValue(ctx, value, nil, false) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -670,7 +670,7 @@ func GetDecisionsSinceCount(params ...any) (any, error) { ctx := context.TODO() sinceTime := time.Now().UTC().Add(-sinceDuration) - count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime) + count, err := dbClient.CountDecisionsByValue(ctx, value, &sinceTime, false) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -689,7 +689,7 @@ func GetActiveDecisionsCount(params ...any) (any, error) { ctx := context.TODO() - count, err := dbClient.CountActiveDecisionsByValue(ctx, value) + count, err := dbClient.CountDecisionsByValue(ctx, value, nil, true) if err != nil { log.Errorf("Failed to get active decisions count from value '%s'", value) return 0, err