From ecf34c2fa12e54798c7ef08ee5280332a82f1804 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:43:19 +0100 Subject: [PATCH] lint/deep-exit: avoid log.Fatal (#3367) * lint/deep-exit: don't fail on invalid alert * lint/deep-exit: kinesis_test.go * lint/deep-exit: watcher_test.go * lint/deep-exit: parsing_test.go * lint/deep-exit: client_test.go --- .golangci.yml | 11 --- .../modules/kinesis/kinesis_test.go | 89 ++++++++++--------- pkg/apiclient/client_test.go | 12 ++- pkg/csplugin/watcher_test.go | 13 ++- pkg/leakybucket/bucket.go | 2 +- pkg/leakybucket/overflows.go | 5 +- pkg/parser/parsing_test.go | 79 ++++++---------- pkg/setup/detect_test.go | 2 +- 8 files changed, 87 insertions(+), 126 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 097cc86d2..d0fdd3b37 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -402,12 +402,6 @@ issues: path: "pkg/(.+)_test.go" text: "line-length-limit: .*" - # tolerate deep exit in tests, for now - - linters: - - revive - path: "pkg/(.+)_test.go" - text: "deep-exit: .*" - # we use t,ctx instead of ctx,t in tests - linters: - revive @@ -420,11 +414,6 @@ issues: path: "cmd/crowdsec-cli/main.go" text: "deep-exit: .*" - - linters: - - revive - path: "pkg/leakybucket/overflows.go" - text: "deep-exit: .*" - - linters: - revive path: "cmd/crowdsec/crowdsec.go" diff --git a/pkg/acquisition/modules/kinesis/kinesis_test.go b/pkg/acquisition/modules/kinesis/kinesis_test.go index 027cbde92..778dda4a6 100644 --- a/pkg/acquisition/modules/kinesis/kinesis_test.go +++ b/pkg/acquisition/modules/kinesis/kinesis_test.go @@ -9,6 +9,7 @@ import ( "net" "os" "runtime" + "strconv" "strings" "testing" "time" @@ -18,6 +19,7 @@ import ( "github.com/aws/aws-sdk-go/service/kinesis" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -28,17 +30,20 @@ import ( func getLocalStackEndpoint() (string, error) { endpoint := "http://localhost:4566" + if v := os.Getenv("AWS_ENDPOINT_FORCE"); v != "" { v = strings.TrimPrefix(v, "http://") + _, err := net.Dial("tcp", v) if err != nil { return "", fmt.Errorf("while dialing %s: %w: aws endpoint isn't available", v, err) } } + return endpoint, nil } -func GenSubObject(i int) []byte { +func GenSubObject(t *testing.T, i int) []byte { r := CloudWatchSubscriptionRecord{ MessageType: "subscription", Owner: "test", @@ -48,15 +53,14 @@ func GenSubObject(i int) []byte { LogEvents: []CloudwatchSubscriptionLogEvent{ { ID: "testid", - Message: fmt.Sprintf("%d", i), + Message: strconv.Itoa(i), Timestamp: time.Now().UTC().Unix(), }, }, } body, err := json.Marshal(r) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + var b bytes.Buffer gz := gzip.NewWriter(&b) gz.Write(body) @@ -66,33 +70,33 @@ func GenSubObject(i int) []byte { return b.Bytes() } -func WriteToStream(streamName string, count int, shards int, sub bool) { +func WriteToStream(t *testing.T, streamName string, count int, shards int, sub bool) { endpoint, err := getLocalStackEndpoint() - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + sess := session.Must(session.NewSession()) kinesisClient := kinesis.New(sess, aws.NewConfig().WithEndpoint(endpoint).WithRegion("us-east-1")) + for i := range count { partition := "partition" if shards != 1 { partition = fmt.Sprintf("partition-%d", i%shards) } + var data []byte + if sub { - data = GenSubObject(i) + data = GenSubObject(t, i) } else { - data = []byte(fmt.Sprintf("%d", i)) + data = []byte(strconv.Itoa(i)) } + _, err = kinesisClient.PutRecord(&kinesis.PutRecordInput{ Data: data, PartitionKey: aws.String(partition), StreamName: aws.String(streamName), }) - if err != nil { - fmt.Printf("Error writing to stream: %s\n", err) - log.Fatal(err) - } + require.NoError(t, err) } } @@ -111,6 +115,7 @@ func TestBadConfiguration(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -142,6 +147,7 @@ stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, } subLogger := log.WithField("type", "kinesis") + for _, test := range tests { f := KinesisSource{} err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) @@ -151,9 +157,11 @@ stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, func TestReadFromStream(t *testing.T) { ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string count int @@ -169,26 +177,26 @@ stream_name: stream-1-shard`, }, } endpoint, _ := getLocalStackEndpoint() + for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) err := f.Configure([]byte(config), log.WithField("type", "kinesis"), configuration.METRICS_NONE) - if err != nil { - t.Fatalf("Error configuring source: %s", err) - } + require.NoError(t, err) + tomb := &tomb.Tomb{} out := make(chan types.Event) err = f.StreamingAcquisition(ctx, out, tomb) - if err != nil { - t.Fatalf("Error starting source: %s", err) - } + require.NoError(t, err) // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) - WriteToStream(f.Config.StreamName, test.count, test.shards, false) + WriteToStream(t, f.Config.StreamName, test.count, test.shards, false) + for i := range test.count { e := <-out - assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw) + assert.Equal(t, strconv.Itoa(i), e.Line.Raw) } + tomb.Kill(nil) tomb.Wait() } @@ -196,9 +204,11 @@ stream_name: stream-1-shard`, func TestReadFromMultipleShards(t *testing.T) { ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string count int @@ -214,23 +224,22 @@ stream_name: stream-2-shards`, }, } endpoint, _ := getLocalStackEndpoint() + for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) err := f.Configure([]byte(config), log.WithField("type", "kinesis"), configuration.METRICS_NONE) - if err != nil { - t.Fatalf("Error configuring source: %s", err) - } + require.NoError(t, err) tomb := &tomb.Tomb{} out := make(chan types.Event) err = f.StreamingAcquisition(ctx, out, tomb) - if err != nil { - t.Fatalf("Error starting source: %s", err) - } + require.NoError(t, err) // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) - WriteToStream(f.Config.StreamName, test.count, test.shards, false) + WriteToStream(t, f.Config.StreamName, test.count, test.shards, false) + c := 0 + for range test.count { <-out c += 1 @@ -243,9 +252,11 @@ stream_name: stream-2-shards`, func TestFromSubscription(t *testing.T) { ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string count int @@ -266,18 +277,14 @@ from_subscription: true`, f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) err := f.Configure([]byte(config), log.WithField("type", "kinesis"), configuration.METRICS_NONE) - if err != nil { - t.Fatalf("Error configuring source: %s", err) - } + require.NoError(t, err) tomb := &tomb.Tomb{} out := make(chan types.Event) err = f.StreamingAcquisition(ctx, out, tomb) - if err != nil { - t.Fatalf("Error starting source: %s", err) - } + require.NoError(t, err) // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) - WriteToStream(f.Config.StreamName, test.count, test.shards, true) + WriteToStream(t, f.Config.StreamName, test.count, test.shards, true) for i := range test.count { e := <-out assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw) @@ -310,15 +317,11 @@ use_enhanced_fanout: true`, f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) err := f.Configure([]byte(config), log.WithField("type", "kinesis")) - if err != nil { - t.Fatalf("Error configuring source: %s", err) - } + require.NoError(t, err) tomb := &tomb.Tomb{} out := make(chan types.Event) err = f.StreamingAcquisition(out, tomb) - if err != nil { - t.Fatalf("Error starting source: %s", err) - } + require.NoError(t, err) //Allow the datasource to start listening to the stream time.Sleep(10 * time.Second) WriteToStream("stream-1-shard", test.count, test.shards) diff --git a/pkg/apiclient/client_test.go b/pkg/apiclient/client_test.go index d1f58f33a..327bf8fbd 100644 --- a/pkg/apiclient/client_test.go +++ b/pkg/apiclient/client_test.go @@ -56,13 +56,11 @@ func toUNCPath(path string) (string, error) { return uncPath, nil } -func setupUnixSocketWithPrefix(socket string, urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) { +func setupUnixSocketWithPrefix(t *testing.T, socket string, urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) { var err error if runtime.GOOS == "windows" { socket, err = toUNCPath(socket) - if err != nil { - log.Fatalf("converting to UNC path: %s", err) - } + require.NoError(t, err, "converting to UNC path") } mux = http.NewServeMux() @@ -120,7 +118,7 @@ func TestNewClientOk_UnixSocket(t *testing.T) { tmpDir := t.TempDir() socket := path.Join(tmpDir, "socket") - mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1") + mux, urlx, teardown := setupUnixSocketWithPrefix(t, socket, "v1") defer teardown() apiURL, err := url.Parse(urlx) @@ -215,7 +213,7 @@ func TestNewDefaultClient_UnixSocket(t *testing.T) { tmpDir := t.TempDir() socket := path.Join(tmpDir, "socket") - mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1") + mux, urlx, teardown := setupUnixSocketWithPrefix(t, socket, "v1") defer teardown() apiURL, err := url.Parse(urlx) @@ -293,7 +291,7 @@ func TestNewClientRegisterOK_UnixSocket(t *testing.T) { tmpDir := t.TempDir() socket := path.Join(tmpDir, "socket") - mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1") + mux, urlx, teardown := setupUnixSocketWithPrefix(t, socket, "v1") defer teardown() /*mock login*/ diff --git a/pkg/csplugin/watcher_test.go b/pkg/csplugin/watcher_test.go index 84e63ec64..9868b8433 100644 --- a/pkg/csplugin/watcher_test.go +++ b/pkg/csplugin/watcher_test.go @@ -15,13 +15,12 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -func resetTestTomb(testTomb *tomb.Tomb, pw *PluginWatcher) { +func resetTestTomb(t *testing.T, testTomb *tomb.Tomb, pw *PluginWatcher) { testTomb.Kill(nil) <-pw.PluginEvents - if err := testTomb.Wait(); err != nil { - log.Fatal(err) - } + err := testTomb.Wait() + require.NoError(t, err) } func resetWatcherAlertCounter(pw *PluginWatcher) { @@ -72,7 +71,7 @@ func TestPluginWatcherInterval(t *testing.T) { err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") - resetTestTomb(&testTomb, &pw) + resetTestTomb(t, &testTomb, &pw) testTomb = tomb.Tomb{} pw.Start(&testTomb) @@ -81,7 +80,7 @@ func TestPluginWatcherInterval(t *testing.T) { err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) - resetTestTomb(&testTomb, &pw) + resetTestTomb(t, &testTomb, &pw) // This is to avoid the int complaining } @@ -130,5 +129,5 @@ func TestPluginAlertCountWatcher(t *testing.T) { err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) - resetTestTomb(&testTomb, &pw) + resetTestTomb(t, &testTomb, &pw) } diff --git a/pkg/leakybucket/bucket.go b/pkg/leakybucket/bucket.go index bc81a5059..e7ea6e3e2 100644 --- a/pkg/leakybucket/bucket.go +++ b/pkg/leakybucket/bucket.go @@ -316,7 +316,7 @@ func LeakRoutine(leaky *Leaky) error { alert, err = NewAlert(leaky, ofw) if err != nil { - log.Errorf("%s", err) + log.Error(err) } for _, f := range leaky.BucketConfig.processors { alert, ofw = f.OnBucketOverflow(leaky.BucketConfig)(leaky, alert, ofw) diff --git a/pkg/leakybucket/overflows.go b/pkg/leakybucket/overflows.go index 62ba3bc9a..9357caefa 100644 --- a/pkg/leakybucket/overflows.go +++ b/pkg/leakybucket/overflows.go @@ -363,10 +363,7 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { } if err := newApiAlert.Validate(strfmt.Default); err != nil { - log.Errorf("Generated alerts isn't valid") - log.Errorf("->%s", spew.Sdump(newApiAlert)) - // XXX: deep-exit - note other errors returned from this function are not fatal - log.Fatalf("error : %s", err) + return runtimeAlert, fmt.Errorf("invalid generated alert: %w: %s", err, spew.Sdump(newApiAlert)) } runtimeAlert.APIAlerts = append(runtimeAlert.APIAlerts, newApiAlert) diff --git a/pkg/parser/parsing_test.go b/pkg/parser/parsing_test.go index 5f6f924e7..84d5f4db7 100644 --- a/pkg/parser/parsing_test.go +++ b/pkg/parser/parsing_test.go @@ -13,6 +13,8 @@ import ( "github.com/davecgh/go-spew/spew" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" @@ -33,14 +35,11 @@ func TestParser(t *testing.T) { envSetting := os.Getenv("TEST_ONLY") - pctx, ectx, err := prepTests() - if err != nil { - t.Fatalf("failed to load env : %s", err) - } + pctx, ectx := prepTests(t) // Init the enricher if envSetting != "" { - if err := testOneParser(pctx, ectx, envSetting, nil); err != nil { + if err := testOneParser(t, pctx, ectx, envSetting, nil); err != nil { t.Fatalf("Test '%s' failed : %s", envSetting, err) } } else { @@ -57,7 +56,7 @@ func TestParser(t *testing.T) { fname := "./tests/" + fd.Name() log.Infof("Running test on %s", fname) - if err := testOneParser(pctx, ectx, fname, nil); err != nil { + if err := testOneParser(t, pctx, ectx, fname, nil); err != nil { t.Fatalf("Test '%s' failed : %s", fname, err) } } @@ -71,22 +70,16 @@ func BenchmarkParser(t *testing.B) { log.SetLevel(log.ErrorLevel) - pctx, ectx, err := prepTests() - if err != nil { - t.Fatalf("failed to load env : %s", err) - } + pctx, ectx := prepTests(t) envSetting := os.Getenv("TEST_ONLY") if envSetting != "" { - if err := testOneParser(pctx, ectx, envSetting, t); err != nil { - t.Fatalf("Test '%s' failed : %s", envSetting, err) - } + err := testOneParser(t, pctx, ectx, envSetting, t) + require.NoError(t, err, "Test '%s' failed", envSetting) } else { fds, err := os.ReadDir("./tests/") - if err != nil { - t.Fatalf("Unable to read test directory : %s", err) - } + require.NoError(t, err, "Unable to read test directory") for _, fd := range fds { if !fd.IsDir() { @@ -96,14 +89,13 @@ func BenchmarkParser(t *testing.B) { fname := "./tests/" + fd.Name() log.Infof("Running test on %s", fname) - if err := testOneParser(pctx, ectx, fname, t); err != nil { - t.Fatalf("Test '%s' failed : %s", fname, err) - } + err := testOneParser(t, pctx, ectx, fname, t) + require.NoError(t, err, "Test '%s' failed", fname) } } } -func testOneParser(pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing.B) error { +func testOneParser(t require.TestingT, pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing.B) error { var ( err error pnodes []Node @@ -143,7 +135,7 @@ func testOneParser(pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing // TBD: Load post overflows // func testFile(t *testing.T, file string, pctx UnixParserCtx, nodes []Node) bool { parser_test_file := fmt.Sprintf("%s/test.yaml", dir) - tests := loadTestFile(parser_test_file) + tests := loadTestFile(t, parser_test_file) count := 1 if b != nil { @@ -152,7 +144,7 @@ func testOneParser(pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing } for range count { - if !testFile(tests, *pctx, pnodes) { + if !testFile(t, tests, *pctx, pnodes) { return errors.New("test failed") } } @@ -161,7 +153,7 @@ func testOneParser(pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing } // prepTests is going to do the initialisation of parser : it's going to load enrichment plugins and load the patterns. This is done here so that we don't redo it for each test -func prepTests() (*UnixParserCtx, EnricherCtx, error) { +func prepTests(t require.TestingT) (*UnixParserCtx, EnricherCtx) { var ( err error pctx *UnixParserCtx @@ -169,22 +161,16 @@ func prepTests() (*UnixParserCtx, EnricherCtx, error) { ) err = exprhelpers.Init(nil) - if err != nil { - return nil, ectx, fmt.Errorf("exprhelpers init failed: %w", err) - } + require.NoError(t, err, "exprhelpers init failed") // Load enrichment datadir := "./test_data/" err = exprhelpers.GeoIPInit(datadir) - if err != nil { - log.Fatalf("unable to initialize GeoIP: %s", err) - } + require.NoError(t, err, "geoip init failed") ectx, err = Loadplugin() - if err != nil { - return nil, ectx, fmt.Errorf("failed to load plugin geoip: %v", err) - } + require.NoError(t, err, "load plugin failed") log.Printf("Loaded -> %+v", ectx) @@ -194,18 +180,14 @@ func prepTests() (*UnixParserCtx, EnricherCtx, error) { /* this should be refactored to 2 lines :p */ // Init the parser pctx, err = Init(map[string]interface{}{"patterns": cfgdir + string("/patterns/"), "data": "./tests/"}) - if err != nil { - return nil, ectx, fmt.Errorf("failed to initialize parser: %v", err) - } + require.NoError(t, err, "parser init failed") - return pctx, ectx, nil + return pctx, ectx } -func loadTestFile(file string) []TestFile { +func loadTestFile(t require.TestingT, file string) []TestFile { yamlFile, err := os.Open(file) - if err != nil { - log.Fatalf("yamlFile.Get err #%v ", err) - } + require.NoError(t, err, "failed to open test file") dec := yaml.NewDecoder(yamlFile) dec.SetStrict(true) @@ -221,7 +203,7 @@ func loadTestFile(file string) []TestFile { break } - log.Fatalf("Failed to load testfile '%s' yaml error : %v", file, err) + require.NoError(t, err, "failed to load testfile '%s'", file) return nil } @@ -391,19 +373,14 @@ reCheck: return true, nil } -func testFile(testSet []TestFile, pctx UnixParserCtx, nodes []Node) bool { +func testFile(t require.TestingT, testSet []TestFile, pctx UnixParserCtx, nodes []Node) bool { log.Warning("Going to process one test set") for _, tf := range testSet { // func testSubSet(testSet TestFile, pctx UnixParserCtx, nodes []Node) (bool, error) { testOk, err := testSubSet(tf, pctx, nodes) - if err != nil { - log.Fatalf("test failed : %s", err) - } - - if !testOk { - log.Fatalf("failed test : %+v", tf) - } + require.NoError(t, err, "test failed") + assert.True(t, testOk, "failed test: %+v", tf) } return true @@ -427,9 +404,7 @@ func TestGeneratePatternsDoc(t *testing.T) { } pctx, err := Init(map[string]interface{}{"patterns": "../../config/patterns/", "data": "./tests/"}) - if err != nil { - t.Fatalf("unable to load patterns : %s", err) - } + require.NoError(t, err, "unable to load patterns") log.Infof("-> %s", spew.Sdump(pctx)) /*don't judge me, we do it for the users*/ diff --git a/pkg/setup/detect_test.go b/pkg/setup/detect_test.go index 588e74dab..553617032 100644 --- a/pkg/setup/detect_test.go +++ b/pkg/setup/detect_test.go @@ -54,7 +54,7 @@ func TestSetupHelperProcess(t *testing.T) { } fmt.Fprint(os.Stdout, fakeSystemctlOutput) - os.Exit(0) + os.Exit(0) //nolint:revive,deep-exit } func tempYAML(t *testing.T, content string) os.File {