mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-11 12:25:53 +02:00
refactor context (cscli, pkg/database) (#3071)
* cscli: helper require.DBClient() * refactor pkg/database: explicit context to dbclient constructor * lint
This commit is contained in:
parent
24687e982a
commit
bd4540b1bf
15 changed files with 97 additions and 49 deletions
|
@ -24,7 +24,6 @@ import (
|
||||||
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
|
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
|
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
|
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database"
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
)
|
)
|
||||||
|
@ -378,6 +377,7 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
|
||||||
alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil {
|
alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ActiveDecision != nil {
|
if ActiveDecision != nil {
|
||||||
alertDeleteFilter.ActiveDecisionEquals = ActiveDecision
|
alertDeleteFilter.ActiveDecisionEquals = ActiveDecision
|
||||||
}
|
}
|
||||||
|
@ -385,21 +385,27 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
|
||||||
if *alertDeleteFilter.ScopeEquals == "" {
|
if *alertDeleteFilter.ScopeEquals == "" {
|
||||||
alertDeleteFilter.ScopeEquals = nil
|
alertDeleteFilter.ScopeEquals = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if *alertDeleteFilter.ValueEquals == "" {
|
if *alertDeleteFilter.ValueEquals == "" {
|
||||||
alertDeleteFilter.ValueEquals = nil
|
alertDeleteFilter.ValueEquals = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if *alertDeleteFilter.ScenarioEquals == "" {
|
if *alertDeleteFilter.ScenarioEquals == "" {
|
||||||
alertDeleteFilter.ScenarioEquals = nil
|
alertDeleteFilter.ScenarioEquals = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if *alertDeleteFilter.IPEquals == "" {
|
if *alertDeleteFilter.IPEquals == "" {
|
||||||
alertDeleteFilter.IPEquals = nil
|
alertDeleteFilter.IPEquals = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if *alertDeleteFilter.RangeEquals == "" {
|
if *alertDeleteFilter.RangeEquals == "" {
|
||||||
alertDeleteFilter.RangeEquals = nil
|
alertDeleteFilter.RangeEquals = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if contained != nil && *contained {
|
if contained != nil && *contained {
|
||||||
alertDeleteFilter.Contains = new(bool)
|
alertDeleteFilter.Contains = new(bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
limit := 0
|
limit := 0
|
||||||
alertDeleteFilter.Limit = &limit
|
alertDeleteFilter.Limit = &limit
|
||||||
} else {
|
} else {
|
||||||
|
@ -419,6 +425,7 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
|
||||||
return fmt.Errorf("unable to delete alert: %w", err)
|
return fmt.Errorf("unable to delete alert: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("%s alert(s) deleted", alerts.NbDeleted)
|
log.Infof("%s alert(s) deleted", alerts.NbDeleted)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -558,14 +565,14 @@ func (cli *cliAlerts) NewFlushCmd() *cobra.Command {
|
||||||
/!\ This command can be used only on the same machine than the local API`,
|
/!\ This command can be used only on the same machine than the local API`,
|
||||||
Example: `cscli alerts flush --max-items 1000 --max-age 7d`,
|
Example: `cscli alerts flush --max-items 1000 --max-age 7d`,
|
||||||
DisableAutoGenTag: true,
|
DisableAutoGenTag: true,
|
||||||
RunE: func(_ *cobra.Command, _ []string) error {
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
cfg := cli.cfg()
|
cfg := cli.cfg()
|
||||||
if err := require.LAPI(cfg); err != nil {
|
if err := require.LAPI(cfg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
db, err := database.NewClient(cfg.DbConfig)
|
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create new database client: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
log.Info("Flushing alerts. !! This may take a long time !!")
|
log.Info("Flushing alerts. !! This may take a long time !!")
|
||||||
err = db.FlushAlerts(maxAge, maxItems)
|
err = db.FlushAlerts(maxAge, maxItems)
|
||||||
|
|
|
@ -57,7 +57,7 @@ Note: This command requires database direct access, so is intended to be run on
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.MinimumNArgs(1),
|
||||||
Aliases: []string{"bouncer"},
|
Aliases: []string{"bouncer"},
|
||||||
DisableAutoGenTag: true,
|
DisableAutoGenTag: true,
|
||||||
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
|
PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
cfg := cli.cfg()
|
cfg := cli.cfg()
|
||||||
|
@ -66,9 +66,9 @@ Note: This command requires database direct access, so is intended to be run on
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cli.db, err = database.NewClient(cfg.DbConfig)
|
cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("can't connect to the database: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -128,14 +128,14 @@ Note: This command requires database direct access, so is intended to be run on
|
||||||
Example: `cscli machines [action]`,
|
Example: `cscli machines [action]`,
|
||||||
DisableAutoGenTag: true,
|
DisableAutoGenTag: true,
|
||||||
Aliases: []string{"machine"},
|
Aliases: []string{"machine"},
|
||||||
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
|
PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
var err error
|
var err error
|
||||||
if err = require.LAPI(cli.cfg()); err != nil {
|
if err = require.LAPI(cli.cfg()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cli.db, err = database.NewClient(cli.cfg().DbConfig)
|
cli.db, err = require.DBClient(cmd.Context(), cli.cfg().DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create new database client: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
|
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/apiserver"
|
"github.com/crowdsecurity/crowdsec/pkg/apiserver"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type cliPapi struct {
|
type cliPapi struct {
|
||||||
|
@ -56,12 +55,12 @@ func (cli *cliPapi) NewStatusCmd() *cobra.Command {
|
||||||
Short: "Get status of the Polling API",
|
Short: "Get status of the Polling API",
|
||||||
Args: cobra.MinimumNArgs(0),
|
Args: cobra.MinimumNArgs(0),
|
||||||
DisableAutoGenTag: true,
|
DisableAutoGenTag: true,
|
||||||
RunE: func(_ *cobra.Command, _ []string) error {
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
var err error
|
var err error
|
||||||
cfg := cli.cfg()
|
cfg := cli.cfg()
|
||||||
db, err := database.NewClient(cfg.DbConfig)
|
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to initialize database client: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
|
apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
|
||||||
|
@ -105,14 +104,14 @@ func (cli *cliPapi) NewSyncCmd() *cobra.Command {
|
||||||
Short: "Sync with the Polling API, pulling all non-expired orders for the instance",
|
Short: "Sync with the Polling API, pulling all non-expired orders for the instance",
|
||||||
Args: cobra.MinimumNArgs(0),
|
Args: cobra.MinimumNArgs(0),
|
||||||
DisableAutoGenTag: true,
|
DisableAutoGenTag: true,
|
||||||
RunE: func(_ *cobra.Command, _ []string) error {
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
var err error
|
var err error
|
||||||
cfg := cli.cfg()
|
cfg := cli.cfg()
|
||||||
t := tomb.Tomb{}
|
t := tomb.Tomb{}
|
||||||
|
|
||||||
db, err := database.NewClient(cfg.DbConfig)
|
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to initialize database client: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
|
apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
|
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
|
||||||
|
"github.com/crowdsecurity/crowdsec/pkg/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
func LAPI(c *csconfig.Config) error {
|
func LAPI(c *csconfig.Config) error {
|
||||||
|
@ -48,6 +49,15 @@ func CAPIRegistered(c *csconfig.Config) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DBClient(ctx context.Context, dbcfg *csconfig.DatabaseCfg) (*database.Client, error) {
|
||||||
|
db, err := database.NewClient(ctx, dbcfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
func DB(c *csconfig.Config) error {
|
func DB(c *csconfig.Config) error {
|
||||||
if err := c.LoadDBConfig(true); err != nil {
|
if err := c.LoadDBConfig(true); err != nil {
|
||||||
return fmt.Errorf("this command requires direct database access (must be run on the local API machine): %w", err)
|
return fmt.Errorf("this command requires direct database access (must be run on the local API machine): %w", err)
|
||||||
|
|
|
@ -463,9 +463,9 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error {
|
||||||
w := bytes.NewBuffer(nil)
|
w := bytes.NewBuffer(nil)
|
||||||
zipWriter := zip.NewWriter(w)
|
zipWriter := zip.NewWriter(w)
|
||||||
|
|
||||||
db, err := database.NewClient(cfg.DbConfig)
|
db, err := require.DBClient(ctx, cfg.DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Could not connect to database: %s", err)
|
log.Warn(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = cfg.LoadAPIServer(true); err != nil {
|
if err = cfg.LoadAPIServer(true); err != nil {
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
|
|
||||||
|
@ -41,9 +42,10 @@ func StartRunSvc() error {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if cConfig.DbConfig != nil {
|
ctx := context.TODO()
|
||||||
dbClient, err = database.NewClient(cConfig.DbConfig)
|
|
||||||
|
|
||||||
|
if cConfig.DbConfig != nil {
|
||||||
|
dbClient, err = database.NewClient(ctx, cConfig.DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create database client: %w", err)
|
return fmt.Errorf("unable to create database client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
|
|
||||||
|
@ -80,8 +81,10 @@ func WindowsRun() error {
|
||||||
var dbClient *database.Client
|
var dbClient *database.Client
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
ctx := context.TODO()
|
||||||
|
|
||||||
if cConfig.DbConfig != nil {
|
if cConfig.DbConfig != nil {
|
||||||
dbClient, err = database.NewClient(cConfig.DbConfig)
|
dbClient, err = database.NewClient(ctx, cConfig.DbConfig)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create database client: %w", err)
|
return fmt.Errorf("unable to create database client: %w", err)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
@ -322,8 +323,10 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
|
||||||
crowdsecTomb = tomb.Tomb{}
|
crowdsecTomb = tomb.Tomb{}
|
||||||
pluginTomb = tomb.Tomb{}
|
pluginTomb = tomb.Tomb{}
|
||||||
|
|
||||||
|
ctx := context.TODO()
|
||||||
|
|
||||||
if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil {
|
if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil {
|
||||||
dbClient, err := database.NewClient(cConfig.API.Server.DbConfig)
|
dbClient, err := database.NewClient(ctx, cConfig.API.Server.DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get database client: %w", err)
|
return fmt.Errorf("failed to get database client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,9 +38,11 @@ import (
|
||||||
func getDBClient(t *testing.T) *database.Client {
|
func getDBClient(t *testing.T) *database.Client {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
dbPath, err := os.CreateTemp("", "*sqlite")
|
dbPath, err := os.CreateTemp("", "*sqlite")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
|
dbClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{
|
||||||
Type: "sqlite",
|
Type: "sqlite",
|
||||||
DbName: "crowdsec",
|
DbName: "crowdsec",
|
||||||
DbPath: dbPath.Name(),
|
DbPath: dbPath.Name(),
|
||||||
|
@ -56,7 +58,7 @@ func getAPIC(t *testing.T) *apic {
|
||||||
|
|
||||||
return &apic{
|
return &apic{
|
||||||
AlertsAddChan: make(chan []*models.Alert),
|
AlertsAddChan: make(chan []*models.Alert),
|
||||||
//DecisionDeleteChan: make(chan []*models.Decision),
|
// DecisionDeleteChan: make(chan []*models.Decision),
|
||||||
dbClient: dbClient,
|
dbClient: dbClient,
|
||||||
mu: sync.Mutex{},
|
mu: sync.Mutex{},
|
||||||
startup: true,
|
startup: true,
|
||||||
|
@ -176,10 +178,11 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
scenarios, err := api.FetchScenariosListFromDB()
|
scenarios, err := api.FetchScenariosListFromDB()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
for machineID := range tc.machineIDsWithScenarios {
|
for machineID := range tc.machineIDsWithScenarios {
|
||||||
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
|
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
|
||||||
}
|
}
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
|
assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
|
||||||
})
|
})
|
||||||
|
@ -234,6 +237,7 @@ func TestNewAPIC(t *testing.T) {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
setConfig()
|
setConfig()
|
||||||
httpmock.Activate()
|
httpmock.Activate()
|
||||||
|
|
||||||
defer httpmock.DeactivateAndReset()
|
defer httpmock.DeactivateAndReset()
|
||||||
httpmock.RegisterResponder("POST", "http://foobar/v3/watchers/login", httpmock.NewBytesResponder(
|
httpmock.RegisterResponder("POST", "http://foobar/v3/watchers/login", httpmock.NewBytesResponder(
|
||||||
200, jsonMarshalX(
|
200, jsonMarshalX(
|
||||||
|
@ -353,6 +357,7 @@ func TestAPICGetMetrics(t *testing.T) {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
apiClient := getAPIC(t)
|
apiClient := getAPIC(t)
|
||||||
cleanUp(apiClient)
|
cleanUp(apiClient)
|
||||||
|
|
||||||
for i, machineID := range tc.machineIDs {
|
for i, machineID := range tc.machineIDs {
|
||||||
apiClient.dbClient.Ent.Machine.Create().
|
apiClient.dbClient.Ent.Machine.Create().
|
||||||
SetMachineId(machineID).
|
SetMachineId(machineID).
|
||||||
|
@ -548,7 +553,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
|
||||||
|
|
||||||
func TestAPICWhitelists(t *testing.T) {
|
func TestAPICWhitelists(t *testing.T) {
|
||||||
api := getAPIC(t)
|
api := getAPIC(t)
|
||||||
//one whitelist on IP, one on CIDR
|
// one whitelist on IP, one on CIDR
|
||||||
api.whitelists = &csconfig.CapiWhitelist{}
|
api.whitelists = &csconfig.CapiWhitelist{}
|
||||||
api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4"))
|
api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4"))
|
||||||
|
|
||||||
|
@ -593,7 +598,7 @@ func TestAPICWhitelists(t *testing.T) {
|
||||||
Scope: ptr.Of("Ip"),
|
Scope: ptr.Of("Ip"),
|
||||||
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
|
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
|
||||||
{
|
{
|
||||||
Value: ptr.Of("13.2.3.4"), //wl by cidr
|
Value: ptr.Of("13.2.3.4"), // wl by cidr
|
||||||
Duration: ptr.Of("24h"),
|
Duration: ptr.Of("24h"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -614,7 +619,7 @@ func TestAPICWhitelists(t *testing.T) {
|
||||||
Scope: ptr.Of("Ip"),
|
Scope: ptr.Of("Ip"),
|
||||||
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
|
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
|
||||||
{
|
{
|
||||||
Value: ptr.Of("13.2.3.5"), //wl by cidr
|
Value: ptr.Of("13.2.3.5"), // wl by cidr
|
||||||
Duration: ptr.Of("24h"),
|
Duration: ptr.Of("24h"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -634,7 +639,7 @@ func TestAPICWhitelists(t *testing.T) {
|
||||||
Scope: ptr.Of("Ip"),
|
Scope: ptr.Of("Ip"),
|
||||||
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
|
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
|
||||||
{
|
{
|
||||||
Value: ptr.Of("9.2.3.4"), //wl by ip
|
Value: ptr.Of("9.2.3.4"), // wl by ip
|
||||||
Duration: ptr.Of("24h"),
|
Duration: ptr.Of("24h"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -685,7 +690,7 @@ func TestAPICWhitelists(t *testing.T) {
|
||||||
err = api.PullTop(false)
|
err = api.PullTop(false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assertTotalDecisionCount(t, api.dbClient, 5) //2 from FIRE + 2 from bl + 1 existing
|
assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing
|
||||||
assertTotalValidDecisionCount(t, api.dbClient, 4)
|
assertTotalValidDecisionCount(t, api.dbClient, 4)
|
||||||
assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list.
|
assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list.
|
||||||
alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background())
|
alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background())
|
||||||
|
@ -1103,6 +1108,7 @@ func TestAPICPush(t *testing.T) {
|
||||||
|
|
||||||
httpmock.Activate()
|
httpmock.Activate()
|
||||||
defer httpmock.DeactivateAndReset()
|
defer httpmock.DeactivateAndReset()
|
||||||
|
|
||||||
apic, err := apiclient.NewDefaultClient(
|
apic, err := apiclient.NewDefaultClient(
|
||||||
url,
|
url,
|
||||||
"/api",
|
"/api",
|
||||||
|
|
|
@ -162,7 +162,9 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro
|
||||||
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
||||||
var flushScheduler *gocron.Scheduler
|
var flushScheduler *gocron.Scheduler
|
||||||
|
|
||||||
dbClient, err := database.NewClient(config.DbConfig)
|
ctx := context.TODO()
|
||||||
|
|
||||||
|
dbClient, err := database.NewClient(ctx, config.DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to init database client: %w", err)
|
return nil, fmt.Errorf("unable to init database client: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -227,7 +229,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
||||||
|
|
||||||
controller := &controllers.Controller{
|
controller := &controllers.Controller{
|
||||||
DBClient: dbClient,
|
DBClient: dbClient,
|
||||||
Ectx: context.Background(),
|
Ectx: ctx,
|
||||||
Router: router,
|
Router: router,
|
||||||
Profiles: config.Profiles,
|
Profiles: config.Profiles,
|
||||||
Log: clog,
|
Log: clog,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package apiserver
|
package apiserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -161,7 +162,9 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
|
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
|
||||||
dbClient, err := database.NewClient(config)
|
ctx := context.Background()
|
||||||
|
|
||||||
|
dbClient, err := database.NewClient(ctx, config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = dbClient.ValidateMachine(machineID)
|
err = dbClient.ValidateMachine(machineID)
|
||||||
|
@ -169,7 +172,9 @@ func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCf
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string {
|
func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string {
|
||||||
dbClient, err := database.NewClient(config)
|
ctx := context.Background()
|
||||||
|
|
||||||
|
dbClient, err := database.NewClient(ctx, config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
machines, err := dbClient.ListMachines()
|
machines, err := dbClient.ListMachines()
|
||||||
|
@ -260,7 +265,9 @@ func CreateTestMachine(t *testing.T, router *gin.Engine) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
|
func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
|
||||||
dbClient, err := database.NewClient(config)
|
ctx := context.Background()
|
||||||
|
|
||||||
|
dbClient, err := database.NewClient(ctx, config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
apiKey, err := middlewares.GenerateAPIKey(keyLength)
|
apiKey, err := middlewares.GenerateAPIKey(keyLength)
|
||||||
|
@ -356,10 +363,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
|
||||||
req.Header.Set("User-Agent", UserAgent)
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
api.router.ServeHTTP(w, req)
|
api.router.ServeHTTP(w, req)
|
||||||
assert.Equal(t, 404, w.Code)
|
assert.Equal(t, 404, w.Code)
|
||||||
//wait for the request to happen
|
// wait for the request to happen
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
//check file content
|
// check file content
|
||||||
data, err := os.ReadFile(expectedFile)
|
data, err := os.ReadFile(expectedFile)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -406,10 +413,10 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
|
||||||
req.Header.Set("User-Agent", UserAgent)
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
api.router.ServeHTTP(w, req)
|
api.router.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
//wait for the request to happen
|
// wait for the request to happen
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
//check file content
|
// check file content
|
||||||
x, err := os.ReadFile(expectedFile)
|
x, err := os.ReadFile(expectedFile)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
require.Empty(t, x)
|
require.Empty(t, x)
|
||||||
|
|
|
@ -947,7 +947,7 @@ func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string
|
||||||
Count int
|
Count int
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.TODO()
|
||||||
|
|
||||||
query := c.Ent.Alert.Query()
|
query := c.Ent.Alert.Query()
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig.
|
||||||
return drv, nil
|
return drv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
|
func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, error) {
|
||||||
var client *ent.Client
|
var client *ent.Client
|
||||||
|
|
||||||
if config == nil {
|
if config == nil {
|
||||||
|
@ -69,7 +69,7 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
|
||||||
|
|
||||||
typ, dia, err := config.ConnectionDialect()
|
typ, dia, err := config.ConnectionDialect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err //unsupported database caught here
|
return nil, err // unsupported database caught here
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Type == "sqlite" {
|
if config.Type == "sqlite" {
|
||||||
|
@ -103,13 +103,13 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
|
||||||
client = client.Debug()
|
client = client.Debug()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = client.Schema.Create(context.Background()); err != nil {
|
if err = client.Schema.Create(ctx); err != nil {
|
||||||
return nil, fmt.Errorf("failed creating schema resources: %v", err)
|
return nil, fmt.Errorf("failed creating schema resources: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
Ent: client,
|
Ent: client,
|
||||||
CTX: context.Background(),
|
CTX: ctx,
|
||||||
Log: clog,
|
Log: clog,
|
||||||
CanFlush: true,
|
CanFlush: true,
|
||||||
Type: config.Type,
|
Type: config.Type,
|
||||||
|
|
|
@ -29,7 +29,9 @@ func getDBClient(t *testing.T) *database.Client {
|
||||||
dbPath, err := os.CreateTemp("", "*sqlite")
|
dbPath, err := os.CreateTemp("", "*sqlite")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testDBClient, err := database.NewClient(&csconfig.DatabaseCfg{
|
ctx := context.Background()
|
||||||
|
|
||||||
|
testDBClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{
|
||||||
Type: "sqlite",
|
Type: "sqlite",
|
||||||
DbName: "crowdsec",
|
DbName: "crowdsec",
|
||||||
DbPath: dbPath.Name(),
|
DbPath: dbPath.Name(),
|
||||||
|
@ -215,7 +217,7 @@ func TestRegexpCacheBehavior(t *testing.T) {
|
||||||
err = FileInit(TestFolder, filename, "regex")
|
err = FileInit(TestFolder, filename, "regex")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
//cache with no TTL
|
// cache with no TTL
|
||||||
err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(1)})
|
err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(1)})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -227,7 +229,7 @@ func TestRegexpCacheBehavior(t *testing.T) {
|
||||||
assert.True(t, ret.(bool))
|
assert.True(t, ret.(bool))
|
||||||
assert.Equal(t, 1, dataFileRegexCache[filename].Len(false))
|
assert.Equal(t, 1, dataFileRegexCache[filename].Len(false))
|
||||||
|
|
||||||
//cache with TTL
|
// cache with TTL
|
||||||
ttl := 500 * time.Millisecond
|
ttl := 500 * time.Millisecond
|
||||||
err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(2), TTL: &ttl})
|
err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(2), TTL: &ttl})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -994,6 +996,7 @@ func TestGetDecisionsCount(t *testing.T) {
|
||||||
log.Printf("test '%s' : OK", test.name)
|
log.Printf("test '%s' : OK", test.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetDecisionsSinceCount(t *testing.T) {
|
func TestGetDecisionsSinceCount(t *testing.T) {
|
||||||
existingIP := "1.2.3.4"
|
existingIP := "1.2.3.4"
|
||||||
unknownIP := "1.2.3.5"
|
unknownIP := "1.2.3.5"
|
||||||
|
@ -1365,6 +1368,7 @@ func TestGetActiveDecisionsTimeLeft(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
output, err := expr.Run(program, test.env)
|
output, err := expr.Run(program, test.env)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
switch o := output.(type) {
|
switch o := output.(type) {
|
||||||
case time.Duration:
|
case time.Duration:
|
||||||
require.LessOrEqual(t, int(o.Seconds()), int(test.max))
|
require.LessOrEqual(t, int(o.Seconds()), int(test.max))
|
||||||
|
@ -1376,7 +1380,6 @@ func TestGetActiveDecisionsTimeLeft(t *testing.T) {
|
||||||
t.Fatalf("GetActiveDecisionsTimeLeft() should return a time.Duration or a float64")
|
t.Fatalf("GetActiveDecisionsTimeLeft() should return a time.Duration or a float64")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseUnixTime(t *testing.T) {
|
func TestParseUnixTime(t *testing.T) {
|
||||||
|
@ -1415,9 +1418,11 @@ func TestParseUnixTime(t *testing.T) {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
output, err := ParseUnixTime(tc.value)
|
output, err := ParseUnixTime(tc.value)
|
||||||
cstest.RequireErrorContains(t, err, tc.expectedErr)
|
cstest.RequireErrorContains(t, err, tc.expectedErr)
|
||||||
|
|
||||||
if tc.expectedErr != "" {
|
if tc.expectedErr != "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
require.WithinDuration(t, tc.expected, output.(time.Time), time.Second)
|
require.WithinDuration(t, tc.expected, output.(time.Time), time.Second)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1520,6 +1525,7 @@ func TestIsIp(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
|
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -1619,12 +1625,15 @@ func TestB64Decode(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
|
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
|
||||||
if tc.expectedRuntimeErr {
|
if tc.expectedRuntimeErr {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tc.expected, output)
|
require.Equal(t, tc.expected, output)
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue