Refact pkg/database/decisions.go (#3541)

This commit is contained in:
mmetc 2025-04-16 11:53:52 +02:00 committed by GitHub
parent c17d42278f
commit 620bd0117a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 132 additions and 245 deletions

View file

@ -62,7 +62,7 @@ func handleTimeFilters(param, value string, predicates *[]predicate.Alert) error
return nil
}
func handleIPv4Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) {
func handleAlertIPv4Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) {
if contains { // decision contains {start_ip,end_ip}
*predicates = append(*predicates, alert.And(
alert.HasDecisionsWith(decision.StartIPLTE(start_ip)),
@ -78,7 +78,7 @@ func handleIPv4Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip,
}
}
func handleIPv6Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) {
func handleAlertIPv6Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) {
if contains { // decision contains {start_ip,end_ip}
*predicates = append(*predicates, alert.And(
// matching addr size
@ -132,11 +132,11 @@ func handleIPv6Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip,
}
}
func handleIPPredicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) error {
func handleAlertIPPredicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) error {
if ip_sz == 4 {
handleIPv4Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates)
handleAlertIPv4Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates)
} else if ip_sz == 16 {
handleIPv6Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates)
handleAlertIPv6Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates)
} else if ip_sz != 0 {
return errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
}
@ -170,7 +170,7 @@ func handleIncludeCapiFilter(value string, predicates *[]predicate.Alert) error
return nil
}
func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, error) {
func alertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, error) {
predicates := make([]predicate.Alert, 0)
var (
@ -241,7 +241,7 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e
}
}
if err := handleIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil {
if err := handleAlertIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil {
return nil, err
}
@ -249,7 +249,7 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e
}
func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]string) (*ent.AlertQuery, error) {
preds, err := AlertPredicatesFromFilter(filter)
preds, err := alertPredicatesFromFilter(filter)
if err != nil {
return nil, err
}

View file

@ -913,7 +913,7 @@ func (c *Client) DeleteAlertByID(ctx context.Context, id int) error {
}
func (c *Client) DeleteAlertWithFilter(ctx context.Context, filter map[string][]string) (int, error) {
preds, err := AlertPredicatesFromFilter(filter)
preds, err := alertPredicatesFromFilter(filter)
if err != nil {
return 0, err
}

View file

@ -28,9 +28,12 @@ type DecisionsByScenario struct {
}
func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
var (
err error
start_ip, start_sfx, end_ip, end_sfx int64
ip_sz int
)
contains := true
/*if contains is true, return bans that *contains* the given value (value is the inner)
else, return bans that are *contained* by the given value (value is the outer)*/
@ -113,7 +116,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string]
}
}
query, err = applyStartIpEndIpFilter(query, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
query, err = decisionIPFilter(query, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
if err != nil {
return nil, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err)
}
@ -197,8 +200,10 @@ func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*Decisions
}
func (c *Client) QueryDecisionWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) {
var data []*ent.Decision
var err error
var (
err error
data []*ent.Decision
)
decisions := c.Ent.Decision.Query().
Where(decision.UntilGTE(time.Now().UTC()))
@ -318,9 +323,12 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *t
}
func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
var (
err error
start_ip, start_sfx, end_ip, end_sfx int64
ip_sz int
)
contains := true
/*if contains is true, return bans that *contains* the given value (value is the inner)
else, return bans that are *contained* by the given value (value is the outer) */
@ -352,72 +360,9 @@ func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[strin
}
}
if ip_sz == 4 {
if contains { /*decision contains {start_ip,end_ip}*/
decisions = decisions.Where(decision.And(
decision.StartIPLTE(start_ip),
decision.EndIPGTE(end_ip),
decision.IPSizeEQ(int64(ip_sz)),
))
} else { /*decision is contained within {start_ip,end_ip}*/
decisions = decisions.Where(decision.And(
decision.StartIPGTE(start_ip),
decision.EndIPLTE(end_ip),
decision.IPSizeEQ(int64(ip_sz)),
))
}
} else if ip_sz == 16 {
if contains { /*decision contains {start_ip,end_ip}*/
decisions = decisions.Where(decision.And(
// matching addr size
decision.IPSizeEQ(int64(ip_sz)),
decision.Or(
// decision.start_ip < query.start_ip
decision.StartIPLT(start_ip),
decision.And(
// decision.start_ip == query.start_ip
decision.StartIPEQ(start_ip),
// decision.start_suffix <= query.start_suffix
decision.StartSuffixLTE(start_sfx),
)),
decision.Or(
// decision.end_ip > query.end_ip
decision.EndIPGT(end_ip),
decision.And(
// decision.end_ip == query.end_ip
decision.EndIPEQ(end_ip),
// decision.end_suffix >= query.end_suffix
decision.EndSuffixGTE(end_sfx),
),
),
))
} else {
decisions = decisions.Where(decision.And(
// matching addr size
decision.IPSizeEQ(int64(ip_sz)),
decision.Or(
// decision.start_ip > query.start_ip
decision.StartIPGT(start_ip),
decision.And(
// decision.start_ip == query.start_ip
decision.StartIPEQ(start_ip),
// decision.start_suffix >= query.start_suffix
decision.StartSuffixGTE(start_sfx),
)),
decision.Or(
// decision.end_ip < query.end_ip
decision.EndIPLT(end_ip),
decision.And(
// decision.end_ip == query.end_ip
decision.EndIPEQ(end_ip),
// decision.end_suffix <= query.end_suffix
decision.EndSuffixLTE(end_sfx),
),
),
))
}
} else if ip_sz != 0 {
return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
decisions, err = decisionIPFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
if err != nil {
return "0", nil, err
}
toDelete, err := decisions.All(ctx)
@ -437,9 +382,12 @@ func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[strin
// ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items
func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
var (
err error
start_ip, start_sfx, end_ip, end_sfx int64
ip_sz int
)
contains := true
/*if contains is true, return bans that *contains* the given value (value is the inner)
else, return bans that are *contained* by the given value (value is the outer)*/
@ -473,76 +421,10 @@ func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[strin
return "0", nil, errors.Wrapf(InvalidFilter, "'%s' doesn't exist", param)
}
}
if ip_sz == 4 {
if contains {
/*Decision contains {start_ip,end_ip}*/
decisions = decisions.Where(decision.And(
decision.StartIPLTE(start_ip),
decision.EndIPGTE(end_ip),
decision.IPSizeEQ(int64(ip_sz)),
))
} else {
/*Decision is contained within {start_ip,end_ip}*/
decisions = decisions.Where(decision.And(
decision.StartIPGTE(start_ip),
decision.EndIPLTE(end_ip),
decision.IPSizeEQ(int64(ip_sz)),
))
}
} else if ip_sz == 16 {
/*decision contains {start_ip,end_ip}*/
if contains {
decisions = decisions.Where(decision.And(
// matching addr size
decision.IPSizeEQ(int64(ip_sz)),
decision.Or(
// decision.start_ip < query.start_ip
decision.StartIPLT(start_ip),
decision.And(
// decision.start_ip == query.start_ip
decision.StartIPEQ(start_ip),
// decision.start_suffix <= query.start_suffix
decision.StartSuffixLTE(start_sfx),
)),
decision.Or(
// decision.end_ip > query.end_ip
decision.EndIPGT(end_ip),
decision.And(
// decision.end_ip == query.end_ip
decision.EndIPEQ(end_ip),
// decision.end_suffix >= query.end_suffix
decision.EndSuffixGTE(end_sfx),
),
),
))
} else {
/*decision is contained within {start_ip,end_ip}*/
decisions = decisions.Where(decision.And(
// matching addr size
decision.IPSizeEQ(int64(ip_sz)),
decision.Or(
// decision.start_ip > query.start_ip
decision.StartIPGT(start_ip),
decision.And(
// decision.start_ip == query.start_ip
decision.StartIPEQ(start_ip),
// decision.start_suffix >= query.start_suffix
decision.StartSuffixGTE(start_sfx),
)),
decision.Or(
// decision.end_ip < query.end_ip
decision.EndIPLT(end_ip),
decision.And(
// decision.end_ip == query.end_ip
decision.EndIPEQ(end_ip),
// decision.end_suffix <= query.end_suffix
decision.EndSuffixLTE(end_sfx),
),
),
))
}
} else if ip_sz != 0 {
return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
decisions, err = decisionIPFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
if err != nil {
return "0", nil, err
}
DecisionsToDelete, err := decisions.All(ctx)
@ -652,9 +534,11 @@ func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, [
}
func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz, count int
var (
err error
start_ip, start_sfx, end_ip, end_sfx int64
ip_sz, count int
)
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue)
if err != nil {
@ -664,7 +548,7 @@ func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string
contains := true
decisions := c.Ent.Decision.Query()
decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
decisions, err = decisionIPFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
if err != nil {
return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
}
@ -678,9 +562,11 @@ func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string
}
func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz, count int
var (
err error
start_ip, start_sfx, end_ip, end_sfx int64
ip_sz, count int
)
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue)
if err != nil {
@ -690,7 +576,7 @@ func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue
contains := true
decisions := c.Ent.Decision.Query()
decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
decisions, err = decisionIPFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
if err != nil {
return 0, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err)
}
@ -706,9 +592,11 @@ func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue
}
func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
var (
err error
start_ip, start_sfx, end_ip, end_sfx int64
ip_sz int
)
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue)
if err != nil {
@ -720,7 +608,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decision
decision.UntilGT(time.Now().UTC()),
)
decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
decisions, err = decisionIPFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
if err != nil {
return 0, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err)
}
@ -750,7 +638,7 @@ func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue s
decision.CreatedAtGT(since),
)
decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
decisions, err = decisionIPFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
if err != nil {
return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
}
@ -763,88 +651,87 @@ func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue s
return count, nil
}
func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) {
if ip_sz == 4 {
if contains {
/*Decision contains {start_ip,end_ip}*/
decisions = decisions.Where(decision.And(
decision.StartIPLTE(start_ip),
decision.EndIPGTE(end_ip),
decision.IPSizeEQ(int64(ip_sz)),
))
} else {
/*Decision is contained within {start_ip,end_ip}*/
decisions = decisions.Where(decision.And(
decision.StartIPGTE(start_ip),
decision.EndIPLTE(end_ip),
decision.IPSizeEQ(int64(ip_sz)),
))
}
return decisions, nil
func decisionIPv4Filter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) {
if contains {
/*Decision contains {start_ip,end_ip}*/
return decisions.Where(decision.And(
decision.StartIPLTE(start_ip),
decision.EndIPGTE(end_ip),
decision.IPSizeEQ(int64(ip_sz)))), nil
}
if ip_sz == 16 {
/*decision contains {start_ip,end_ip}*/
if contains {
decisions = decisions.Where(decision.And(
// matching addr size
decision.IPSizeEQ(int64(ip_sz)),
decision.Or(
// decision.start_ip < query.start_ip
decision.StartIPLT(start_ip),
decision.And(
// decision.start_ip == query.start_ip
decision.StartIPEQ(start_ip),
// decision.start_suffix <= query.start_suffix
decision.StartSuffixLTE(start_sfx),
)),
decision.Or(
// decision.end_ip > query.end_ip
decision.EndIPGT(end_ip),
decision.And(
// decision.end_ip == query.end_ip
decision.EndIPEQ(end_ip),
// decision.end_suffix >= query.end_suffix
decision.EndSuffixGTE(end_sfx),
),
),
))
} else {
/*decision is contained within {start_ip,end_ip}*/
decisions = decisions.Where(decision.And(
// matching addr size
decision.IPSizeEQ(int64(ip_sz)),
decision.Or(
// decision.start_ip > query.start_ip
decision.StartIPGT(start_ip),
decision.And(
// decision.start_ip == query.start_ip
decision.StartIPEQ(start_ip),
// decision.start_suffix >= query.start_suffix
decision.StartSuffixGTE(start_sfx),
)),
decision.Or(
// decision.end_ip < query.end_ip
decision.EndIPLT(end_ip),
decision.And(
// decision.end_ip == query.end_ip
decision.EndIPEQ(end_ip),
// decision.end_suffix <= query.end_suffix
decision.EndSuffixLTE(end_sfx),
),
),
))
}
/*Decision is contained within {start_ip,end_ip}*/
return decisions.Where(decision.And(
decision.StartIPGTE(start_ip),
decision.EndIPLTE(end_ip),
decision.IPSizeEQ(int64(ip_sz)))), nil
}
return decisions, nil
func decisionIPv6Filter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) {
/*decision contains {start_ip,end_ip}*/
if contains {
return decisions.Where(decision.And(
// matching addr size
decision.IPSizeEQ(int64(ip_sz)),
decision.Or(
// decision.start_ip < query.start_ip
decision.StartIPLT(start_ip),
decision.And(
// decision.start_ip == query.start_ip
decision.StartIPEQ(start_ip),
// decision.start_suffix <= query.start_suffix
decision.StartSuffixLTE(start_sfx),
)),
decision.Or(
// decision.end_ip > query.end_ip
decision.EndIPGT(end_ip),
decision.And(
// decision.end_ip == query.end_ip
decision.EndIPEQ(end_ip),
// decision.end_suffix >= query.end_suffix
decision.EndSuffixGTE(end_sfx),
),
),
)), nil
}
if ip_sz != 0 {
/*decision is contained within {start_ip,end_ip}*/
return decisions.Where(decision.And(
// matching addr size
decision.IPSizeEQ(int64(ip_sz)),
decision.Or(
// decision.start_ip > query.start_ip
decision.StartIPGT(start_ip),
decision.And(
// decision.start_ip == query.start_ip
decision.StartIPEQ(start_ip),
// decision.start_suffix >= query.start_suffix
decision.StartSuffixGTE(start_sfx),
)),
decision.Or(
// decision.end_ip < query.end_ip
decision.EndIPLT(end_ip),
decision.And(
// decision.end_ip == query.end_ip
decision.EndIPEQ(end_ip),
// decision.end_suffix <= query.end_suffix
decision.EndSuffixLTE(end_sfx),
),
),
)), nil
}
func decisionIPFilter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) {
switch ip_sz {
case 4:
return decisionIPv4Filter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
case 16:
return decisionIPv6Filter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
case 0:
return decisions, nil
default:
return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz)
}
return decisions, nil
}
func decisionPredicatesFromStr(s string, predicateFunc func(string) predicate.Decision) []predicate.Decision {