refactor context (cscli, pkg/database) (#3071)

* cscli: helper require.DBClient()

* refactor pkg/database: explicit context to dbclient constructor

* lint
This commit is contained in:
mmetc 2024-06-11 12:13:18 +02:00 committed by GitHub
parent 24687e982a
commit bd4540b1bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 97 additions and 49 deletions

View file

@ -24,7 +24,6 @@ import (
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
"github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
)
@ -378,6 +377,7 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil {
return err
}
if ActiveDecision != nil {
alertDeleteFilter.ActiveDecisionEquals = ActiveDecision
}
@ -385,21 +385,27 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
if *alertDeleteFilter.ScopeEquals == "" {
alertDeleteFilter.ScopeEquals = nil
}
if *alertDeleteFilter.ValueEquals == "" {
alertDeleteFilter.ValueEquals = nil
}
if *alertDeleteFilter.ScenarioEquals == "" {
alertDeleteFilter.ScenarioEquals = nil
}
if *alertDeleteFilter.IPEquals == "" {
alertDeleteFilter.IPEquals = nil
}
if *alertDeleteFilter.RangeEquals == "" {
alertDeleteFilter.RangeEquals = nil
}
if contained != nil && *contained {
alertDeleteFilter.Contains = new(bool)
}
limit := 0
alertDeleteFilter.Limit = &limit
} else {
@ -419,6 +425,7 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
return fmt.Errorf("unable to delete alert: %w", err)
}
}
log.Infof("%s alert(s) deleted", alerts.NbDeleted)
return nil
@ -558,14 +565,14 @@ func (cli *cliAlerts) NewFlushCmd() *cobra.Command {
/!\ This command can be used only on the same machine than the local API`,
Example: `cscli alerts flush --max-items 1000 --max-age 7d`,
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
cfg := cli.cfg()
if err := require.LAPI(cfg); err != nil {
return err
}
db, err := database.NewClient(cfg.DbConfig)
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil {
return fmt.Errorf("unable to create new database client: %w", err)
return err
}
log.Info("Flushing alerts. !! This may take a long time !!")
err = db.FlushAlerts(maxAge, maxItems)

View file

@ -57,7 +57,7 @@ Note: This command requires database direct access, so is intended to be run on
Args: cobra.MinimumNArgs(1),
Aliases: []string{"bouncer"},
DisableAutoGenTag: true,
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
var err error
cfg := cli.cfg()
@ -66,9 +66,9 @@ Note: This command requires database direct access, so is intended to be run on
return err
}
cli.db, err = database.NewClient(cfg.DbConfig)
cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil {
return fmt.Errorf("can't connect to the database: %w", err)
return err
}
return nil

View file

@ -128,14 +128,14 @@ Note: This command requires database direct access, so is intended to be run on
Example: `cscli machines [action]`,
DisableAutoGenTag: true,
Aliases: []string{"machine"},
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
var err error
if err = require.LAPI(cli.cfg()); err != nil {
return err
}
cli.db, err = database.NewClient(cli.cfg().DbConfig)
cli.db, err = require.DBClient(cmd.Context(), cli.cfg().DbConfig)
if err != nil {
return fmt.Errorf("unable to create new database client: %w", err)
return err
}
return nil

View file

@ -12,7 +12,6 @@ import (
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/apiserver"
"github.com/crowdsecurity/crowdsec/pkg/database"
)
type cliPapi struct {
@ -56,12 +55,12 @@ func (cli *cliPapi) NewStatusCmd() *cobra.Command {
Short: "Get status of the Polling API",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
var err error
cfg := cli.cfg()
db, err := database.NewClient(cfg.DbConfig)
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil {
return fmt.Errorf("unable to initialize database client: %w", err)
return err
}
apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
@ -105,14 +104,14 @@ func (cli *cliPapi) NewSyncCmd() *cobra.Command {
Short: "Sync with the Polling API, pulling all non-expired orders for the instance",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
var err error
cfg := cli.cfg()
t := tomb.Tomb{}
db, err := database.NewClient(cfg.DbConfig)
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil {
return fmt.Errorf("unable to initialize database client: %w", err)
return err
}
apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)

View file

@ -10,6 +10,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
"github.com/crowdsecurity/crowdsec/pkg/database"
)
func LAPI(c *csconfig.Config) error {
@ -48,6 +49,15 @@ func CAPIRegistered(c *csconfig.Config) error {
return nil
}
func DBClient(ctx context.Context, dbcfg *csconfig.DatabaseCfg) (*database.Client, error) {
db, err := database.NewClient(ctx, dbcfg)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
return db, nil
}
func DB(c *csconfig.Config) error {
if err := c.LoadDBConfig(true); err != nil {
return fmt.Errorf("this command requires direct database access (must be run on the local API machine): %w", err)

View file

@ -463,9 +463,9 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error {
w := bytes.NewBuffer(nil)
zipWriter := zip.NewWriter(w)
db, err := database.NewClient(cfg.DbConfig)
db, err := require.DBClient(ctx, cfg.DbConfig)
if err != nil {
log.Warnf("Could not connect to database: %s", err)
log.Warn(err)
}
if err = cfg.LoadAPIServer(true); err != nil {

View file

@ -3,6 +3,7 @@
package main
import (
"context"
"fmt"
"runtime/pprof"
@ -41,9 +42,10 @@ func StartRunSvc() error {
var err error
if cConfig.DbConfig != nil {
dbClient, err = database.NewClient(cConfig.DbConfig)
ctx := context.TODO()
if cConfig.DbConfig != nil {
dbClient, err = database.NewClient(ctx, cConfig.DbConfig)
if err != nil {
return fmt.Errorf("unable to create database client: %w", err)
}

View file

@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"runtime/pprof"
@ -80,8 +81,10 @@ func WindowsRun() error {
var dbClient *database.Client
var err error
ctx := context.TODO()
if cConfig.DbConfig != nil {
dbClient, err = database.NewClient(cConfig.DbConfig)
dbClient, err = database.NewClient(ctx, cConfig.DbConfig)
if err != nil {
return fmt.Errorf("unable to create database client: %w", err)

View file

@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"os"
"os/signal"
@ -322,8 +323,10 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
crowdsecTomb = tomb.Tomb{}
pluginTomb = tomb.Tomb{}
ctx := context.TODO()
if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil {
dbClient, err := database.NewClient(cConfig.API.Server.DbConfig)
dbClient, err := database.NewClient(ctx, cConfig.API.Server.DbConfig)
if err != nil {
return fmt.Errorf("failed to get database client: %w", err)
}

View file

@ -38,9 +38,11 @@ import (
func getDBClient(t *testing.T) *database.Client {
t.Helper()
ctx := context.Background()
dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err)
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
dbClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{
Type: "sqlite",
DbName: "crowdsec",
DbPath: dbPath.Name(),
@ -56,7 +58,7 @@ func getAPIC(t *testing.T) *apic {
return &apic{
AlertsAddChan: make(chan []*models.Alert),
//DecisionDeleteChan: make(chan []*models.Decision),
// DecisionDeleteChan: make(chan []*models.Decision),
dbClient: dbClient,
mu: sync.Mutex{},
startup: true,
@ -176,10 +178,11 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
}
scenarios, err := api.FetchScenariosListFromDB()
require.NoError(t, err)
for machineID := range tc.machineIDsWithScenarios {
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
}
require.NoError(t, err)
assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
})
@ -234,6 +237,7 @@ func TestNewAPIC(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
setConfig()
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("POST", "http://foobar/v3/watchers/login", httpmock.NewBytesResponder(
200, jsonMarshalX(
@ -353,6 +357,7 @@ func TestAPICGetMetrics(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
apiClient := getAPIC(t)
cleanUp(apiClient)
for i, machineID := range tc.machineIDs {
apiClient.dbClient.Ent.Machine.Create().
SetMachineId(machineID).
@ -548,7 +553,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
func TestAPICWhitelists(t *testing.T) {
api := getAPIC(t)
//one whitelist on IP, one on CIDR
// one whitelist on IP, one on CIDR
api.whitelists = &csconfig.CapiWhitelist{}
api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4"))
@ -593,7 +598,7 @@ func TestAPICWhitelists(t *testing.T) {
Scope: ptr.Of("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Value: ptr.Of("13.2.3.4"), //wl by cidr
Value: ptr.Of("13.2.3.4"), // wl by cidr
Duration: ptr.Of("24h"),
},
},
@ -614,7 +619,7 @@ func TestAPICWhitelists(t *testing.T) {
Scope: ptr.Of("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Value: ptr.Of("13.2.3.5"), //wl by cidr
Value: ptr.Of("13.2.3.5"), // wl by cidr
Duration: ptr.Of("24h"),
},
},
@ -634,7 +639,7 @@ func TestAPICWhitelists(t *testing.T) {
Scope: ptr.Of("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Value: ptr.Of("9.2.3.4"), //wl by ip
Value: ptr.Of("9.2.3.4"), // wl by ip
Duration: ptr.Of("24h"),
},
},
@ -685,7 +690,7 @@ func TestAPICWhitelists(t *testing.T) {
err = api.PullTop(false)
require.NoError(t, err)
assertTotalDecisionCount(t, api.dbClient, 5) //2 from FIRE + 2 from bl + 1 existing
assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing
assertTotalValidDecisionCount(t, api.dbClient, 4)
assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list.
alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background())
@ -1103,6 +1108,7 @@ func TestAPICPush(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
apic, err := apiclient.NewDefaultClient(
url,
"/api",

View file

@ -162,7 +162,9 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
var flushScheduler *gocron.Scheduler
dbClient, err := database.NewClient(config.DbConfig)
ctx := context.TODO()
dbClient, err := database.NewClient(ctx, config.DbConfig)
if err != nil {
return nil, fmt.Errorf("unable to init database client: %w", err)
}
@ -227,7 +229,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
controller := &controllers.Controller{
DBClient: dbClient,
Ectx: context.Background(),
Ectx: ctx,
Router: router,
Profiles: config.Profiles,
Log: clog,

View file

@ -1,6 +1,7 @@
package apiserver
import (
"context"
"encoding/json"
"fmt"
"net/http"
@ -161,7 +162,9 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
}
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
dbClient, err := database.NewClient(config)
ctx := context.Background()
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)
err = dbClient.ValidateMachine(machineID)
@ -169,7 +172,9 @@ func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCf
}
func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string {
dbClient, err := database.NewClient(config)
ctx := context.Background()
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)
machines, err := dbClient.ListMachines()
@ -260,7 +265,9 @@ func CreateTestMachine(t *testing.T, router *gin.Engine) string {
}
func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
dbClient, err := database.NewClient(config)
ctx := context.Background()
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)
apiKey, err := middlewares.GenerateAPIKey(keyLength)
@ -356,10 +363,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
req.Header.Set("User-Agent", UserAgent)
api.router.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
//wait for the request to happen
// wait for the request to happen
time.Sleep(500 * time.Millisecond)
//check file content
// check file content
data, err := os.ReadFile(expectedFile)
require.NoError(t, err)
@ -406,10 +413,10 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
req.Header.Set("User-Agent", UserAgent)
api.router.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
//wait for the request to happen
// wait for the request to happen
time.Sleep(500 * time.Millisecond)
//check file content
// check file content
x, err := os.ReadFile(expectedFile)
if err == nil {
require.Empty(t, x)

View file

@ -947,7 +947,7 @@ func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string
Count int
}
ctx := context.Background()
ctx := context.TODO()
query := c.Ent.Alert.Query()

View file

@ -48,7 +48,7 @@ func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig.
return drv, nil
}
func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, error) {
var client *ent.Client
if config == nil {
@ -69,7 +69,7 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
typ, dia, err := config.ConnectionDialect()
if err != nil {
return nil, err //unsupported database caught here
return nil, err // unsupported database caught here
}
if config.Type == "sqlite" {
@ -103,13 +103,13 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
client = client.Debug()
}
if err = client.Schema.Create(context.Background()); err != nil {
if err = client.Schema.Create(ctx); err != nil {
return nil, fmt.Errorf("failed creating schema resources: %v", err)
}
return &Client{
Ent: client,
CTX: context.Background(),
CTX: ctx,
Log: clog,
CanFlush: true,
Type: config.Type,

View file

@ -29,7 +29,9 @@ func getDBClient(t *testing.T) *database.Client {
dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err)
testDBClient, err := database.NewClient(&csconfig.DatabaseCfg{
ctx := context.Background()
testDBClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{
Type: "sqlite",
DbName: "crowdsec",
DbPath: dbPath.Name(),
@ -215,7 +217,7 @@ func TestRegexpCacheBehavior(t *testing.T) {
err = FileInit(TestFolder, filename, "regex")
require.NoError(t, err)
//cache with no TTL
// cache with no TTL
err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(1)})
require.NoError(t, err)
@ -227,7 +229,7 @@ func TestRegexpCacheBehavior(t *testing.T) {
assert.True(t, ret.(bool))
assert.Equal(t, 1, dataFileRegexCache[filename].Len(false))
//cache with TTL
// cache with TTL
ttl := 500 * time.Millisecond
err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(2), TTL: &ttl})
require.NoError(t, err)
@ -994,6 +996,7 @@ func TestGetDecisionsCount(t *testing.T) {
log.Printf("test '%s' : OK", test.name)
}
}
func TestGetDecisionsSinceCount(t *testing.T) {
existingIP := "1.2.3.4"
unknownIP := "1.2.3.5"
@ -1365,6 +1368,7 @@ func TestGetActiveDecisionsTimeLeft(t *testing.T) {
require.NoError(t, err)
output, err := expr.Run(program, test.env)
require.NoError(t, err)
switch o := output.(type) {
case time.Duration:
require.LessOrEqual(t, int(o.Seconds()), int(test.max))
@ -1376,7 +1380,6 @@ func TestGetActiveDecisionsTimeLeft(t *testing.T) {
t.Fatalf("GetActiveDecisionsTimeLeft() should return a time.Duration or a float64")
}
}
}
func TestParseUnixTime(t *testing.T) {
@ -1415,9 +1418,11 @@ func TestParseUnixTime(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
output, err := ParseUnixTime(tc.value)
cstest.RequireErrorContains(t, err, tc.expectedErr)
if tc.expectedErr != "" {
return
}
require.WithinDuration(t, tc.expected, output.(time.Time), time.Second)
})
}
@ -1520,6 +1525,7 @@ func TestIsIp(t *testing.T) {
require.Error(t, err)
return
}
require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
require.NoError(t, err)
@ -1619,12 +1625,15 @@ func TestB64Decode(t *testing.T) {
require.Error(t, err)
return
}
require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
if tc.expectedRuntimeErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.expected, output)
})