fix: avoid possible race condition while compiling expressions (#3582)

This commit is contained in:
mmetc 2025-04-17 17:34:40 +02:00 committed by GitHub
parent 4004868245
commit a0fab0ac5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -57,6 +57,13 @@ var dbClient *database.Client
var exprFunctionOptions []expr.Option var exprFunctionOptions []expr.Option
func init() { //nolint:gochecknoinits
exprFunctionOptions = make([]expr.Option, len(exprFuncs))
for i, fn := range exprFuncs {
exprFunctionOptions[i] = expr.Function(fn.name, fn.function, fn.signature...)
}
}
var keyValuePattern = regexp.MustCompile(`(?P<key>[^=\s]+)=(?:"(?P<quoted_value>[^"\\]*(?:\\.[^"\\]*)*)"|(?P<value>[^=\s]+)|\s*)`) var keyValuePattern = regexp.MustCompile(`(?P<key>[^=\s]+)=(?:"(?P<quoted_value>[^"\\]*(?:\\.[^"\\]*)*)"|(?P<value>[^=\s]+)|\s*)`)
var ( var (
@ -65,23 +72,13 @@ var (
geoIPRangeReader *maxminddb.Reader geoIPRangeReader *maxminddb.Reader
) )
func GetExprOptions(ctx map[string]interface{}) []expr.Option { func GetExprOptions(ctx map[string]any) []expr.Option {
if len(exprFunctionOptions) == 0 { // copy the prebuilt options + one Env(...) for this call
exprFunctionOptions = []expr.Option{} opts := make([]expr.Option, len(exprFunctionOptions)+1)
for _, function := range exprFuncs { copy(opts, exprFunctionOptions)
exprFunctionOptions = append(exprFunctionOptions, opts[len(opts)-1] = expr.Env(ctx)
expr.Function(function.name,
function.function,
function.signature...,
))
}
}
ret := []expr.Option{} return opts
ret = append(ret, exprFunctionOptions...)
ret = append(ret, expr.Env(ctx))
return ret
} }
func GeoIPInit(datadir string) error { func GeoIPInit(datadir string) error {
@ -199,6 +196,7 @@ func FileInit(fileFolder string, filename string, fileType string) error {
log.Debugf("ignored file %s%s because already loaded", fileFolder, filename) log.Debugf("ignored file %s%s because already loaded", fileFolder, filename)
return nil return nil
} }
if err != nil { if err != nil {
return err return err
} }
@ -244,13 +242,13 @@ func Distinct(params ...any) (any, error) {
return nil, nil return nil, nil
} }
array := params[0].([]interface{}) array := params[0].([]any)
if array == nil { if array == nil {
return []interface{}{}, nil return []any{}, nil
} }
exists := make(map[any]bool) exists := make(map[any]bool)
ret := make([]interface{}, 0) ret := make([]any, 0)
for _, val := range array { for _, val := range array {
if _, ok := exists[val]; !ok { if _, ok := exists[val]; !ok {
@ -270,7 +268,7 @@ func Flatten(params ...any) (any, error) {
return flatten(nil, reflect.ValueOf(params)), nil return flatten(nil, reflect.ValueOf(params)), nil
} }
func flatten(args []interface{}, v reflect.Value) []interface{} { func flatten(args []any, v reflect.Value) []any {
if v.Kind() == reflect.Interface { if v.Kind() == reflect.Interface {
v = v.Elem() v = v.Elem()
} }
@ -501,9 +499,11 @@ func RegexpInFile(params ...any) (any, error) {
// func IpInRange(ip string, ipRange string) bool { // func IpInRange(ip string, ipRange string) bool {
func IpInRange(params ...any) (any, error) { func IpInRange(params ...any) (any, error) {
var err error var (
var ipParsed net.IP err error
var ipRangeParsed *net.IPNet ipParsed net.IP
ipRangeParsed *net.IPNet
)
ip := params[0].(string) ip := params[0].(string)
ipRange := params[1].(string) ipRange := params[1].(string)
@ -513,13 +513,16 @@ func IpInRange(params ...any) (any, error) {
log.Debugf("'%s' is not a valid IP", ip) log.Debugf("'%s' is not a valid IP", ip)
return false, nil return false, nil
} }
if _, ipRangeParsed, err = net.ParseCIDR(ipRange); err != nil { if _, ipRangeParsed, err = net.ParseCIDR(ipRange); err != nil {
log.Debugf("'%s' is not a valid IP Range", ipRange) log.Debugf("'%s' is not a valid IP Range", ipRange)
return false, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility return false, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility
} }
if ipRangeParsed.Contains(ipParsed) { if ipRangeParsed.Contains(ipParsed) {
return true, nil return true, nil
} }
return false, nil return false, nil
} }
@ -527,6 +530,7 @@ func IpInRange(params ...any) (any, error) {
func IsIPV6(params ...any) (any, error) { func IsIPV6(params ...any) (any, error) {
ip := params[0].(string) ip := params[0].(string)
ipParsed := net.ParseIP(ip) ipParsed := net.ParseIP(ip)
if ipParsed == nil { if ipParsed == nil {
log.Debugf("'%s' is not a valid IP", ip) log.Debugf("'%s' is not a valid IP", ip)
return false, nil return false, nil
@ -540,10 +544,12 @@ func IsIPV6(params ...any) (any, error) {
func IsIPV4(params ...any) (any, error) { func IsIPV4(params ...any) (any, error) {
ip := params[0].(string) ip := params[0].(string)
ipParsed := net.ParseIP(ip) ipParsed := net.ParseIP(ip)
if ipParsed == nil { if ipParsed == nil {
log.Debugf("'%s' is not a valid IP", ip) log.Debugf("'%s' is not a valid IP", ip)
return false, nil return false, nil
} }
return ipParsed.To4() != nil, nil return ipParsed.To4() != nil, nil
} }
@ -551,10 +557,12 @@ func IsIPV4(params ...any) (any, error) {
func IsIP(params ...any) (any, error) { func IsIP(params ...any) (any, error) {
ip := params[0].(string) ip := params[0].(string)
ipParsed := net.ParseIP(ip) ipParsed := net.ParseIP(ip)
if ipParsed == nil { if ipParsed == nil {
log.Debugf("'%s' is not a valid IP", ip) log.Debugf("'%s' is not a valid IP", ip)
return false, nil return false, nil
} }
return true, nil return true, nil
} }
@ -563,6 +571,7 @@ func IpToRange(params ...any) (any, error) {
ip := params[0].(string) ip := params[0].(string)
cidr := params[1].(string) cidr := params[1].(string)
cidr = strings.TrimPrefix(cidr, "/") cidr = strings.TrimPrefix(cidr, "/")
mask, err := strconv.Atoi(cidr) mask, err := strconv.Atoi(cidr)
if err != nil { if err != nil {
log.Errorf("bad cidr '%s': %s", cidr, err) log.Errorf("bad cidr '%s': %s", cidr, err)
@ -574,11 +583,13 @@ func IpToRange(params ...any) (any, error) {
log.Errorf("can't parse IP address '%s'", ip) log.Errorf("can't parse IP address '%s'", ip)
return "", nil return "", nil
} }
ipRange := iplib.NewNet(ipAddr, mask) ipRange := iplib.NewNet(ipAddr, mask)
if ipRange.IP() == nil { if ipRange.IP() == nil {
log.Errorf("can't get cidr '%s' of '%s'", cidr, ip) log.Errorf("can't get cidr '%s' of '%s'", cidr, ip)
return "", nil return "", nil
} }
return ipRange.String(), nil return ipRange.String(), nil
} }
@ -591,37 +602,42 @@ func TimeNow(params ...any) (any, error) {
func ParseUri(params ...any) (any, error) { func ParseUri(params ...any) (any, error) {
uri := params[0].(string) uri := params[0].(string)
ret := make(map[string][]string) ret := make(map[string][]string)
u, err := url.Parse(uri) u, err := url.Parse(uri)
if err != nil { if err != nil {
log.Errorf("Could not parse URI: %s", err) log.Errorf("Could not parse URI: %s", err)
return ret, nil return ret, nil
} }
parsed, err := url.ParseQuery(u.RawQuery) parsed, err := url.ParseQuery(u.RawQuery)
if err != nil { if err != nil {
log.Errorf("Could not parse query uri : %s", err) log.Errorf("Could not parse query uri : %s", err)
return ret, nil return ret, nil
} }
for k, v := range parsed { for k, v := range parsed {
ret[k] = v ret[k] = v
} }
return ret, nil return ret, nil
} }
// func KeyExists(key string, dict map[string]interface{}) bool { // func KeyExists(key string, dict map[string]interface{}) bool {
func KeyExists(params ...any) (any, error) { func KeyExists(params ...any) (any, error) {
key := params[0].(string) key := params[0].(string)
dict := params[1].(map[string]interface{}) dict := params[1].(map[string]any)
_, ok := dict[key] _, ok := dict[key]
return ok, nil return ok, nil
} }
// func GetDecisionsCount(value string) int { // func GetDecisionsCount(value string) int {
func GetDecisionsCount(params ...any) (any, error) { func GetDecisionsCount(params ...any) (any, error) {
value := params[0].(string) value := params[0].(string)
if dbClient == nil { if dbClient == nil {
log.Error("No database config to call GetDecisionsCount()") log.Error("No database config to call GetDecisionsCount()")
return 0, nil return 0, nil
} }
ctx := context.TODO() ctx := context.TODO()
@ -631,6 +647,7 @@ func GetDecisionsCount(params ...any) (any, error) {
log.Errorf("Failed to get decisions count from value '%s'", value) log.Errorf("Failed to get decisions count from value '%s'", value)
return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility
} }
return count, nil return count, nil
} }
@ -638,10 +655,12 @@ func GetDecisionsCount(params ...any) (any, error) {
func GetDecisionsSinceCount(params ...any) (any, error) { func GetDecisionsSinceCount(params ...any) (any, error) {
value := params[0].(string) value := params[0].(string)
since := params[1].(string) since := params[1].(string)
if dbClient == nil { if dbClient == nil {
log.Error("No database config to call GetDecisionsSinceCount()") log.Error("No database config to call GetDecisionsSinceCount()")
return 0, nil return 0, nil
} }
sinceDuration, err := time.ParseDuration(since) sinceDuration, err := time.ParseDuration(since)
if err != nil { if err != nil {
log.Errorf("Failed to parse since parameter '%s' : %s", since, err) log.Errorf("Failed to parse since parameter '%s' : %s", since, err)
@ -656,79 +675,95 @@ func GetDecisionsSinceCount(params ...any) (any, error) {
log.Errorf("Failed to get decisions count from value '%s'", value) log.Errorf("Failed to get decisions count from value '%s'", value)
return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility
} }
return count, nil return count, nil
} }
func GetActiveDecisionsCount(params ...any) (any, error) { func GetActiveDecisionsCount(params ...any) (any, error) {
value := params[0].(string) value := params[0].(string)
if dbClient == nil { if dbClient == nil {
log.Error("No database config to call GetActiveDecisionsCount()") log.Error("No database config to call GetActiveDecisionsCount()")
return 0, nil return 0, nil
} }
ctx := context.TODO() ctx := context.TODO()
count, err := dbClient.CountActiveDecisionsByValue(ctx, value) count, err := dbClient.CountActiveDecisionsByValue(ctx, value)
if err != nil { if err != nil {
log.Errorf("Failed to get active decisions count from value '%s'", value) log.Errorf("Failed to get active decisions count from value '%s'", value)
return 0, err return 0, err
} }
return count, nil return count, nil
} }
func GetActiveDecisionsTimeLeft(params ...any) (any, error) { func GetActiveDecisionsTimeLeft(params ...any) (any, error) {
value := params[0].(string) value := params[0].(string)
if dbClient == nil { if dbClient == nil {
log.Error("No database config to call GetActiveDecisionsTimeLeft()") log.Error("No database config to call GetActiveDecisionsTimeLeft()")
return 0, nil return 0, nil
} }
ctx := context.TODO() ctx := context.TODO()
timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value) timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value)
if err != nil { if err != nil {
log.Errorf("Failed to get active decisions time left from value '%s'", value) log.Errorf("Failed to get active decisions time left from value '%s'", value)
return 0, err return 0, err
} }
return timeLeft, nil return timeLeft, nil
} }
// func LookupHost(value string) []string { // func LookupHost(value string) []string {
func LookupHost(params ...any) (any, error) { func LookupHost(params ...any) (any, error) {
value := params[0].(string) value := params[0].(string)
addresses, err := net.LookupHost(value) addresses, err := net.LookupHost(value)
if err != nil { if err != nil {
log.Errorf("Failed to lookup host '%s' : %s", value, err) log.Errorf("Failed to lookup host '%s' : %s", value, err)
return []string{}, nil return []string{}, nil
} }
return addresses, nil return addresses, nil
} }
// func ParseUnixTime(value string) (time.Time, error) { // func ParseUnixTime(value string) (time.Time, error) {
func ParseUnixTime(params ...any) (any, error) { func ParseUnixTime(params ...any) (any, error) {
value := params[0].(string) value := params[0].(string)
//Splitting string here as some unix timestamp may have milliseconds and break ParseInt // Splitting string here as some unix timestamp may have milliseconds and break ParseInt
i, err := strconv.ParseInt(strings.Split(value, ".")[0], 10, 64) i, err := strconv.ParseInt(strings.Split(value, ".")[0], 10, 64)
if err != nil || i <= 0 { if err != nil || i <= 0 {
return time.Time{}, fmt.Errorf("unable to parse %s as unix timestamp", value) return time.Time{}, fmt.Errorf("unable to parse %s as unix timestamp", value)
} }
return time.Unix(i, 0), nil return time.Unix(i, 0), nil
} }
// func ParseUnix(value string) string { // func ParseUnix(value string) string {
func ParseUnix(params ...any) (any, error) { func ParseUnix(params ...any) (any, error) {
value := params[0].(string) value := params[0].(string)
t, err := ParseUnixTime(value) t, err := ParseUnixTime(value)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return "", nil return "", nil
} }
return t.(time.Time).Format(time.RFC3339), nil return t.(time.Time).Format(time.RFC3339), nil
} }
// func ToString(value interface{}) string { // func ToString(value interface{}) string {
func ToString(params ...any) (any, error) { func ToString(params ...any) (any, error) {
value := params[0] value := params[0]
s, ok := value.(string) s, ok := value.(string)
if !ok { if !ok {
return "", nil return "", nil
} }
return s, nil return s, nil
} }
@ -736,6 +771,7 @@ func ToString(params ...any) (any, error) {
func GetFromStash(params ...any) (any, error) { func GetFromStash(params ...any) (any, error) {
cacheName := params[0].(string) cacheName := params[0].(string)
key := params[1].(string) key := params[1].(string)
return cache.GetKey(cacheName, key) return cache.GetKey(cacheName, key)
} }
@ -745,6 +781,7 @@ func SetInStash(params ...any) (any, error) {
key := params[1].(string) key := params[1].(string)
value := params[2].(string) value := params[2].(string)
expiration := params[3].(*time.Duration) expiration := params[3].(*time.Duration)
return cache.SetKey(cacheName, key, value, expiration), nil return cache.SetKey(cacheName, key, value, expiration), nil
} }
@ -763,12 +800,15 @@ func Match(params ...any) (any, error) {
if pattern == "" { if pattern == "" {
return name == "", nil return name == "", nil
} }
if name == "" { if name == "" {
if pattern == "*" || pattern == "" { if pattern == "*" || pattern == "" {
return true, nil return true, nil
} }
return false, nil return false, nil
} }
if pattern[0] == '*' { if pattern[0] == '*' {
for i := 0; i <= len(name); i++ { for i := 0; i <= len(name); i++ {
matched, _ := Match(pattern[1:], name[i:]) matched, _ := Match(pattern[1:], name[i:])
@ -776,11 +816,14 @@ func Match(params ...any) (any, error) {
return matched, nil return matched, nil
} }
} }
return matched, nil return matched, nil
} }
if pattern[0] == '?' || pattern[0] == name[0] { if pattern[0] == '?' || pattern[0] == name[0] {
return Match(pattern[1:], name[1:]) return Match(pattern[1:], name[1:])
} }
return matched, nil return matched, nil
} }
@ -791,21 +834,24 @@ func FloatApproxEqual(params ...any) (any, error) {
if math.Abs(float1-float2) < 1e-6 { if math.Abs(float1-float2) < 1e-6 {
return true, nil return true, nil
} }
return false, nil return false, nil
} }
func B64Decode(params ...any) (any, error) { func B64Decode(params ...any) (any, error) {
encoded := params[0].(string) encoded := params[0].(string)
decoded, err := base64.StdEncoding.DecodeString(encoded) decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil { if err != nil {
return "", err return "", err
} }
return string(decoded), nil return string(decoded), nil
} }
func ParseKV(params ...any) (any, error) { func ParseKV(params ...any) (any, error) {
blob := params[0].(string) blob := params[0].(string)
target := params[1].(map[string]interface{}) target := params[1].(map[string]any)
prefix := params[2].(string) prefix := params[2].(string)
matches := keyValuePattern.FindAllStringSubmatch(blob, -1) matches := keyValuePattern.FindAllStringSubmatch(blob, -1)
@ -813,6 +859,7 @@ func ParseKV(params ...any) (any, error) {
log.Errorf("could not find any key/value pair in line") log.Errorf("could not find any key/value pair in line")
return nil, errors.New("invalid input format") return nil, errors.New("invalid input format")
} }
if _, ok := target[prefix]; !ok { if _, ok := target[prefix]; !ok {
target[prefix] = make(map[string]string) target[prefix] = make(map[string]string)
} else { } else {
@ -822,9 +869,11 @@ func ParseKV(params ...any) (any, error) {
return nil, errors.New("target is not a map[string]string") return nil, errors.New("target is not a map[string]string")
} }
} }
for _, match := range matches { for _, match := range matches {
key := "" key := ""
value := "" value := ""
for i, name := range keyValuePattern.SubexpNames() { for i, name := range keyValuePattern.SubexpNames() {
if name == "key" { if name == "key" {
key = match[i] key = match[i]
@ -834,9 +883,12 @@ func ParseKV(params ...any) (any, error) {
value = match[i] value = match[i]
} }
} }
target[prefix].(map[string]string)[key] = value target[prefix].(map[string]string)[key] = value
} }
log.Tracef("unmarshaled KV: %+v", target[prefix]) log.Tracef("unmarshaled KV: %+v", target[prefix])
return nil, nil return nil, nil
} }
@ -845,5 +897,6 @@ func Hostname(params ...any) (any, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return hostname, nil return hostname, nil
} }