context propagation: pkg/csplugin (#3273)

This commit is contained in:
mmetc 2024-10-10 17:18:59 +02:00 committed by GitHub
parent 50d115b914
commit 8ff58ee74e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 66 additions and 33 deletions

View file

@ -275,7 +275,8 @@ func (cli cliNotifications) newTestCmd() *cobra.Command {
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
ValidArgsFunction: cli.notificationConfigFilter,
PreRunE: func(_ *cobra.Command, args []string) error {
PreRunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
cfg := cli.cfg()
pconfigs, err := cli.getPluginConfigs()
if err != nil {
@ -286,7 +287,7 @@ func (cli cliNotifications) newTestCmd() *cobra.Command {
return fmt.Errorf("plugin name: '%s' does not exist", args[0])
}
// Create a single profile with plugin name as notification name
return pluginBroker.Init(cfg.PluginConfig, []*csconfig.ProfileCfg{
return pluginBroker.Init(ctx, cfg.PluginConfig, []*csconfig.ProfileCfg{
{
Notifications: []string{
pcfg.Name,
@ -377,12 +378,13 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
return nil
},
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
var (
pluginBroker csplugin.PluginBroker
pluginTomb tomb.Tomb
)
ctx := cmd.Context()
cfg := cli.cfg()
if alertOverride != "" {
@ -391,7 +393,7 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
}
}
err := pluginBroker.Init(cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths)
err := pluginBroker.Init(ctx, cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths)
if err != nil {
return fmt.Errorf("can't initialize plugins: %w", err)
}

View file

@ -40,7 +40,7 @@ func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.AP
return nil, errors.New("plugins are enabled, but config_paths.plugin_dir is not defined")
}
err = pluginBroker.Init(cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths)
err = pluginBroker.Init(ctx, cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths)
if err != nil {
return nil, fmt.Errorf("unable to run plugin broker: %w", err)
}

View file

@ -72,8 +72,8 @@ func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.Wat
}
func LoginToTestAPI(t *testing.T, ctx context.Context, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse {
body := CreateTestMachine(t, router, "")
ValidateMachine(t, "test", config.API.Server.DbConfig)
body := CreateTestMachine(t, ctx, router, "")
ValidateMachine(t, ctx, "test", config.API.Server.DbConfig)
w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body))

View file

@ -180,9 +180,7 @@ func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csc
return router, config
}
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
ctx := context.TODO()
func ValidateMachine(t *testing.T, ctx context.Context, machineID string, config *csconfig.DatabaseCfg) {
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)
@ -269,7 +267,7 @@ func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map
return response, resp.Code
}
func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string {
func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, token string) string {
regReq := MachineTest
regReq.RegistrationToken = token
b, err := json.Marshal(regReq)
@ -277,8 +275,6 @@ func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string {
body := string(b)
ctx := context.Background()
w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body))
req.Header.Set("User-Agent", UserAgent)

View file

@ -14,7 +14,7 @@ func TestLogin(t *testing.T) {
ctx := context.Background()
router, config := NewAPITest(t, ctx)
body := CreateTestMachine(t, router, "")
body := CreateTestMachine(t, ctx, router, "")
// Login with machine not validated yet
w := httptest.NewRecorder()
@ -53,7 +53,7 @@ func TestLogin(t *testing.T) {
assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String())
// Validate machine
ValidateMachine(t, "test", config.API.Server.DbConfig)
ValidateMachine(t, ctx, "test", config.API.Server.DbConfig)
// Login with invalid password
w = httptest.NewRecorder()

View file

@ -131,7 +131,7 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
ctx := context.Background()
router, _ := NewAPITest(t, ctx)
body := CreateTestMachine(t, router, "")
body := CreateTestMachine(t, ctx, router, "")
w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body))

View file

@ -72,7 +72,7 @@ type ProfileAlert struct {
Alert *models.Alert
}
func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error {
func (pb *PluginBroker) Init(ctx context.Context, pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error {
pb.PluginChannel = make(chan ProfileAlert)
pb.notificationConfigsByPluginType = make(map[string][][]byte)
pb.notificationPluginByName = make(map[string]protobufs.NotifierServer)
@ -85,7 +85,7 @@ func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*cs
if err := pb.loadConfig(configPaths.NotificationDir); err != nil {
return fmt.Errorf("while loading plugin config: %w", err)
}
if err := pb.loadPlugins(configPaths.PluginDir); err != nil {
if err := pb.loadPlugins(ctx, configPaths.PluginDir); err != nil {
return fmt.Errorf("while loading plugin: %w", err)
}
pb.watcher = PluginWatcher{}
@ -230,7 +230,7 @@ func (pb *PluginBroker) verifyPluginBinaryWithProfile() error {
return nil
}
func (pb *PluginBroker) loadPlugins(path string) error {
func (pb *PluginBroker) loadPlugins(ctx context.Context, path string) error {
binaryPaths, err := listFilesAtPath(path)
if err != nil {
return err
@ -265,7 +265,7 @@ func (pb *PluginBroker) loadPlugins(path string) error {
return err
}
data = []byte(csstring.StrictExpand(string(data), os.LookupEnv))
_, err = pluginClient.Configure(context.Background(), &protobufs.Config{Config: data})
_, err = pluginClient.Configure(ctx, &protobufs.Config{Config: data})
if err != nil {
return fmt.Errorf("while configuring %s: %w", pc.Name, err)
}

View file

@ -1,6 +1,7 @@
package csplugin
import (
"context"
"io"
"os"
"os/exec"
@ -96,6 +97,7 @@ func (s *PluginSuite) TearDownTest() {
func (s *PluginSuite) SetupSubTest() {
var err error
t := s.T()
s.runDir, err = os.MkdirTemp("", "cs_plugin_test")
@ -127,6 +129,7 @@ func (s *PluginSuite) SetupSubTest() {
func (s *PluginSuite) TearDownSubTest() {
t := s.T()
if s.pluginBroker != nil {
s.pluginBroker.Kill()
s.pluginBroker = nil
@ -140,19 +143,24 @@ func (s *PluginSuite) TearDownSubTest() {
os.Remove("./out")
}
func (s *PluginSuite) InitBroker(procCfg *csconfig.PluginCfg) (*PluginBroker, error) {
func (s *PluginSuite) InitBroker(ctx context.Context, procCfg *csconfig.PluginCfg) (*PluginBroker, error) {
pb := PluginBroker{}
if procCfg == nil {
procCfg = &csconfig.PluginCfg{}
}
profiles := csconfig.NewDefaultConfig().API.Server.Profiles
profiles = append(profiles, &csconfig.ProfileCfg{
Notifications: []string{"dummy_default"},
})
err := pb.Init(procCfg, profiles, &csconfig.ConfigurationPaths{
err := pb.Init(ctx, procCfg, profiles, &csconfig.ConfigurationPaths{
PluginDir: s.pluginDir,
NotificationDir: s.notifDir,
})
s.pluginBroker = &pb
return s.pluginBroker, err
}

View file

@ -4,6 +4,7 @@ package csplugin
import (
"bytes"
"context"
"encoding/json"
"io"
"os"
@ -53,6 +54,7 @@ func (s *PluginSuite) writeconfig(config PluginConfig) {
}
func (s *PluginSuite) TestBrokerInit() {
ctx := context.Background()
tests := []struct {
name string
action func(*testing.T)
@ -135,20 +137,22 @@ func (s *PluginSuite) TestBrokerInit() {
tc.action(t)
}
_, err := s.InitBroker(&tc.procCfg)
_, err := s.InitBroker(ctx, &tc.procCfg)
cstest.RequireErrorContains(t, err, tc.expectedErr)
})
}
}
func (s *PluginSuite) TestBrokerNoThreshold() {
ctx := context.Background()
var alerts []models.Alert
DefaultEmptyTicker = 50 * time.Millisecond
t := s.T()
pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)
tomb := tomb.Tomb{}
@ -187,6 +191,8 @@ func (s *PluginSuite) TestBrokerNoThreshold() {
}
func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
ctx := context.Background()
// test grouping by "time"
DefaultEmptyTicker = 50 * time.Millisecond
@ -198,7 +204,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
cfg.GroupWait = 1 * time.Second
s.writeconfig(cfg)
pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)
tomb := tomb.Tomb{}
@ -224,6 +230,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
}
func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
ctx := context.Background()
DefaultEmptyTicker = 50 * time.Millisecond
t := s.T()
@ -234,7 +241,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
cfg.GroupWait = 4 * time.Second
s.writeconfig(cfg)
pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)
tomb := tomb.Tomb{}
@ -264,6 +271,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
}
func (s *PluginSuite) TestBrokerRunGroupThreshold() {
ctx := context.Background()
// test grouping by "size"
DefaultEmptyTicker = 50 * time.Millisecond
@ -274,7 +282,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
cfg.GroupThreshold = 4
s.writeconfig(cfg)
pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)
tomb := tomb.Tomb{}
@ -318,6 +326,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
}
func (s *PluginSuite) TestBrokerRunTimeThreshold() {
ctx := context.Background()
DefaultEmptyTicker = 50 * time.Millisecond
t := s.T()
@ -327,7 +336,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
cfg.GroupWait = 1 * time.Second
s.writeconfig(cfg)
pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)
tomb := tomb.Tomb{}
@ -353,11 +362,12 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
}
func (s *PluginSuite) TestBrokerRunSimple() {
ctx := context.Background()
DefaultEmptyTicker = 50 * time.Millisecond
t := s.T()
pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)
tomb := tomb.Tomb{}

View file

@ -4,6 +4,7 @@ package csplugin
import (
"bytes"
"context"
"encoding/json"
"io"
"os"
@ -26,6 +27,7 @@ not if it will actually reject plugins with invalid permissions
*/
func (s *PluginSuite) TestBrokerInit() {
ctx := context.Background()
tests := []struct {
name string
action func(*testing.T)
@ -59,16 +61,17 @@ func (s *PluginSuite) TestBrokerInit() {
if tc.action != nil {
tc.action(t)
}
_, err := s.InitBroker(&tc.procCfg)
_, err := s.InitBroker(ctx, &tc.procCfg)
cstest.RequireErrorContains(t, err, tc.expectedErr)
})
}
}
func (s *PluginSuite) TestBrokerRun() {
ctx := context.Background()
t := s.T()
pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)
tomb := tomb.Tomb{}

View file

@ -15,11 +15,10 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/models"
)
var ctx = context.Background()
func resetTestTomb(testTomb *tomb.Tomb, pw *PluginWatcher) {
testTomb.Kill(nil)
<-pw.PluginEvents
if err := testTomb.Wait(); err != nil {
log.Fatal(err)
}
@ -46,13 +45,17 @@ func listenChannelWithTimeout(ctx context.Context, channel chan string) error {
case <-ctx.Done():
return ctx.Err()
}
return nil
}
func TestPluginWatcherInterval(t *testing.T) {
ctx := context.Background()
if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows because timing is not reliable")
}
pw := PluginWatcher{}
alertsByPluginName := make(map[string][]*models.Alert)
testTomb := tomb.Tomb{}
@ -66,6 +69,7 @@ func TestPluginWatcherInterval(t *testing.T) {
ct, cancel := context.WithTimeout(ctx, time.Microsecond)
defer cancel()
err := listenChannelWithTimeout(ct, pw.PluginEvents)
cstest.RequireErrorContains(t, err, "context deadline exceeded")
resetTestTomb(&testTomb, &pw)
@ -74,6 +78,7 @@ func TestPluginWatcherInterval(t *testing.T) {
ct, cancel = context.WithTimeout(ctx, time.Millisecond*5)
defer cancel()
err = listenChannelWithTimeout(ct, pw.PluginEvents)
require.NoError(t, err)
resetTestTomb(&testTomb, &pw)
@ -81,9 +86,12 @@ func TestPluginWatcherInterval(t *testing.T) {
}
func TestPluginAlertCountWatcher(t *testing.T) {
ctx := context.Background()
if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows because timing is not reliable")
}
pw := PluginWatcher{}
alertsByPluginName := make(map[string][]*models.Alert)
configs := map[string]PluginConfig{
@ -92,28 +100,34 @@ func TestPluginAlertCountWatcher(t *testing.T) {
},
}
testTomb := tomb.Tomb{}
pw.Init(configs, alertsByPluginName)
pw.Start(&testTomb)
// Channel won't contain any events since threshold is not crossed.
ct, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
err := listenChannelWithTimeout(ct, pw.PluginEvents)
cstest.RequireErrorContains(t, err, "context deadline exceeded")
// Channel won't contain any events since threshold is not crossed.
resetWatcherAlertCounter(&pw)
insertNAlertsToPlugin(&pw, 4, "testPlugin")
ct, cancel = context.WithTimeout(ctx, time.Second)
defer cancel()
err = listenChannelWithTimeout(ct, pw.PluginEvents)
cstest.RequireErrorContains(t, err, "context deadline exceeded")
// Channel will contain an event since threshold is crossed.
resetWatcherAlertCounter(&pw)
insertNAlertsToPlugin(&pw, 5, "testPlugin")
ct, cancel = context.WithTimeout(ctx, time.Second)
defer cancel()
err = listenChannelWithTimeout(ct, pw.PluginEvents)
require.NoError(t, err)
resetTestTomb(&testTomb, &pw)