diff --git a/Makefile b/Makefile index 60014f393..3a04f174c 100644 --- a/Makefile +++ b/Makefile @@ -279,11 +279,6 @@ cscli: ## Build cscli crowdsec: ## Build crowdsec @$(MAKE) -C $(CROWDSEC_FOLDER) build $(MAKE_FLAGS) -# for the tests with localstack -export AWS_ENDPOINT_FORCE=http://localhost:4566 -export AWS_ACCESS_KEY_ID=test -export AWS_SECRET_ACCESS_KEY=test - testenv: ifeq ($(TEST_LOCAL_ONLY),) @echo 'NOTE: You need to run "make localstack" in a separate shell, "make localstack-stop" to terminate it; or define the envvar TEST_LOCAL_ONLY to some value.' diff --git a/go.mod b/go.mod index 28540d62e..1eafe2786 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/creack/pty v1.1.21 // indirect github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 - github.com/crowdsecurity/go-cs-lib v0.0.17 + github.com/crowdsecurity/go-cs-lib v0.0.18 github.com/crowdsecurity/grokky v0.2.2 github.com/crowdsecurity/machineid v1.0.2 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc diff --git a/go.sum b/go.sum index ca1093431..fe4a7d74d 100644 --- a/go.sum +++ b/go.sum @@ -111,8 +111,8 @@ github.com/crowdsecurity/coraza/v3 v3.0.0-20250320231801-749b8bded21a h1:2Nyr+47 github.com/crowdsecurity/coraza/v3 v3.0.0-20250320231801-749b8bded21a/go.mod h1:xSaXWOhFMSbrV8qOOfBKAyw3aOqfwaSaOy5BgSF8XlA= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen+7WnLs4xDScS/Ex988+id2k6mDf8psU= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk= -github.com/crowdsecurity/go-cs-lib v0.0.17 h1:VM++7EDa34kVCXsCRwOjaua3XHru8FVfKUAbqEoQPas= -github.com/crowdsecurity/go-cs-lib v0.0.17/go.mod h1:XwGcvTt4lMq4Tm1IRMSKMDf0CVrnytTU8Uoofa7AR+g= +github.com/crowdsecurity/go-cs-lib v0.0.18 h1:GNyvaag5MXfuapIy4E30pIOvIE5AyHoanJBNSMA1cmE= +github.com/crowdsecurity/go-cs-lib v0.0.18/go.mod h1:XwGcvTt4lMq4Tm1IRMSKMDf0CVrnytTU8Uoofa7AR+g= github.com/crowdsecurity/grokky v0.2.2 h1:yALsI9zqpDArYzmSSxfBq2dhYuGUTKMJq8KOEIAsuo4= github.com/crowdsecurity/grokky v0.2.2/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM= github.com/crowdsecurity/machineid v1.0.2 h1:wpkpsUghJF8Khtmn/tg6GxgdhLA1Xflerh5lirI+bdc= diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go index 49d306631..1a2702208 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go @@ -1,19 +1,14 @@ package cloudwatchacquisition import ( - "errors" - "fmt" - "net" - "os" - "runtime" - "strings" "testing" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/cloudwatchlogs" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -30,6 +25,21 @@ import ( - check shutdown/restart */ +func createLogGroup(t *testing.T, cw *CloudwatchSource, group string) { + _, err := cw.cwClient.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{ + LogGroupName: aws.String(group), + }) + require.NoError(t, err) +} + +func createLogStream(t *testing.T, cw *CloudwatchSource, group string, stream string) { + _, err := cw.cwClient.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{ + LogGroupName: aws.String(group), + LogStreamName: aws.String(stream), + }) + require.NoError(t, err) +} + func deleteAllLogGroups(t *testing.T, cw *CloudwatchSource) { input := &cloudwatchlogs.DescribeLogGroupsInput{} result, err := cw.cwClient.DescribeLogGroups(input) @@ -43,112 +53,69 @@ func deleteAllLogGroups(t *testing.T, cw *CloudwatchSource) { } } -func checkForLocalStackAvailability() error { - v := os.Getenv("AWS_ENDPOINT_FORCE") - if v == "" { - return errors.New("missing aws endpoint for tests : AWS_ENDPOINT_FORCE") - } - - v = strings.TrimPrefix(v, "http://") - - _, err := net.Dial("tcp", v) - if err != nil { - return fmt.Errorf("while dialing %s: %w: aws endpoint isn't available", v, err) - } - - return nil +type CloudwatchSuite struct { + suite.Suite } -func TestMain(m *testing.M) { - if runtime.GOOS == "windows" { - os.Exit(0) - } - - if os.Getenv("TEST_LOCAL_ONLY") != "" { - os.Exit(0) - } - - if err := checkForLocalStackAvailability(); err != nil { - log.Fatalf("local stack error : %s", err) - } - +func (s *CloudwatchSuite) SetupSuite() { def_PollNewStreamInterval = 1 * time.Second def_PollStreamInterval = 1 * time.Second def_StreamReadTimeout = 10 * time.Second def_MaxStreamAge = 5 * time.Second def_PollDeadStreamInterval = 5 * time.Second - - os.Exit(m.Run()) } -func TestWatchLogGroupForStreams(t *testing.T) { - ctx := t.Context() +func TestCloudwatchSuite(t *testing.T) { + cstest.SetAWSTestEnv(t) + suite.Run(t, new(CloudwatchSuite)) +} - cstest.SkipOnWindows(t) - - log.SetLevel(log.DebugLevel) +func (s *CloudwatchSuite) TestWatchLogGroupForStreams() { + logrus.SetLevel(logrus.DebugLevel) tests := []struct { - config []byte + config string expectedCfgErr string expectedStartErr string name string setup func(*testing.T, *CloudwatchSource) run func(*testing.T, *CloudwatchSource) teardown func(*testing.T, *CloudwatchSource) - expectedResLen int expectedResMessages []string }{ // require a group name that doesn't exist { - name: "group_does_not_exists", - config: []byte(` + name: "group_does_not_exist", + config: ` source: cloudwatch aws_region: us-east-1 labels: type: test_source group_name: b -stream_name: test_stream`), +stream_name: test_stream`, expectedStartErr: "The specified log group does not exist", setup: func(t *testing.T, cw *CloudwatchSource) { deleteAllLogGroups(t, cw) - _, err := cw.cwClient.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{ - LogGroupName: aws.String("test_group_not_used_1"), - }) - require.NoError(t, err) - }, - teardown: func(t *testing.T, cw *CloudwatchSource) { - _, err := cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ - LogGroupName: aws.String("test_group_not_used_1"), - }) - require.NoError(t, err) + createLogGroup(t, cw, "test_group_not_used_1") }, }, // test stream mismatch { name: "group_exists_bad_stream_name", - config: []byte(` + config: ` source: cloudwatch aws_region: us-east-1 labels: type: test_source group_name: test_group1 -stream_name: test_stream_bad`), +stream_name: test_stream_bad`, setup: func(t *testing.T, cw *CloudwatchSource) { deleteAllLogGroups(t, cw) - _, err := cw.cwClient.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{ - LogGroupName: aws.String("test_group1"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{ - LogGroupName: aws.String("test_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) + createLogGroup(t, cw, "test_group1") + createLogStream(t, cw, "test_group1", "test_stream") // have a message before we start - won't be popped, but will trigger stream monitoring - _, err = cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ + _, err := cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ LogGroupName: aws.String("test_group1"), LogStreamName: aws.String("test_stream"), LogEvents: []*cloudwatchlogs.InputLogEvent{ @@ -160,39 +127,25 @@ stream_name: test_stream_bad`), }) require.NoError(t, err) }, - teardown: func(t *testing.T, cw *CloudwatchSource) { - _, err := cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ - LogGroupName: aws.String("test_group1"), - }) - require.NoError(t, err) - }, - expectedResLen: 0, + expectedResMessages: []string{}, }, // test stream mismatch { name: "group_exists_bad_stream_regexp", - config: []byte(` + config: ` source: cloudwatch aws_region: us-east-1 labels: type: test_source group_name: test_group1 -stream_regexp: test_bad[0-9]+`), +stream_regexp: test_bad[0-9]+`, setup: func(t *testing.T, cw *CloudwatchSource) { deleteAllLogGroups(t, cw) - _, err := cw.cwClient.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{ - LogGroupName: aws.String("test_group1"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{ - LogGroupName: aws.String("test_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) + createLogGroup(t, cw, "test_group1") + createLogStream(t, cw, "test_group1", "test_stream") // have a message before we start - won't be popped, but will trigger stream monitoring - _, err = cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ + _, err := cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ LogGroupName: aws.String("test_group1"), LogStreamName: aws.String("test_stream"), LogEvents: []*cloudwatchlogs.InputLogEvent{ @@ -204,41 +157,27 @@ stream_regexp: test_bad[0-9]+`), }) require.NoError(t, err) }, - teardown: func(t *testing.T, cw *CloudwatchSource) { - _, err := cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ - LogGroupName: aws.String("test_group1"), - }) - require.NoError(t, err) - }, - expectedResLen: 0, + expectedResMessages: []string{}, }, // require a group name that does exist and contains a stream in which we are going to put events { name: "group_exists_stream_exists_has_events", - config: []byte(` + config: ` source: cloudwatch aws_region: us-east-1 labels: type: test_source group_name: test_log_group1 log_level: trace -stream_name: test_stream`), +stream_name: test_stream`, // expectedStartErr: "The specified log group does not exist", setup: func(t *testing.T, cw *CloudwatchSource) { deleteAllLogGroups(t, cw) - _, err := cw.cwClient.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{ - LogGroupName: aws.String("test_log_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) + createLogGroup(t, cw, "test_log_group1") + createLogStream(t, cw, "test_log_group1", "test_stream") // have a message before we start - won't be popped, but will trigger stream monitoring - _, err = cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ + _, err := cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ LogGroupName: aws.String("test_log_group1"), LogStreamName: aws.String("test_stream"), LogEvents: []*cloudwatchlogs.InputLogEvent{ @@ -271,48 +210,27 @@ stream_name: test_stream`), }) require.NoError(t, err) }, - teardown: func(t *testing.T, cw *CloudwatchSource) { - _, err := cw.cwClient.DeleteLogStream(&cloudwatchlogs.DeleteLogStreamInput{ - LogGroupName: aws.String("test_log_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - }, - expectedResLen: 3, expectedResMessages: []string{"test_message_1", "test_message_4", "test_message_5"}, }, // have a stream generate events, reach time-out and gets polled again { name: "group_exists_stream_exists_has_events+timeout", - config: []byte(` + config: ` source: cloudwatch aws_region: us-east-1 labels: type: test_source group_name: test_log_group1 log_level: trace -stream_name: test_stream`), +stream_name: test_stream`, // expectedStartErr: "The specified log group does not exist", setup: func(t *testing.T, cw *CloudwatchSource) { deleteAllLogGroups(t, cw) - _, err := cw.cwClient.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{ - LogGroupName: aws.String("test_log_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) + createLogGroup(t, cw, "test_log_group1") + createLogStream(t, cw, "test_log_group1", "test_stream") // have a message before we start - won't be popped, but will trigger stream monitoring - _, err = cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ + _, err := cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ LogGroupName: aws.String("test_log_group1"), LogStreamName: aws.String("test_stream"), LogEvents: []*cloudwatchlogs.InputLogEvent{ @@ -358,48 +276,27 @@ stream_name: test_stream`), time.Sleep(def_PollNewStreamInterval + (1 * time.Second)) time.Sleep(def_PollStreamInterval + (1 * time.Second)) }, - teardown: func(t *testing.T, cw *CloudwatchSource) { - _, err := cw.cwClient.DeleteLogStream(&cloudwatchlogs.DeleteLogStreamInput{ - LogGroupName: aws.String("test_log_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - }, - expectedResLen: 3, expectedResMessages: []string{"test_message_1", "test_message_41", "test_message_51"}, }, // have a stream generate events, reach time-out and dead body collection { name: "group_exists_stream_exists_has_events+timeout+GC", - config: []byte(` + config: ` source: cloudwatch aws_region: us-east-1 labels: type: test_source group_name: test_log_group1 log_level: trace -stream_name: test_stream`), +stream_name: test_stream`, // expectedStartErr: "The specified log group does not exist", setup: func(t *testing.T, cw *CloudwatchSource) { deleteAllLogGroups(t, cw) - _, err := cw.cwClient.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{ - LogGroupName: aws.String("test_log_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) + createLogGroup(t, cw, "test_log_group1") + createLogStream(t, cw, "test_log_group1", "test_stream") // have a message before we start - won't be popped, but will trigger stream monitoring - _, err = cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ + _, err := cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ LogGroupName: aws.String("test_log_group1"), LogStreamName: aws.String("test_stream"), LogEvents: []*cloudwatchlogs.InputLogEvent{ @@ -417,31 +314,19 @@ stream_name: test_stream`), time.Sleep(def_PollStreamInterval + (1 * time.Second)) time.Sleep(def_PollDeadStreamInterval + (1 * time.Second)) }, - teardown: func(t *testing.T, cw *CloudwatchSource) { - _, err := cw.cwClient.DeleteLogStream(&cloudwatchlogs.DeleteLogStreamInput{ - LogGroupName: aws.String("test_log_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - }, - expectedResLen: 1, + expectedResMessages: []string{"test_message_1"}, }, } for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - dbgLogger := log.New().WithField("test", tc.name) - dbgLogger.Logger.SetLevel(log.DebugLevel) + s.Run(tc.name, func() { + dbgLogger := logrus.New().WithField("test", tc.name) + dbgLogger.Logger.SetLevel(logrus.DebugLevel) dbgLogger.Infof("starting test") cw := CloudwatchSource{} - err := cw.Configure(tc.config, dbgLogger, configuration.METRICS_NONE) - cstest.RequireErrorContains(t, err, tc.expectedCfgErr) + err := cw.Configure(([]byte)(tc.config), dbgLogger, configuration.METRICS_NONE) + cstest.RequireErrorContains(s.T(), err, tc.expectedCfgErr) if tc.expectedCfgErr != "" { return @@ -449,41 +334,42 @@ stream_name: test_stream`), // run pre-routine : tests use it to set group & streams etc. if tc.setup != nil { - tc.setup(t, &cw) + tc.setup(s.T(), &cw) } out := make(chan types.Event) tmb := tomb.Tomb{} - rcvdEvts := []types.Event{} dbgLogger.Infof("running StreamingAcquisition") actmb := tomb.Tomb{} actmb.Go(func() error { - err := cw.StreamingAcquisition(ctx, out, &actmb) + err := cw.StreamingAcquisition(s.T().Context(), out, &actmb) dbgLogger.Infof("acquis done") - cstest.RequireErrorContains(t, err, tc.expectedStartErr) + cstest.RequireErrorContains(s.T(), err, tc.expectedStartErr) return nil }) + got := []string{} + // let's empty output chan tmb.Go(func() error { for { select { case in := <-out: - log.Debugf("received event %+v", in) - rcvdEvts = append(rcvdEvts, in) + dbgLogger.Debugf("received event %+v", in) + got = append(got, in.Line.Raw) case <-tmb.Dying(): - log.Debugf("pumper died") + dbgLogger.Debug("pumper died") return nil } } }) if tc.run != nil { - tc.run(t, &cw) + tc.run(s.T(), &cw) } else { dbgLogger.Warning("no code to run") } @@ -495,96 +381,69 @@ stream_name: test_stream`), dbgLogger.Infof("killing datasource") actmb.Kill(nil) <-actmb.Dead() - // dbgLogger.Infof("collected events : %d -> %+v", len(rcvd_evts), rcvd_evts) - // check results - if tc.expectedResLen != -1 { - if tc.expectedResLen != len(rcvdEvts) { - t.Fatalf("%s : expected %d results got %d -> %v", tc.name, tc.expectedResLen, len(rcvdEvts), rcvdEvts) - } - dbgLogger.Debugf("got %d expected messages", len(rcvdEvts)) - } - - if len(tc.expectedResMessages) != 0 { - res := tc.expectedResMessages - for idx, v := range rcvdEvts { - if len(res) == 0 { - t.Fatalf("result %d/%d : received '%s', didn't expect anything (recvd:%d, expected:%d)", idx, len(rcvdEvts), v.Line.Raw, len(rcvdEvts), len(tc.expectedResMessages)) - } - - if res[0] != v.Line.Raw { - t.Fatalf("result %d/%d : expected '%s', received '%s' (recvd:%d, expected:%d)", idx, len(rcvdEvts), res[0], v.Line.Raw, len(rcvdEvts), len(tc.expectedResMessages)) - } - - dbgLogger.Debugf("got message '%s'", res[0]) - res = res[1:] - } - - if len(res) != 0 { - t.Fatalf("leftover unmatched results : %v", res) - } + if len(tc.expectedResMessages) == 0 { + s.Empty(got, "unexpected events") + } else { + s.Equal(tc.expectedResMessages, got, "mismatched events") } if tc.teardown != nil { - tc.teardown(t, &cw) + tc.teardown(s.T(), &cw) } }) } } -func TestConfiguration(t *testing.T) { - ctx := t.Context() - - cstest.SkipOnWindows(t) - - log.SetLevel(log.DebugLevel) +func (s *CloudwatchSuite) TestConfiguration() { + logrus.SetLevel(logrus.DebugLevel) tests := []struct { - config []byte + config string expectedCfgErr string expectedStartErr string name string }{ { - name: "group_does_not_exists", - config: []byte(` + name: "group_does_not_exist", + config: ` source: cloudwatch aws_region: us-east-1 labels: type: test_source group_name: test_group -stream_name: test_stream`), +stream_name: test_stream`, expectedStartErr: "The specified log group does not exist", }, { - config: []byte(` + config: ` xxx: cloudwatch labels: type: test_source group_name: test_group -stream_name: test_stream`), +stream_name: test_stream`, expectedCfgErr: "field xxx not found in type", }, { name: "missing_group_name", - config: []byte(` + config: ` source: cloudwatch aws_region: us-east-1 labels: type: test_source -stream_name: test_stream`), +stream_name: test_stream`, expectedCfgErr: "group_name is mandatory for CloudwatchSource", }, } for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - dbgLogger := log.New().WithField("test", tc.name) - dbgLogger.Logger.SetLevel(log.DebugLevel) + s.Run(tc.name, func() { + dbgLogger := logrus.New().WithField("test", tc.name) + dbgLogger.Logger.SetLevel(logrus.DebugLevel) cw := CloudwatchSource{} - err := cw.Configure(tc.config, dbgLogger, configuration.METRICS_NONE) - cstest.RequireErrorContains(t, err, tc.expectedCfgErr) + err := cw.Configure(([]byte)(tc.config), dbgLogger, configuration.METRICS_NONE) + cstest.RequireErrorContains(s.T(), err, tc.expectedCfgErr) if tc.expectedCfgErr != "" { return @@ -595,25 +454,23 @@ stream_name: test_stream`), switch cw.GetMode() { case "tail": - err = cw.StreamingAcquisition(ctx, out, &tmb) + err = cw.StreamingAcquisition(s.T().Context(), out, &tmb) case "cat": - err = cw.OneShotAcquisition(ctx, out, &tmb) + err = cw.OneShotAcquisition(s.T().Context(), out, &tmb) } - cstest.RequireErrorContains(t, err, tc.expectedStartErr) + cstest.RequireErrorContains(s.T(), err, tc.expectedStartErr) - log.Debugf("killing ...") + dbgLogger.Debugf("killing ...") tmb.Kill(nil) <-tmb.Dead() - log.Debugf("dead :)") + dbgLogger.Debugf("dead :)") }) } } -func TestConfigureByDSN(t *testing.T) { - cstest.SkipOnWindows(t) - - log.SetLevel(log.DebugLevel) +func (s *CloudwatchSuite) TestConfigureByDSN() { + logrus.SetLevel(logrus.DebugLevel) tests := []struct { dsn string @@ -644,23 +501,19 @@ func TestConfigureByDSN(t *testing.T) { } for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - dbgLogger := log.New().WithField("test", tc.name) - dbgLogger.Logger.SetLevel(log.DebugLevel) + s.Run(tc.name, func() { + dbgLogger := logrus.New().WithField("test", tc.name) + dbgLogger.Logger.SetLevel(logrus.DebugLevel) cw := CloudwatchSource{} err := cw.ConfigureByDSN(tc.dsn, tc.labels, dbgLogger, "") - cstest.RequireErrorContains(t, err, tc.expectedCfgErr) + cstest.RequireErrorContains(s.T(), err, tc.expectedCfgErr) }) } } -func TestOneShotAcquisition(t *testing.T) { - ctx := t.Context() - - cstest.SkipOnWindows(t) - - log.SetLevel(log.DebugLevel) +func (s *CloudwatchSuite) TestOneShotAcquisition() { + logrus.SetLevel(logrus.DebugLevel) tests := []struct { dsn string @@ -670,7 +523,6 @@ func TestOneShotAcquisition(t *testing.T) { setup func(*testing.T, *CloudwatchSource) run func(*testing.T, *CloudwatchSource) teardown func(*testing.T, *CloudwatchSource) - expectedResLen int expectedResMessages []string }{ // stream with no data @@ -680,24 +532,10 @@ func TestOneShotAcquisition(t *testing.T) { // expectedStartErr: "The specified log group does not exist", setup: func(t *testing.T, cw *CloudwatchSource) { deleteAllLogGroups(t, cw) - _, err := cw.cwClient.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{ - LogGroupName: aws.String("test_log_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) + createLogGroup(t, cw, "test_log_group1") + createLogStream(t, cw, "test_log_group1", "test_stream") }, - teardown: func(t *testing.T, cw *CloudwatchSource) { - _, err := cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - }, - expectedResLen: 0, + expectedResMessages: []string{}, }, // stream with one event { @@ -706,19 +544,11 @@ func TestOneShotAcquisition(t *testing.T) { // expectedStartErr: "The specified log group does not exist", setup: func(t *testing.T, cw *CloudwatchSource) { deleteAllLogGroups(t, cw) - _, err := cw.cwClient.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - - _, err = cw.cwClient.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{ - LogGroupName: aws.String("test_log_group1"), - LogStreamName: aws.String("test_stream"), - }) - require.NoError(t, err) + createLogGroup(t, cw, "test_log_group1") + createLogStream(t, cw, "test_log_group1", "test_stream") // this one is too much in the back - _, err = cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ + _, err := cw.cwClient.PutLogEvents(&cloudwatchlogs.PutLogEventsInput{ LogGroupName: aws.String("test_log_group1"), LogStreamName: aws.String("test_stream"), LogEvents: []*cloudwatchlogs.InputLogEvent{ @@ -756,26 +586,19 @@ func TestOneShotAcquisition(t *testing.T) { }) require.NoError(t, err) }, - teardown: func(t *testing.T, cw *CloudwatchSource) { - _, err := cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ - LogGroupName: aws.String("test_log_group1"), - }) - require.NoError(t, err) - }, - expectedResLen: 1, expectedResMessages: []string{"test_message_2"}, }, } for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - dbgLogger := log.New().WithField("test", tc.name) - dbgLogger.Logger.SetLevel(log.DebugLevel) + s.Run(tc.name, func() { + dbgLogger := logrus.New().WithField("test", tc.name) + dbgLogger.Logger.SetLevel(logrus.DebugLevel) dbgLogger.Infof("starting test") cw := CloudwatchSource{} err := cw.ConfigureByDSN(tc.dsn, map[string]string{"type": "test"}, dbgLogger, "") - cstest.RequireErrorContains(t, err, tc.expectedCfgErr) + cstest.RequireErrorContains(s.T(), err, tc.expectedCfgErr) if tc.expectedCfgErr != "" { return @@ -784,61 +607,39 @@ func TestOneShotAcquisition(t *testing.T) { dbgLogger.Infof("config done test") // run pre-routine : tests use it to set group & streams etc. if tc.setup != nil { - tc.setup(t, &cw) + tc.setup(s.T(), &cw) } out := make(chan types.Event, 100) tmb := tomb.Tomb{} - rcvdEvts := []types.Event{} - dbgLogger.Infof("running StreamingAcquisition") + dbgLogger.Infof("running OneShotAcquisition") - err = cw.OneShotAcquisition(ctx, out, &tmb) - cstest.RequireErrorContains(t, err, tc.expectedStartErr) + err = cw.OneShotAcquisition(s.T().Context(), out, &tmb) + cstest.RequireErrorContains(s.T(), err, tc.expectedStartErr) dbgLogger.Infof("acquis done") close(out) // let's empty output chan + got := []string{} for evt := range out { - rcvdEvts = append(rcvdEvts, evt) + got = append(got, evt.Line.Raw) } if tc.run != nil { - tc.run(t, &cw) + tc.run(s.T(), &cw) } else { dbgLogger.Warning("no code to run") } - if tc.expectedResLen != -1 { - if tc.expectedResLen != len(rcvdEvts) { - t.Fatalf("%s : expected %d results got %d -> %v", tc.name, tc.expectedResLen, len(rcvdEvts), rcvdEvts) - } else { - dbgLogger.Debugf("got %d expected messages", len(rcvdEvts)) - } - } - - if len(tc.expectedResMessages) != 0 { - res := tc.expectedResMessages - for idx, v := range rcvdEvts { - if len(res) == 0 { - t.Fatalf("result %d/%d : received '%s', didn't expect anything (recvd:%d, expected:%d)", idx, len(rcvdEvts), v.Line.Raw, len(rcvdEvts), len(tc.expectedResMessages)) - } - - if res[0] != v.Line.Raw { - t.Fatalf("result %d/%d : expected '%s', received '%s' (recvd:%d, expected:%d)", idx, len(rcvdEvts), res[0], v.Line.Raw, len(rcvdEvts), len(tc.expectedResMessages)) - } - - dbgLogger.Debugf("got message '%s'", res[0]) - res = res[1:] - } - - if len(res) != 0 { - t.Fatalf("leftover unmatched results : %v", res) - } + if len(tc.expectedResMessages) == 0 { + s.Empty(got, "unexpected events") + } else { + s.Equal(tc.expectedResMessages, got, "mismatched events") } if tc.teardown != nil { - tc.teardown(t, &cw) + tc.teardown(s.T(), &cw) } }) } diff --git a/pkg/acquisition/modules/kafka/kafka_test.go b/pkg/acquisition/modules/kafka/kafka_test.go index 186cd19bc..8206d1833 100644 --- a/pkg/acquisition/modules/kafka/kafka_test.go +++ b/pkg/acquisition/modules/kafka/kafka_test.go @@ -127,8 +127,7 @@ func createTopic(topic string, broker string) { } func TestStreamingAcquisition(t *testing.T) { - cstest.SkipOnWindows(t) - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") + cstest.SetAWSTestEnv(t) ctx := t.Context() @@ -200,8 +199,7 @@ topic: crowdsecplaintext`), subLogger, configuration.METRICS_NONE) } func TestStreamingAcquisitionWithSSL(t *testing.T) { - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") - cstest.SkipOnWindows(t) + cstest.SetAWSTestEnv(t) ctx := t.Context() diff --git a/pkg/acquisition/modules/kinesis/kinesis_test.go b/pkg/acquisition/modules/kinesis/kinesis_test.go index 7b0a26674..faeb14656 100644 --- a/pkg/acquisition/modules/kinesis/kinesis_test.go +++ b/pkg/acquisition/modules/kinesis/kinesis_test.go @@ -5,10 +5,8 @@ import ( "compress/gzip" "encoding/json" "fmt" - "net" "os" "strconv" - "strings" "testing" "time" @@ -26,21 +24,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -func getLocalStackEndpoint() (string, error) { - endpoint := "http://localhost:4566" - - if v := os.Getenv("AWS_ENDPOINT_FORCE"); v != "" { - v = strings.TrimPrefix(v, "http://") - - _, err := net.Dial("tcp", v) - if err != nil { - return "", fmt.Errorf("while dialing %s: %w: aws endpoint isn't available", v, err) - } - } - - return endpoint, nil -} - func GenSubObject(t *testing.T, i int) []byte { r := CloudWatchSubscriptionRecord{ MessageType: "subscription", @@ -69,10 +52,7 @@ func GenSubObject(t *testing.T, i int) []byte { return b.Bytes() } -func WriteToStream(t *testing.T, streamName string, count int, shards int, sub bool) { - endpoint, err := getLocalStackEndpoint() - require.NoError(t, err) - +func WriteToStream(t *testing.T, endpoint string, streamName string, count int, shards int, sub bool) { sess := session.Must(session.NewSession()) kinesisClient := kinesis.New(sess, aws.NewConfig().WithEndpoint(endpoint).WithRegion("us-east-1")) @@ -90,7 +70,7 @@ func WriteToStream(t *testing.T, streamName string, count int, shards int, sub b data = []byte(strconv.Itoa(i)) } - _, err = kinesisClient.PutRecord(&kinesis.PutRecordInput{ + _, err := kinesisClient.PutRecord(&kinesis.PutRecordInput{ Data: data, PartitionKey: aws.String(partition), StreamName: aws.String(streamName), @@ -153,8 +133,7 @@ stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, } func TestReadFromStream(t *testing.T) { - cstest.SkipOnWindows(t) - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") + endpoint := cstest.SetAWSTestEnv(t) ctx := t.Context() @@ -172,8 +151,6 @@ stream_name: stream-1-shard`, shards: 1, }, } - endpoint, _ := getLocalStackEndpoint() - for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) @@ -186,7 +163,7 @@ stream_name: stream-1-shard`, require.NoError(t, err) // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) - WriteToStream(t, f.Config.StreamName, test.count, test.shards, false) + WriteToStream(t, endpoint, f.Config.StreamName, test.count, test.shards, false) for i := range test.count { e := <-out @@ -200,8 +177,7 @@ stream_name: stream-1-shard`, } func TestReadFromMultipleShards(t *testing.T) { - cstest.SkipOnWindows(t) - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") + endpoint := cstest.SetAWSTestEnv(t) ctx := t.Context() @@ -219,7 +195,6 @@ stream_name: stream-2-shards`, shards: 2, }, } - endpoint, _ := getLocalStackEndpoint() for _, test := range tests { f := KinesisSource{} @@ -233,7 +208,7 @@ stream_name: stream-2-shards`, require.NoError(t, err) // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) - WriteToStream(t, f.Config.StreamName, test.count, test.shards, false) + WriteToStream(t, endpoint, f.Config.StreamName, test.count, test.shards, false) c := 0 @@ -250,8 +225,7 @@ stream_name: stream-2-shards`, } func TestFromSubscription(t *testing.T) { - cstest.SkipOnWindows(t) - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") + endpoint := cstest.SetAWSTestEnv(t) ctx := t.Context() @@ -270,7 +244,6 @@ from_subscription: true`, shards: 1, }, } - endpoint, _ := getLocalStackEndpoint() for _, test := range tests { f := KinesisSource{} @@ -284,7 +257,7 @@ from_subscription: true`, require.NoError(t, err) // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) - WriteToStream(t, f.Config.StreamName, test.count, test.shards, true) + WriteToStream(t, endpoint, f.Config.StreamName, test.count, test.shards, true) for i := range test.count { e := <-out diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index 5bfd6ff89..bb7b6e0eb 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -331,8 +331,7 @@ func feedLoki(ctx context.Context, logger *log.Entry, n int, title string) error } func TestOneShotAcquisition(t *testing.T) { - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") - cstest.SkipOnWindows(t) + cstest.SetAWSTestEnv(t) ctx := t.Context() @@ -392,8 +391,7 @@ since: 1h } func TestStreamingAcquisition(t *testing.T) { - cstest.SkipOnWindows(t) - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") + cstest.SetAWSTestEnv(t) ctx := t.Context() @@ -507,8 +505,7 @@ query: > } func TestStopStreaming(t *testing.T) { - cstest.SkipOnWindows(t) - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") + cstest.SetAWSTestEnv(t) ctx := t.Context() diff --git a/pkg/acquisition/modules/victorialogs/victorialogs_test.go b/pkg/acquisition/modules/victorialogs/victorialogs_test.go index e8e43cdba..018f19a71 100644 --- a/pkg/acquisition/modules/victorialogs/victorialogs_test.go +++ b/pkg/acquisition/modules/victorialogs/victorialogs_test.go @@ -253,8 +253,7 @@ func feedVLogs(ctx context.Context, logger *log.Entry, n int, title string) erro } func TestOneShotAcquisition(t *testing.T) { - cstest.SkipOnWindows(t) - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") + cstest.SetAWSTestEnv(t) ctx := t.Context() @@ -317,8 +316,7 @@ since: 1h } func TestStreamingAcquisition(t *testing.T) { - cstest.SkipOnWindows(t) - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") + cstest.SetAWSTestEnv(t) ctx := t.Context() @@ -428,8 +426,7 @@ query: > } func TestStopStreaming(t *testing.T) { - cstest.SkipOnWindows(t) - cstest.SkipIfDefined(t, "TEST_LOCAL_ONLY") + cstest.SetAWSTestEnv(t) ctx := t.Context()