refactor pkg/leakybucket (#3371)

* refact pkg/leakybucket - call LoadBuckets with Item instances
* extract compileScopeFilter()
This commit is contained in:
mmetc 2024-12-20 14:33:24 +01:00 committed by GitHub
parent 26c15a1267
commit 4748720a07
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 122 additions and 122 deletions

View file

@ -86,20 +86,15 @@ func (f *Flags) haveTimeMachine() bool {
type labelsMap map[string]string
func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error {
var (
err error
files []string
)
for _, hubScenarioItem := range hub.GetInstalledByType(cwhub.SCENARIOS, false) {
files = append(files, hubScenarioItem.State.LocalPath)
}
var err error
buckets = leakybucket.NewBuckets()
log.Infof("Loading %d scenario files", len(files))
scenarios := hub.GetInstalledByType(cwhub.SCENARIOS, false)
holders, outputEventChan, err = leakybucket.LoadBuckets(cConfig.Crowdsec, hub, files, &bucketsTomb, buckets, flags.OrderEvent)
log.Infof("Loading %d scenario files", len(scenarios))
holders, outputEventChan, err = leakybucket.LoadBuckets(cConfig.Crowdsec, hub, scenarios, &bucketsTomb, buckets, flags.OrderEvent)
if err != nil {
return fmt.Errorf("scenario loading failed: %w", err)
}

View file

@ -103,13 +103,13 @@ func TestSimulatedAlert(t *testing.T) {
// exclude decision in simulation mode
w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", alertContent, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `)
assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `)
// include decision in simulation mode
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", alertContent, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `)
}
@ -120,21 +120,21 @@ func TestCreateAlert(t *testing.T) {
// Create Alert with invalid format
w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password")
assert.Equal(t, 400, w.Code)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.JSONEq(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Create Alert with invalid input
alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json")
w = lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertContent, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t,
`{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`,
w.Body.String())
// Create Valid Alert
w = lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json")
assert.Equal(t, 201, w.Code)
assert.Equal(t, http.StatusCreated, w.Code)
assert.Equal(t, `["1"]`, w.Body.String())
}
@ -175,13 +175,13 @@ func TestAlertListFilters(t *testing.T) {
// bad filter
w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", alertContent, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
// get without filters
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
// check alert and decision
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
@ -189,149 +189,149 @@ func TestAlertListFilters(t *testing.T) {
// test decision_type filter (ok)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test decision_type filter (bad value)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "null", w.Body.String())
// test scope (ok)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=Ip", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test scope (bad value)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=rarara", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "null", w.Body.String())
// test scenario (ok)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test scenario (bad value)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "null", w.Body.String())
// test ip (ok)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test ip (bad value)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "null", w.Body.String())
// test ip (invalid value)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, `{"message":"invalid ip address 'gruueq'"}`, w.Body.String())
// test range (ok)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test range
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "null", w.Body.String())
// test range (invalid value)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=ratata", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, `{"message":"invalid ip address 'ratata'"}`, w.Body.String())
// test since (ok)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1h", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test since (ok but yields no results)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1ns", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "null", w.Body.String())
// test since (invalid value)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`)
// test until (ok)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1ns", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test until (ok but no return)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1m", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "null", w.Body.String())
// test until (invalid value)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`)
// test simulated (ok)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test simulated (ok)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test has active decision
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
// test has active decision
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "null", w.Body.String())
// test has active decision (invalid value)
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String())
}
@ -343,7 +343,7 @@ func TestAlertBulkInsert(t *testing.T) {
alertContent := GetAlertReaderFromFile(t, "./tests/alert_bulk.json")
w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", alertContent, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestListAlert(t *testing.T) {
@ -353,13 +353,13 @@ func TestListAlert(t *testing.T) {
// List Alert with invalid filter
w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
// List Alert
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "crowdsecurity/test")
}
@ -374,7 +374,7 @@ func TestCreateAlertErrors(t *testing.T) {
req.Header.Add("User-Agent", UserAgent)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "ratata"))
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, http.StatusUnauthorized, w.Code)
// test invalid bearer
w = httptest.NewRecorder()
@ -382,7 +382,7 @@ func TestCreateAlertErrors(t *testing.T) {
req.Header.Add("User-Agent", UserAgent)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s"))
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestDeleteAlert(t *testing.T) {
@ -396,7 +396,7 @@ func TestDeleteAlert(t *testing.T) {
AddAuthHeaders(req, lapi.loginResp)
req.RemoteAddr = "127.0.0.2:4242"
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.JSONEq(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String())
// Delete Alert
@ -405,7 +405,7 @@ func TestDeleteAlert(t *testing.T) {
AddAuthHeaders(req, lapi.loginResp)
req.RemoteAddr = "127.0.0.1:4242"
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String())
}
@ -420,7 +420,7 @@ func TestDeleteAlertByID(t *testing.T) {
AddAuthHeaders(req, lapi.loginResp)
req.RemoteAddr = "127.0.0.2:4242"
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.JSONEq(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String())
// Delete Alert
@ -429,7 +429,7 @@ func TestDeleteAlertByID(t *testing.T) {
AddAuthHeaders(req, lapi.loginResp)
req.RemoteAddr = "127.0.0.1:4242"
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String())
}
@ -463,7 +463,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
req.RemoteAddr = ip + ":1234"
router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.Contains(t, w.Body.String(), fmt.Sprintf(`{"message":"access forbidden from this IP (%s)"}`, ip))
}
@ -474,7 +474,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
req.RemoteAddr = ip + ":1234"
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String())
}

View file

@ -139,14 +139,24 @@ func testOneBucket(t *testing.T, hub *cwhub.Hub, dir string, tomb *tomb.Tomb) er
t.Fatalf("failed to parse %s : %s", stagecfg, err)
}
files := []string{}
scenarios := []*cwhub.Item{}
for _, x := range stages {
files = append(files, x.Filename)
// XXX: LoadBuckets should take an interface, BucketProvider ScenarioProvider or w/e
item := &cwhub.Item{
Name: x.Filename,
State: cwhub.ItemState{
LocalVersion: "",
LocalPath: x.Filename,
LocalHash: "",
},
}
scenarios = append(scenarios, item)
}
cscfg := &csconfig.CrowdsecServiceCfg{}
holders, response, err := LoadBuckets(cscfg, hub, files, tomb, buckets, false)
holders, response, err := LoadBuckets(cscfg, hub, scenarios, tomb, buckets, false)
if err != nil {
t.Fatalf("failed loading bucket : %s", err)
}
@ -184,7 +194,7 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res
}
dec := json.NewDecoder(yamlFile)
dec.DisallowUnknownFields()
//dec.SetStrict(true)
// dec.SetStrict(true)
tf := TestFile{}
err = dec.Decode(&tf)
if err != nil {
@ -196,7 +206,7 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res
}
var latest_ts time.Time
for _, in := range tf.Lines {
//just to avoid any race during ingestion of funny scenarios
// just to avoid any race during ingestion of funny scenarios
time.Sleep(50 * time.Millisecond)
var ts time.Time
@ -226,7 +236,7 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res
time.Sleep(1 * time.Second)
//Read results from chan
// Read results from chan
POLL_AGAIN:
fails := 0
for fails < 2 {
@ -287,37 +297,37 @@ POLL_AGAIN:
log.Tracef("Checking next expected result.")
//empty overflow
// empty overflow
if out.Overflow.Alert == nil && expected.Overflow.Alert == nil {
//match stuff
// match stuff
} else {
if out.Overflow.Alert == nil || expected.Overflow.Alert == nil {
log.Printf("Here ?")
continue
}
//Scenario
// Scenario
if *out.Overflow.Alert.Scenario != *expected.Overflow.Alert.Scenario {
log.Errorf("(scenario) %v != %v", *out.Overflow.Alert.Scenario, *expected.Overflow.Alert.Scenario)
continue
}
log.Infof("(scenario) %v == %v", *out.Overflow.Alert.Scenario, *expected.Overflow.Alert.Scenario)
//EventsCount
// EventsCount
if *out.Overflow.Alert.EventsCount != *expected.Overflow.Alert.EventsCount {
log.Errorf("(EventsCount) %d != %d", *out.Overflow.Alert.EventsCount, *expected.Overflow.Alert.EventsCount)
continue
}
log.Infof("(EventsCount) %d == %d", *out.Overflow.Alert.EventsCount, *expected.Overflow.Alert.EventsCount)
//Sources
// Sources
if !reflect.DeepEqual(out.Overflow.Sources, expected.Overflow.Sources) {
log.Errorf("(Sources %s != %s)", spew.Sdump(out.Overflow.Sources), spew.Sdump(expected.Overflow.Sources))
continue
}
log.Infof("(Sources: %s == %s)", spew.Sdump(out.Overflow.Sources), spew.Sdump(expected.Overflow.Sources))
}
//Events
// Events
// if !reflect.DeepEqual(out.Overflow.Alert.Events, expected.Overflow.Alert.Events) {
// log.Errorf("(Events %s != %s)", spew.Sdump(out.Overflow.Alert.Events), spew.Sdump(expected.Overflow.Alert.Events))
// valid = false
@ -326,10 +336,10 @@ POLL_AGAIN:
// log.Infof("(Events: %s == %s)", spew.Sdump(out.Overflow.Alert.Events), spew.Sdump(expected.Overflow.Alert.Events))
// }
//CheckFailed:
// CheckFailed:
log.Warningf("The test is valid, remove entry %d from expects, and %d from t.Results", eidx, ridx)
//don't do this at home : delete current element from list and redo
// don't do this at home : delete current element from list and redo
results[eidx] = results[len(results)-1]
results = results[:len(results)-1]
tf.Results[ridx] = tf.Results[len(tf.Results)-1]

View file

@ -7,7 +7,6 @@ import (
"io"
"os"
"path/filepath"
"strings"
"sync"
"time"
@ -201,44 +200,41 @@ func ValidateFactory(bucketFactory *BucketFactory) error {
return fmt.Errorf("unknown bucket type '%s'", bucketFactory.Type)
}
switch bucketFactory.ScopeType.Scope {
case types.Undefined:
return compileScopeFilter(bucketFactory)
}
func compileScopeFilter(bucketFactory *BucketFactory) error {
if bucketFactory.ScopeType.Scope == types.Undefined {
bucketFactory.ScopeType.Scope = types.Ip
case types.Ip:
case types.Range:
var (
runTimeFilter *vm.Program
err error
)
if bucketFactory.ScopeType.Filter != "" {
if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil {
return fmt.Errorf("error compiling the scope filter: %w", err)
}
bucketFactory.ScopeType.RunTimeFilter = runTimeFilter
}
default:
// Compile the scope filter
var (
runTimeFilter *vm.Program
err error
)
if bucketFactory.ScopeType.Filter != "" {
if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil {
return fmt.Errorf("error compiling the scope filter: %w", err)
}
bucketFactory.ScopeType.RunTimeFilter = runTimeFilter
}
}
if bucketFactory.ScopeType.Scope == types.Ip {
if bucketFactory.ScopeType.Filter != "" {
return errors.New("filter is not allowed for IP scope")
}
return nil
}
if bucketFactory.ScopeType.Scope == types.Range && bucketFactory.ScopeType.Filter == "" {
return nil
}
if bucketFactory.ScopeType.Filter == "" {
return errors.New("filter is mandatory for non-IP, non-Range scope")
}
runTimeFilter, err := expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...)
if err != nil {
return fmt.Errorf("error compiling the scope filter: %w", err)
}
bucketFactory.ScopeType.RunTimeFilter = runTimeFilter
return nil
}
func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []string, tomb *tomb.Tomb, buckets *Buckets, orderEvent bool) ([]BucketFactory, chan types.Event, error) {
func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, scenarios []*cwhub.Item, tomb *tomb.Tomb, buckets *Buckets, orderEvent bool) ([]BucketFactory, chan types.Event, error) {
var (
ret = []BucketFactory{}
response chan types.Event
@ -246,18 +242,15 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str
response = make(chan types.Event, 1)
for _, f := range files {
log.Debugf("Loading '%s'", f)
for _, item := range scenarios {
log.Debugf("Loading '%s'", item.State.LocalPath)
if !strings.HasSuffix(f, ".yaml") && !strings.HasSuffix(f, ".yml") {
log.Debugf("Skipping %s : not a yaml file", f)
continue
}
itemPath := item.State.LocalPath
// process the yaml
bucketConfigurationFile, err := os.Open(f)
bucketConfigurationFile, err := os.Open(itemPath)
if err != nil {
log.Errorf("Can't access leaky configuration file %s", f)
log.Errorf("Can't access leaky configuration file %s", itemPath)
return nil, nil, err
}
@ -271,8 +264,8 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str
err = dec.Decode(&bucketFactory)
if err != nil {
if !errors.Is(err, io.EOF) {
log.Errorf("Bad yaml in %s: %v", f, err)
return nil, nil, fmt.Errorf("bad yaml in %s: %w", f, err)
log.Errorf("Bad yaml in %s: %v", itemPath, err)
return nil, nil, fmt.Errorf("bad yaml in %s: %w", itemPath, err)
}
log.Tracef("End of yaml file")
@ -288,7 +281,7 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str
}
// check compat
if bucketFactory.FormatVersion == "" {
log.Tracef("no version in %s : %s, assuming '1.0'", bucketFactory.Name, f)
log.Tracef("no version in %s : %s, assuming '1.0'", bucketFactory.Name, itemPath)
bucketFactory.FormatVersion = "1.0"
}
@ -302,22 +295,17 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str
continue
}
bucketFactory.Filename = filepath.Clean(f)
bucketFactory.Filename = filepath.Clean(itemPath)
bucketFactory.BucketName = seed.Generate()
bucketFactory.ret = response
hubItem := hub.GetItemByPath(bucketFactory.Filename)
if hubItem == nil {
log.Errorf("scenario %s (%s) could not be found in hub (ignore if in unit tests)", bucketFactory.Name, bucketFactory.Filename)
} else {
if cscfg.SimulationConfig != nil {
bucketFactory.Simulated = cscfg.SimulationConfig.IsSimulated(hubItem.Name)
}
bucketFactory.ScenarioVersion = hubItem.State.LocalVersion
bucketFactory.hash = hubItem.State.LocalHash
if cscfg.SimulationConfig != nil {
bucketFactory.Simulated = cscfg.SimulationConfig.IsSimulated(item.Name)
}
bucketFactory.ScenarioVersion = item.State.LocalVersion
bucketFactory.hash = item.State.LocalHash
bucketFactory.wgDumpState = buckets.wgDumpState
bucketFactory.wgPour = buckets.wgPour

View file

@ -9,6 +9,7 @@ import (
func TestIP2Int(t *testing.T) {
tEmpty := net.IP{}
_, _, _, err := IP2Ints(tEmpty)
if !strings.Contains(err.Error(), "unexpected len 0 for <nil>") {
t.Fatalf("unexpected: %s", err)
@ -189,31 +190,37 @@ func TestAdd2Int(t *testing.T) {
if err != nil && test.exp_error == "" {
t.Fatalf("%d unexpected error : %s", idx, err)
}
if test.exp_error != "" {
if !strings.Contains(err.Error(), test.exp_error) {
t.Fatalf("%d unmatched error : %s != %s", idx, err, test.exp_error)
}
continue //we can skip this one
continue // we can skip this one
}
if sz != test.exp_sz {
t.Fatalf("%d unexpected size %d != %d", idx, sz, test.exp_sz)
}
if start_ip != test.exp_start_ip {
t.Fatalf("%d unexpected start_ip %d != %d", idx, start_ip, test.exp_start_ip)
}
if sz == 16 {
if start_sfx != test.exp_start_sfx {
t.Fatalf("%d unexpected start sfx %d != %d", idx, start_sfx, test.exp_start_sfx)
}
}
if end_ip != test.exp_end_ip {
t.Fatalf("%d unexpected end ip %d != %d", idx, end_ip, test.exp_end_ip)
}
if sz == 16 {
if end_sfx != test.exp_end_sfx {
t.Fatalf("%d unexpected end sfx %d != %d", idx, end_sfx, test.exp_end_sfx)
}
}
}
}