mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-10 20:05:55 +02:00
context propagation: pkg/csplugin (#3273)
This commit is contained in:
parent
50d115b914
commit
8ff58ee74e
11 changed files with 66 additions and 33 deletions
|
@ -275,7 +275,8 @@ func (cli cliNotifications) newTestCmd() *cobra.Command {
|
|||
Args: cobra.ExactArgs(1),
|
||||
DisableAutoGenTag: true,
|
||||
ValidArgsFunction: cli.notificationConfigFilter,
|
||||
PreRunE: func(_ *cobra.Command, args []string) error {
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
cfg := cli.cfg()
|
||||
pconfigs, err := cli.getPluginConfigs()
|
||||
if err != nil {
|
||||
|
@ -286,7 +287,7 @@ func (cli cliNotifications) newTestCmd() *cobra.Command {
|
|||
return fmt.Errorf("plugin name: '%s' does not exist", args[0])
|
||||
}
|
||||
// Create a single profile with plugin name as notification name
|
||||
return pluginBroker.Init(cfg.PluginConfig, []*csconfig.ProfileCfg{
|
||||
return pluginBroker.Init(ctx, cfg.PluginConfig, []*csconfig.ProfileCfg{
|
||||
{
|
||||
Notifications: []string{
|
||||
pcfg.Name,
|
||||
|
@ -377,12 +378,13 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
|
|||
|
||||
return nil
|
||||
},
|
||||
RunE: func(_ *cobra.Command, _ []string) error {
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
var (
|
||||
pluginBroker csplugin.PluginBroker
|
||||
pluginTomb tomb.Tomb
|
||||
)
|
||||
|
||||
ctx := cmd.Context()
|
||||
cfg := cli.cfg()
|
||||
|
||||
if alertOverride != "" {
|
||||
|
@ -391,7 +393,7 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
|
|||
}
|
||||
}
|
||||
|
||||
err := pluginBroker.Init(cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths)
|
||||
err := pluginBroker.Init(ctx, cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't initialize plugins: %w", err)
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.AP
|
|||
return nil, errors.New("plugins are enabled, but config_paths.plugin_dir is not defined")
|
||||
}
|
||||
|
||||
err = pluginBroker.Init(cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths)
|
||||
err = pluginBroker.Init(ctx, cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to run plugin broker: %w", err)
|
||||
}
|
||||
|
|
|
@ -72,8 +72,8 @@ func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.Wat
|
|||
}
|
||||
|
||||
func LoginToTestAPI(t *testing.T, ctx context.Context, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse {
|
||||
body := CreateTestMachine(t, router, "")
|
||||
ValidateMachine(t, "test", config.API.Server.DbConfig)
|
||||
body := CreateTestMachine(t, ctx, router, "")
|
||||
ValidateMachine(t, ctx, "test", config.API.Server.DbConfig)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body))
|
||||
|
|
|
@ -180,9 +180,7 @@ func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csc
|
|||
return router, config
|
||||
}
|
||||
|
||||
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
|
||||
ctx := context.TODO()
|
||||
|
||||
func ValidateMachine(t *testing.T, ctx context.Context, machineID string, config *csconfig.DatabaseCfg) {
|
||||
dbClient, err := database.NewClient(ctx, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -269,7 +267,7 @@ func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map
|
|||
return response, resp.Code
|
||||
}
|
||||
|
||||
func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string {
|
||||
func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, token string) string {
|
||||
regReq := MachineTest
|
||||
regReq.RegistrationToken = token
|
||||
b, err := json.Marshal(regReq)
|
||||
|
@ -277,8 +275,6 @@ func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string {
|
|||
|
||||
body := string(b)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body))
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
|
|
@ -14,7 +14,7 @@ func TestLogin(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
router, config := NewAPITest(t, ctx)
|
||||
|
||||
body := CreateTestMachine(t, router, "")
|
||||
body := CreateTestMachine(t, ctx, router, "")
|
||||
|
||||
// Login with machine not validated yet
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -53,7 +53,7 @@ func TestLogin(t *testing.T) {
|
|||
assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String())
|
||||
|
||||
// Validate machine
|
||||
ValidateMachine(t, "test", config.API.Server.DbConfig)
|
||||
ValidateMachine(t, ctx, "test", config.API.Server.DbConfig)
|
||||
|
||||
// Login with invalid password
|
||||
w = httptest.NewRecorder()
|
||||
|
|
|
@ -131,7 +131,7 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
router, _ := NewAPITest(t, ctx)
|
||||
|
||||
body := CreateTestMachine(t, router, "")
|
||||
body := CreateTestMachine(t, ctx, router, "")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body))
|
||||
|
|
|
@ -72,7 +72,7 @@ type ProfileAlert struct {
|
|||
Alert *models.Alert
|
||||
}
|
||||
|
||||
func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error {
|
||||
func (pb *PluginBroker) Init(ctx context.Context, pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error {
|
||||
pb.PluginChannel = make(chan ProfileAlert)
|
||||
pb.notificationConfigsByPluginType = make(map[string][][]byte)
|
||||
pb.notificationPluginByName = make(map[string]protobufs.NotifierServer)
|
||||
|
@ -85,7 +85,7 @@ func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*cs
|
|||
if err := pb.loadConfig(configPaths.NotificationDir); err != nil {
|
||||
return fmt.Errorf("while loading plugin config: %w", err)
|
||||
}
|
||||
if err := pb.loadPlugins(configPaths.PluginDir); err != nil {
|
||||
if err := pb.loadPlugins(ctx, configPaths.PluginDir); err != nil {
|
||||
return fmt.Errorf("while loading plugin: %w", err)
|
||||
}
|
||||
pb.watcher = PluginWatcher{}
|
||||
|
@ -230,7 +230,7 @@ func (pb *PluginBroker) verifyPluginBinaryWithProfile() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (pb *PluginBroker) loadPlugins(path string) error {
|
||||
func (pb *PluginBroker) loadPlugins(ctx context.Context, path string) error {
|
||||
binaryPaths, err := listFilesAtPath(path)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -265,7 +265,7 @@ func (pb *PluginBroker) loadPlugins(path string) error {
|
|||
return err
|
||||
}
|
||||
data = []byte(csstring.StrictExpand(string(data), os.LookupEnv))
|
||||
_, err = pluginClient.Configure(context.Background(), &protobufs.Config{Config: data})
|
||||
_, err = pluginClient.Configure(ctx, &protobufs.Config{Config: data})
|
||||
if err != nil {
|
||||
return fmt.Errorf("while configuring %s: %w", pc.Name, err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package csplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
@ -96,6 +97,7 @@ func (s *PluginSuite) TearDownTest() {
|
|||
|
||||
func (s *PluginSuite) SetupSubTest() {
|
||||
var err error
|
||||
|
||||
t := s.T()
|
||||
|
||||
s.runDir, err = os.MkdirTemp("", "cs_plugin_test")
|
||||
|
@ -127,6 +129,7 @@ func (s *PluginSuite) SetupSubTest() {
|
|||
|
||||
func (s *PluginSuite) TearDownSubTest() {
|
||||
t := s.T()
|
||||
|
||||
if s.pluginBroker != nil {
|
||||
s.pluginBroker.Kill()
|
||||
s.pluginBroker = nil
|
||||
|
@ -140,19 +143,24 @@ func (s *PluginSuite) TearDownSubTest() {
|
|||
os.Remove("./out")
|
||||
}
|
||||
|
||||
func (s *PluginSuite) InitBroker(procCfg *csconfig.PluginCfg) (*PluginBroker, error) {
|
||||
func (s *PluginSuite) InitBroker(ctx context.Context, procCfg *csconfig.PluginCfg) (*PluginBroker, error) {
|
||||
pb := PluginBroker{}
|
||||
|
||||
if procCfg == nil {
|
||||
procCfg = &csconfig.PluginCfg{}
|
||||
}
|
||||
|
||||
profiles := csconfig.NewDefaultConfig().API.Server.Profiles
|
||||
profiles = append(profiles, &csconfig.ProfileCfg{
|
||||
Notifications: []string{"dummy_default"},
|
||||
})
|
||||
err := pb.Init(procCfg, profiles, &csconfig.ConfigurationPaths{
|
||||
|
||||
err := pb.Init(ctx, procCfg, profiles, &csconfig.ConfigurationPaths{
|
||||
PluginDir: s.pluginDir,
|
||||
NotificationDir: s.notifDir,
|
||||
})
|
||||
|
||||
s.pluginBroker = &pb
|
||||
|
||||
return s.pluginBroker, err
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ package csplugin
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
|
@ -53,6 +54,7 @@ func (s *PluginSuite) writeconfig(config PluginConfig) {
|
|||
}
|
||||
|
||||
func (s *PluginSuite) TestBrokerInit() {
|
||||
ctx := context.Background()
|
||||
tests := []struct {
|
||||
name string
|
||||
action func(*testing.T)
|
||||
|
@ -135,20 +137,22 @@ func (s *PluginSuite) TestBrokerInit() {
|
|||
tc.action(t)
|
||||
}
|
||||
|
||||
_, err := s.InitBroker(&tc.procCfg)
|
||||
_, err := s.InitBroker(ctx, &tc.procCfg)
|
||||
cstest.RequireErrorContains(t, err, tc.expectedErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PluginSuite) TestBrokerNoThreshold() {
|
||||
ctx := context.Background()
|
||||
|
||||
var alerts []models.Alert
|
||||
|
||||
DefaultEmptyTicker = 50 * time.Millisecond
|
||||
|
||||
t := s.T()
|
||||
|
||||
pb, err := s.InitBroker(nil)
|
||||
pb, err := s.InitBroker(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tomb := tomb.Tomb{}
|
||||
|
@ -187,6 +191,8 @@ func (s *PluginSuite) TestBrokerNoThreshold() {
|
|||
}
|
||||
|
||||
func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
|
||||
ctx := context.Background()
|
||||
|
||||
// test grouping by "time"
|
||||
DefaultEmptyTicker = 50 * time.Millisecond
|
||||
|
||||
|
@ -198,7 +204,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
|
|||
cfg.GroupWait = 1 * time.Second
|
||||
s.writeconfig(cfg)
|
||||
|
||||
pb, err := s.InitBroker(nil)
|
||||
pb, err := s.InitBroker(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tomb := tomb.Tomb{}
|
||||
|
@ -224,6 +230,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
|
|||
}
|
||||
|
||||
func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
|
||||
ctx := context.Background()
|
||||
DefaultEmptyTicker = 50 * time.Millisecond
|
||||
|
||||
t := s.T()
|
||||
|
@ -234,7 +241,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
|
|||
cfg.GroupWait = 4 * time.Second
|
||||
s.writeconfig(cfg)
|
||||
|
||||
pb, err := s.InitBroker(nil)
|
||||
pb, err := s.InitBroker(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tomb := tomb.Tomb{}
|
||||
|
@ -264,6 +271,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
|
|||
}
|
||||
|
||||
func (s *PluginSuite) TestBrokerRunGroupThreshold() {
|
||||
ctx := context.Background()
|
||||
// test grouping by "size"
|
||||
DefaultEmptyTicker = 50 * time.Millisecond
|
||||
|
||||
|
@ -274,7 +282,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
|
|||
cfg.GroupThreshold = 4
|
||||
s.writeconfig(cfg)
|
||||
|
||||
pb, err := s.InitBroker(nil)
|
||||
pb, err := s.InitBroker(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tomb := tomb.Tomb{}
|
||||
|
@ -318,6 +326,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
|
|||
}
|
||||
|
||||
func (s *PluginSuite) TestBrokerRunTimeThreshold() {
|
||||
ctx := context.Background()
|
||||
DefaultEmptyTicker = 50 * time.Millisecond
|
||||
|
||||
t := s.T()
|
||||
|
@ -327,7 +336,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
|
|||
cfg.GroupWait = 1 * time.Second
|
||||
s.writeconfig(cfg)
|
||||
|
||||
pb, err := s.InitBroker(nil)
|
||||
pb, err := s.InitBroker(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tomb := tomb.Tomb{}
|
||||
|
@ -353,11 +362,12 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
|
|||
}
|
||||
|
||||
func (s *PluginSuite) TestBrokerRunSimple() {
|
||||
ctx := context.Background()
|
||||
DefaultEmptyTicker = 50 * time.Millisecond
|
||||
|
||||
t := s.T()
|
||||
|
||||
pb, err := s.InitBroker(nil)
|
||||
pb, err := s.InitBroker(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tomb := tomb.Tomb{}
|
||||
|
|
|
@ -4,6 +4,7 @@ package csplugin
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
|
@ -26,6 +27,7 @@ not if it will actually reject plugins with invalid permissions
|
|||
*/
|
||||
|
||||
func (s *PluginSuite) TestBrokerInit() {
|
||||
ctx := context.Background()
|
||||
tests := []struct {
|
||||
name string
|
||||
action func(*testing.T)
|
||||
|
@ -59,16 +61,17 @@ func (s *PluginSuite) TestBrokerInit() {
|
|||
if tc.action != nil {
|
||||
tc.action(t)
|
||||
}
|
||||
_, err := s.InitBroker(&tc.procCfg)
|
||||
_, err := s.InitBroker(ctx, &tc.procCfg)
|
||||
cstest.RequireErrorContains(t, err, tc.expectedErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PluginSuite) TestBrokerRun() {
|
||||
ctx := context.Background()
|
||||
t := s.T()
|
||||
|
||||
pb, err := s.InitBroker(nil)
|
||||
pb, err := s.InitBroker(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tomb := tomb.Tomb{}
|
||||
|
|
|
@ -15,11 +15,10 @@ import (
|
|||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
func resetTestTomb(testTomb *tomb.Tomb, pw *PluginWatcher) {
|
||||
testTomb.Kill(nil)
|
||||
<-pw.PluginEvents
|
||||
|
||||
if err := testTomb.Wait(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -46,13 +45,17 @@ func listenChannelWithTimeout(ctx context.Context, channel chan string) error {
|
|||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPluginWatcherInterval(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping test on windows because timing is not reliable")
|
||||
}
|
||||
|
||||
pw := PluginWatcher{}
|
||||
alertsByPluginName := make(map[string][]*models.Alert)
|
||||
testTomb := tomb.Tomb{}
|
||||
|
@ -66,6 +69,7 @@ func TestPluginWatcherInterval(t *testing.T) {
|
|||
|
||||
ct, cancel := context.WithTimeout(ctx, time.Microsecond)
|
||||
defer cancel()
|
||||
|
||||
err := listenChannelWithTimeout(ct, pw.PluginEvents)
|
||||
cstest.RequireErrorContains(t, err, "context deadline exceeded")
|
||||
resetTestTomb(&testTomb, &pw)
|
||||
|
@ -74,6 +78,7 @@ func TestPluginWatcherInterval(t *testing.T) {
|
|||
|
||||
ct, cancel = context.WithTimeout(ctx, time.Millisecond*5)
|
||||
defer cancel()
|
||||
|
||||
err = listenChannelWithTimeout(ct, pw.PluginEvents)
|
||||
require.NoError(t, err)
|
||||
resetTestTomb(&testTomb, &pw)
|
||||
|
@ -81,9 +86,12 @@ func TestPluginWatcherInterval(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPluginAlertCountWatcher(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping test on windows because timing is not reliable")
|
||||
}
|
||||
|
||||
pw := PluginWatcher{}
|
||||
alertsByPluginName := make(map[string][]*models.Alert)
|
||||
configs := map[string]PluginConfig{
|
||||
|
@ -92,28 +100,34 @@ func TestPluginAlertCountWatcher(t *testing.T) {
|
|||
},
|
||||
}
|
||||
testTomb := tomb.Tomb{}
|
||||
|
||||
pw.Init(configs, alertsByPluginName)
|
||||
pw.Start(&testTomb)
|
||||
|
||||
// Channel won't contain any events since threshold is not crossed.
|
||||
ct, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := listenChannelWithTimeout(ct, pw.PluginEvents)
|
||||
cstest.RequireErrorContains(t, err, "context deadline exceeded")
|
||||
|
||||
// Channel won't contain any events since threshold is not crossed.
|
||||
resetWatcherAlertCounter(&pw)
|
||||
insertNAlertsToPlugin(&pw, 4, "testPlugin")
|
||||
|
||||
ct, cancel = context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = listenChannelWithTimeout(ct, pw.PluginEvents)
|
||||
cstest.RequireErrorContains(t, err, "context deadline exceeded")
|
||||
|
||||
// Channel will contain an event since threshold is crossed.
|
||||
resetWatcherAlertCounter(&pw)
|
||||
insertNAlertsToPlugin(&pw, 5, "testPlugin")
|
||||
|
||||
ct, cancel = context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = listenChannelWithTimeout(ct, pw.PluginEvents)
|
||||
require.NoError(t, err)
|
||||
resetTestTomb(&testTomb, &pw)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue