refactor context (cscli, pkg/database) (#3071)

* cscli: helper require.DBClient()

* refactor pkg/database: explicit context to dbclient constructor

* lint
This commit is contained in:
mmetc 2024-06-11 12:13:18 +02:00 committed by GitHub
parent 24687e982a
commit bd4540b1bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 97 additions and 49 deletions

View file

@ -24,7 +24,6 @@ import (
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/cwversion"
"github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
) )
@ -378,6 +377,7 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil { alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil {
return err return err
} }
if ActiveDecision != nil { if ActiveDecision != nil {
alertDeleteFilter.ActiveDecisionEquals = ActiveDecision alertDeleteFilter.ActiveDecisionEquals = ActiveDecision
} }
@ -385,21 +385,27 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
if *alertDeleteFilter.ScopeEquals == "" { if *alertDeleteFilter.ScopeEquals == "" {
alertDeleteFilter.ScopeEquals = nil alertDeleteFilter.ScopeEquals = nil
} }
if *alertDeleteFilter.ValueEquals == "" { if *alertDeleteFilter.ValueEquals == "" {
alertDeleteFilter.ValueEquals = nil alertDeleteFilter.ValueEquals = nil
} }
if *alertDeleteFilter.ScenarioEquals == "" { if *alertDeleteFilter.ScenarioEquals == "" {
alertDeleteFilter.ScenarioEquals = nil alertDeleteFilter.ScenarioEquals = nil
} }
if *alertDeleteFilter.IPEquals == "" { if *alertDeleteFilter.IPEquals == "" {
alertDeleteFilter.IPEquals = nil alertDeleteFilter.IPEquals = nil
} }
if *alertDeleteFilter.RangeEquals == "" { if *alertDeleteFilter.RangeEquals == "" {
alertDeleteFilter.RangeEquals = nil alertDeleteFilter.RangeEquals = nil
} }
if contained != nil && *contained { if contained != nil && *contained {
alertDeleteFilter.Contains = new(bool) alertDeleteFilter.Contains = new(bool)
} }
limit := 0 limit := 0
alertDeleteFilter.Limit = &limit alertDeleteFilter.Limit = &limit
} else { } else {
@ -419,6 +425,7 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
return fmt.Errorf("unable to delete alert: %w", err) return fmt.Errorf("unable to delete alert: %w", err)
} }
} }
log.Infof("%s alert(s) deleted", alerts.NbDeleted) log.Infof("%s alert(s) deleted", alerts.NbDeleted)
return nil return nil
@ -558,14 +565,14 @@ func (cli *cliAlerts) NewFlushCmd() *cobra.Command {
/!\ This command can be used only on the same machine than the local API`, /!\ This command can be used only on the same machine than the local API`,
Example: `cscli alerts flush --max-items 1000 --max-age 7d`, Example: `cscli alerts flush --max-items 1000 --max-age 7d`,
DisableAutoGenTag: true, DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
cfg := cli.cfg() cfg := cli.cfg()
if err := require.LAPI(cfg); err != nil { if err := require.LAPI(cfg); err != nil {
return err return err
} }
db, err := database.NewClient(cfg.DbConfig) db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil { if err != nil {
return fmt.Errorf("unable to create new database client: %w", err) return err
} }
log.Info("Flushing alerts. !! This may take a long time !!") log.Info("Flushing alerts. !! This may take a long time !!")
err = db.FlushAlerts(maxAge, maxItems) err = db.FlushAlerts(maxAge, maxItems)

View file

@ -57,7 +57,7 @@ Note: This command requires database direct access, so is intended to be run on
Args: cobra.MinimumNArgs(1), Args: cobra.MinimumNArgs(1),
Aliases: []string{"bouncer"}, Aliases: []string{"bouncer"},
DisableAutoGenTag: true, DisableAutoGenTag: true,
PersistentPreRunE: func(_ *cobra.Command, _ []string) error { PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
var err error var err error
cfg := cli.cfg() cfg := cli.cfg()
@ -66,9 +66,9 @@ Note: This command requires database direct access, so is intended to be run on
return err return err
} }
cli.db, err = database.NewClient(cfg.DbConfig) cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil { if err != nil {
return fmt.Errorf("can't connect to the database: %w", err) return err
} }
return nil return nil

View file

@ -128,14 +128,14 @@ Note: This command requires database direct access, so is intended to be run on
Example: `cscli machines [action]`, Example: `cscli machines [action]`,
DisableAutoGenTag: true, DisableAutoGenTag: true,
Aliases: []string{"machine"}, Aliases: []string{"machine"},
PersistentPreRunE: func(_ *cobra.Command, _ []string) error { PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
var err error var err error
if err = require.LAPI(cli.cfg()); err != nil { if err = require.LAPI(cli.cfg()); err != nil {
return err return err
} }
cli.db, err = database.NewClient(cli.cfg().DbConfig) cli.db, err = require.DBClient(cmd.Context(), cli.cfg().DbConfig)
if err != nil { if err != nil {
return fmt.Errorf("unable to create new database client: %w", err) return err
} }
return nil return nil

View file

@ -12,7 +12,6 @@ import (
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/apiserver" "github.com/crowdsecurity/crowdsec/pkg/apiserver"
"github.com/crowdsecurity/crowdsec/pkg/database"
) )
type cliPapi struct { type cliPapi struct {
@ -56,12 +55,12 @@ func (cli *cliPapi) NewStatusCmd() *cobra.Command {
Short: "Get status of the Polling API", Short: "Get status of the Polling API",
Args: cobra.MinimumNArgs(0), Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true, DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
var err error var err error
cfg := cli.cfg() cfg := cli.cfg()
db, err := database.NewClient(cfg.DbConfig) db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil { if err != nil {
return fmt.Errorf("unable to initialize database client: %w", err) return err
} }
apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
@ -105,14 +104,14 @@ func (cli *cliPapi) NewSyncCmd() *cobra.Command {
Short: "Sync with the Polling API, pulling all non-expired orders for the instance", Short: "Sync with the Polling API, pulling all non-expired orders for the instance",
Args: cobra.MinimumNArgs(0), Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true, DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
var err error var err error
cfg := cli.cfg() cfg := cli.cfg()
t := tomb.Tomb{} t := tomb.Tomb{}
db, err := database.NewClient(cfg.DbConfig) db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil { if err != nil {
return fmt.Errorf("unable to initialize database client: %w", err) return err
} }
apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)

View file

@ -10,6 +10,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cwhub" "github.com/crowdsecurity/crowdsec/pkg/cwhub"
"github.com/crowdsecurity/crowdsec/pkg/database"
) )
func LAPI(c *csconfig.Config) error { func LAPI(c *csconfig.Config) error {
@ -48,6 +49,15 @@ func CAPIRegistered(c *csconfig.Config) error {
return nil return nil
} }
func DBClient(ctx context.Context, dbcfg *csconfig.DatabaseCfg) (*database.Client, error) {
db, err := database.NewClient(ctx, dbcfg)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
return db, nil
}
func DB(c *csconfig.Config) error { func DB(c *csconfig.Config) error {
if err := c.LoadDBConfig(true); err != nil { if err := c.LoadDBConfig(true); err != nil {
return fmt.Errorf("this command requires direct database access (must be run on the local API machine): %w", err) return fmt.Errorf("this command requires direct database access (must be run on the local API machine): %w", err)

View file

@ -463,9 +463,9 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error {
w := bytes.NewBuffer(nil) w := bytes.NewBuffer(nil)
zipWriter := zip.NewWriter(w) zipWriter := zip.NewWriter(w)
db, err := database.NewClient(cfg.DbConfig) db, err := require.DBClient(ctx, cfg.DbConfig)
if err != nil { if err != nil {
log.Warnf("Could not connect to database: %s", err) log.Warn(err)
} }
if err = cfg.LoadAPIServer(true); err != nil { if err = cfg.LoadAPIServer(true); err != nil {

View file

@ -3,6 +3,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"runtime/pprof" "runtime/pprof"
@ -41,9 +42,10 @@ func StartRunSvc() error {
var err error var err error
if cConfig.DbConfig != nil { ctx := context.TODO()
dbClient, err = database.NewClient(cConfig.DbConfig)
if cConfig.DbConfig != nil {
dbClient, err = database.NewClient(ctx, cConfig.DbConfig)
if err != nil { if err != nil {
return fmt.Errorf("unable to create database client: %w", err) return fmt.Errorf("unable to create database client: %w", err)
} }

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"runtime/pprof" "runtime/pprof"
@ -80,8 +81,10 @@ func WindowsRun() error {
var dbClient *database.Client var dbClient *database.Client
var err error var err error
ctx := context.TODO()
if cConfig.DbConfig != nil { if cConfig.DbConfig != nil {
dbClient, err = database.NewClient(cConfig.DbConfig) dbClient, err = database.NewClient(ctx, cConfig.DbConfig)
if err != nil { if err != nil {
return fmt.Errorf("unable to create database client: %w", err) return fmt.Errorf("unable to create database client: %w", err)

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
@ -322,8 +323,10 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
crowdsecTomb = tomb.Tomb{} crowdsecTomb = tomb.Tomb{}
pluginTomb = tomb.Tomb{} pluginTomb = tomb.Tomb{}
ctx := context.TODO()
if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil { if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil {
dbClient, err := database.NewClient(cConfig.API.Server.DbConfig) dbClient, err := database.NewClient(ctx, cConfig.API.Server.DbConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed to get database client: %w", err) return fmt.Errorf("failed to get database client: %w", err)
} }

View file

@ -38,9 +38,11 @@ import (
func getDBClient(t *testing.T) *database.Client { func getDBClient(t *testing.T) *database.Client {
t.Helper() t.Helper()
ctx := context.Background()
dbPath, err := os.CreateTemp("", "*sqlite") dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err) require.NoError(t, err)
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{ dbClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{
Type: "sqlite", Type: "sqlite",
DbName: "crowdsec", DbName: "crowdsec",
DbPath: dbPath.Name(), DbPath: dbPath.Name(),
@ -56,7 +58,7 @@ func getAPIC(t *testing.T) *apic {
return &apic{ return &apic{
AlertsAddChan: make(chan []*models.Alert), AlertsAddChan: make(chan []*models.Alert),
//DecisionDeleteChan: make(chan []*models.Decision), // DecisionDeleteChan: make(chan []*models.Decision),
dbClient: dbClient, dbClient: dbClient,
mu: sync.Mutex{}, mu: sync.Mutex{},
startup: true, startup: true,
@ -176,10 +178,11 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
} }
scenarios, err := api.FetchScenariosListFromDB() scenarios, err := api.FetchScenariosListFromDB()
require.NoError(t, err)
for machineID := range tc.machineIDsWithScenarios { for machineID := range tc.machineIDsWithScenarios {
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background()) api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
} }
require.NoError(t, err)
assert.ElementsMatch(t, tc.expectedScenarios, scenarios) assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
}) })
@ -234,6 +237,7 @@ func TestNewAPIC(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
setConfig() setConfig()
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("POST", "http://foobar/v3/watchers/login", httpmock.NewBytesResponder( httpmock.RegisterResponder("POST", "http://foobar/v3/watchers/login", httpmock.NewBytesResponder(
200, jsonMarshalX( 200, jsonMarshalX(
@ -353,6 +357,7 @@ func TestAPICGetMetrics(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
apiClient := getAPIC(t) apiClient := getAPIC(t)
cleanUp(apiClient) cleanUp(apiClient)
for i, machineID := range tc.machineIDs { for i, machineID := range tc.machineIDs {
apiClient.dbClient.Ent.Machine.Create(). apiClient.dbClient.Ent.Machine.Create().
SetMachineId(machineID). SetMachineId(machineID).
@ -548,7 +553,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
func TestAPICWhitelists(t *testing.T) { func TestAPICWhitelists(t *testing.T) {
api := getAPIC(t) api := getAPIC(t)
//one whitelist on IP, one on CIDR // one whitelist on IP, one on CIDR
api.whitelists = &csconfig.CapiWhitelist{} api.whitelists = &csconfig.CapiWhitelist{}
api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4")) api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4"))
@ -593,7 +598,7 @@ func TestAPICWhitelists(t *testing.T) {
Scope: ptr.Of("Ip"), Scope: ptr.Of("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{ {
Value: ptr.Of("13.2.3.4"), //wl by cidr Value: ptr.Of("13.2.3.4"), // wl by cidr
Duration: ptr.Of("24h"), Duration: ptr.Of("24h"),
}, },
}, },
@ -614,7 +619,7 @@ func TestAPICWhitelists(t *testing.T) {
Scope: ptr.Of("Ip"), Scope: ptr.Of("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{ {
Value: ptr.Of("13.2.3.5"), //wl by cidr Value: ptr.Of("13.2.3.5"), // wl by cidr
Duration: ptr.Of("24h"), Duration: ptr.Of("24h"),
}, },
}, },
@ -634,7 +639,7 @@ func TestAPICWhitelists(t *testing.T) {
Scope: ptr.Of("Ip"), Scope: ptr.Of("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{ {
Value: ptr.Of("9.2.3.4"), //wl by ip Value: ptr.Of("9.2.3.4"), // wl by ip
Duration: ptr.Of("24h"), Duration: ptr.Of("24h"),
}, },
}, },
@ -685,7 +690,7 @@ func TestAPICWhitelists(t *testing.T) {
err = api.PullTop(false) err = api.PullTop(false)
require.NoError(t, err) require.NoError(t, err)
assertTotalDecisionCount(t, api.dbClient, 5) //2 from FIRE + 2 from bl + 1 existing assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing
assertTotalValidDecisionCount(t, api.dbClient, 4) assertTotalValidDecisionCount(t, api.dbClient, 4)
assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list. assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list.
alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background()) alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background())
@ -1103,6 +1108,7 @@ func TestAPICPush(t *testing.T) {
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()
apic, err := apiclient.NewDefaultClient( apic, err := apiclient.NewDefaultClient(
url, url,
"/api", "/api",

View file

@ -162,7 +162,9 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
var flushScheduler *gocron.Scheduler var flushScheduler *gocron.Scheduler
dbClient, err := database.NewClient(config.DbConfig) ctx := context.TODO()
dbClient, err := database.NewClient(ctx, config.DbConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to init database client: %w", err) return nil, fmt.Errorf("unable to init database client: %w", err)
} }
@ -227,7 +229,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
controller := &controllers.Controller{ controller := &controllers.Controller{
DBClient: dbClient, DBClient: dbClient,
Ectx: context.Background(), Ectx: ctx,
Router: router, Router: router,
Profiles: config.Profiles, Profiles: config.Profiles,
Log: clog, Log: clog,

View file

@ -1,6 +1,7 @@
package apiserver package apiserver
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -161,7 +162,9 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
} }
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) { func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
dbClient, err := database.NewClient(config) ctx := context.Background()
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err) require.NoError(t, err)
err = dbClient.ValidateMachine(machineID) err = dbClient.ValidateMachine(machineID)
@ -169,7 +172,9 @@ func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCf
} }
func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string { func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string {
dbClient, err := database.NewClient(config) ctx := context.Background()
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err) require.NoError(t, err)
machines, err := dbClient.ListMachines() machines, err := dbClient.ListMachines()
@ -260,7 +265,9 @@ func CreateTestMachine(t *testing.T, router *gin.Engine) string {
} }
func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string { func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
dbClient, err := database.NewClient(config) ctx := context.Background()
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err) require.NoError(t, err)
apiKey, err := middlewares.GenerateAPIKey(keyLength) apiKey, err := middlewares.GenerateAPIKey(keyLength)
@ -356,10 +363,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
api.router.ServeHTTP(w, req) api.router.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code) assert.Equal(t, 404, w.Code)
//wait for the request to happen // wait for the request to happen
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
//check file content // check file content
data, err := os.ReadFile(expectedFile) data, err := os.ReadFile(expectedFile)
require.NoError(t, err) require.NoError(t, err)
@ -406,10 +413,10 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
api.router.ServeHTTP(w, req) api.router.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code) assert.Equal(t, http.StatusNotFound, w.Code)
//wait for the request to happen // wait for the request to happen
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
//check file content // check file content
x, err := os.ReadFile(expectedFile) x, err := os.ReadFile(expectedFile)
if err == nil { if err == nil {
require.Empty(t, x) require.Empty(t, x)

View file

@ -947,7 +947,7 @@ func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string
Count int Count int
} }
ctx := context.Background() ctx := context.TODO()
query := c.Ent.Alert.Query() query := c.Ent.Alert.Query()

View file

@ -48,7 +48,7 @@ func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig.
return drv, nil return drv, nil
} }
func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, error) {
var client *ent.Client var client *ent.Client
if config == nil { if config == nil {
@ -69,7 +69,7 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
typ, dia, err := config.ConnectionDialect() typ, dia, err := config.ConnectionDialect()
if err != nil { if err != nil {
return nil, err //unsupported database caught here return nil, err // unsupported database caught here
} }
if config.Type == "sqlite" { if config.Type == "sqlite" {
@ -103,13 +103,13 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
client = client.Debug() client = client.Debug()
} }
if err = client.Schema.Create(context.Background()); err != nil { if err = client.Schema.Create(ctx); err != nil {
return nil, fmt.Errorf("failed creating schema resources: %v", err) return nil, fmt.Errorf("failed creating schema resources: %v", err)
} }
return &Client{ return &Client{
Ent: client, Ent: client,
CTX: context.Background(), CTX: ctx,
Log: clog, Log: clog,
CanFlush: true, CanFlush: true,
Type: config.Type, Type: config.Type,

View file

@ -29,7 +29,9 @@ func getDBClient(t *testing.T) *database.Client {
dbPath, err := os.CreateTemp("", "*sqlite") dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err) require.NoError(t, err)
testDBClient, err := database.NewClient(&csconfig.DatabaseCfg{ ctx := context.Background()
testDBClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{
Type: "sqlite", Type: "sqlite",
DbName: "crowdsec", DbName: "crowdsec",
DbPath: dbPath.Name(), DbPath: dbPath.Name(),
@ -215,7 +217,7 @@ func TestRegexpCacheBehavior(t *testing.T) {
err = FileInit(TestFolder, filename, "regex") err = FileInit(TestFolder, filename, "regex")
require.NoError(t, err) require.NoError(t, err)
//cache with no TTL // cache with no TTL
err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(1)}) err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(1)})
require.NoError(t, err) require.NoError(t, err)
@ -227,7 +229,7 @@ func TestRegexpCacheBehavior(t *testing.T) {
assert.True(t, ret.(bool)) assert.True(t, ret.(bool))
assert.Equal(t, 1, dataFileRegexCache[filename].Len(false)) assert.Equal(t, 1, dataFileRegexCache[filename].Len(false))
//cache with TTL // cache with TTL
ttl := 500 * time.Millisecond ttl := 500 * time.Millisecond
err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(2), TTL: &ttl}) err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(2), TTL: &ttl})
require.NoError(t, err) require.NoError(t, err)
@ -994,6 +996,7 @@ func TestGetDecisionsCount(t *testing.T) {
log.Printf("test '%s' : OK", test.name) log.Printf("test '%s' : OK", test.name)
} }
} }
func TestGetDecisionsSinceCount(t *testing.T) { func TestGetDecisionsSinceCount(t *testing.T) {
existingIP := "1.2.3.4" existingIP := "1.2.3.4"
unknownIP := "1.2.3.5" unknownIP := "1.2.3.5"
@ -1365,6 +1368,7 @@ func TestGetActiveDecisionsTimeLeft(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
output, err := expr.Run(program, test.env) output, err := expr.Run(program, test.env)
require.NoError(t, err) require.NoError(t, err)
switch o := output.(type) { switch o := output.(type) {
case time.Duration: case time.Duration:
require.LessOrEqual(t, int(o.Seconds()), int(test.max)) require.LessOrEqual(t, int(o.Seconds()), int(test.max))
@ -1376,7 +1380,6 @@ func TestGetActiveDecisionsTimeLeft(t *testing.T) {
t.Fatalf("GetActiveDecisionsTimeLeft() should return a time.Duration or a float64") t.Fatalf("GetActiveDecisionsTimeLeft() should return a time.Duration or a float64")
} }
} }
} }
func TestParseUnixTime(t *testing.T) { func TestParseUnixTime(t *testing.T) {
@ -1415,9 +1418,11 @@ func TestParseUnixTime(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
output, err := ParseUnixTime(tc.value) output, err := ParseUnixTime(tc.value)
cstest.RequireErrorContains(t, err, tc.expectedErr) cstest.RequireErrorContains(t, err, tc.expectedErr)
if tc.expectedErr != "" { if tc.expectedErr != "" {
return return
} }
require.WithinDuration(t, tc.expected, output.(time.Time), time.Second) require.WithinDuration(t, tc.expected, output.(time.Time), time.Second)
}) })
} }
@ -1520,6 +1525,7 @@ func TestIsIp(t *testing.T) {
require.Error(t, err) require.Error(t, err)
return return
} }
require.NoError(t, err) require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
require.NoError(t, err) require.NoError(t, err)
@ -1619,12 +1625,15 @@ func TestB64Decode(t *testing.T) {
require.Error(t, err) require.Error(t, err)
return return
} }
require.NoError(t, err) require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
if tc.expectedRuntimeErr { if tc.expectedRuntimeErr {
require.Error(t, err) require.Error(t, err)
return return
} }
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, tc.expected, output) require.Equal(t, tc.expected, output)
}) })