mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-11 12:25:53 +02:00
context propagation: pkg/database/machines (#3248)
This commit is contained in:
parent
e2196bdd66
commit
fee3debdcc
11 changed files with 109 additions and 94 deletions
|
@ -1,6 +1,7 @@
|
|||
package climachine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -210,11 +211,11 @@ func (cli *cliMachines) listCSV(out io.Writer, machines ent.Machines) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (cli *cliMachines) List(out io.Writer, db *database.Client) error {
|
||||
func (cli *cliMachines) List(ctx context.Context, out io.Writer, db *database.Client) error {
|
||||
// XXX: must use the provided db object, the one in the struct might be nil
|
||||
// (calling List directly skips the PersistentPreRunE)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
machines, err := db.ListMachines(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to list machines: %w", err)
|
||||
}
|
||||
|
@ -251,8 +252,8 @@ func (cli *cliMachines) newListCmd() *cobra.Command {
|
|||
Example: `cscli machines list`,
|
||||
Args: cobra.NoArgs,
|
||||
DisableAutoGenTag: true,
|
||||
RunE: func(_ *cobra.Command, _ []string) error {
|
||||
return cli.List(color.Output, cli.db)
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return cli.List(cmd.Context(), color.Output, cli.db)
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -278,8 +279,8 @@ func (cli *cliMachines) newAddCmd() *cobra.Command {
|
|||
cscli machines add MyTestMachine --auto
|
||||
cscli machines add MyTestMachine --password MyPassword
|
||||
cscli machines add -f- --auto > /tmp/mycreds.yaml`,
|
||||
RunE: func(_ *cobra.Command, args []string) error {
|
||||
return cli.add(args, string(password), dumpFile, apiURL, interactive, autoAdd, force)
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force)
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -294,7 +295,7 @@ cscli machines add -f- --auto > /tmp/mycreds.yaml`,
|
|||
return cmd
|
||||
}
|
||||
|
||||
func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error {
|
||||
func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error {
|
||||
var (
|
||||
err error
|
||||
machineID string
|
||||
|
@ -353,7 +354,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri
|
|||
|
||||
password := strfmt.Password(machinePassword)
|
||||
|
||||
_, err = cli.db.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType)
|
||||
_, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create machine: %w", err)
|
||||
}
|
||||
|
@ -399,6 +400,7 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp
|
|||
var err error
|
||||
|
||||
cfg := cli.cfg()
|
||||
ctx := cmd.Context()
|
||||
|
||||
// need to load config and db because PersistentPreRunE is not called for completions
|
||||
|
||||
|
@ -407,13 +409,13 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp
|
|||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
|
||||
cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
|
||||
cli.db, err = require.DBClient(ctx, cfg.DbConfig)
|
||||
if err != nil {
|
||||
cobra.CompError("unable to list machines " + err.Error())
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
|
||||
machines, err := cli.db.ListMachines()
|
||||
machines, err := cli.db.ListMachines(ctx)
|
||||
if err != nil {
|
||||
cobra.CompError("unable to list machines " + err.Error())
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
|
@ -430,9 +432,9 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp
|
|||
return ret, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
|
||||
func (cli *cliMachines) delete(machines []string, ignoreMissing bool) error {
|
||||
func (cli *cliMachines) delete(ctx context.Context, machines []string, ignoreMissing bool) error {
|
||||
for _, machineID := range machines {
|
||||
if err := cli.db.DeleteWatcher(machineID); err != nil {
|
||||
if err := cli.db.DeleteWatcher(ctx, machineID); err != nil {
|
||||
var notFoundErr *database.MachineNotFoundError
|
||||
if ignoreMissing && errors.As(err, ¬FoundErr) {
|
||||
return nil
|
||||
|
@ -460,8 +462,8 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command {
|
|||
Aliases: []string{"remove"},
|
||||
DisableAutoGenTag: true,
|
||||
ValidArgsFunction: cli.validMachineID,
|
||||
RunE: func(_ *cobra.Command, args []string) error {
|
||||
return cli.delete(args, ignoreMissing)
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cli.delete(cmd.Context(), args, ignoreMissing)
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -471,7 +473,7 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command {
|
|||
return cmd
|
||||
}
|
||||
|
||||
func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force bool) error {
|
||||
func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notValidOnly bool, force bool) error {
|
||||
if duration < 2*time.Minute && !notValidOnly {
|
||||
if yes, err := ask.YesNo(
|
||||
"The duration you provided is less than 2 minutes. "+
|
||||
|
@ -484,12 +486,12 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b
|
|||
}
|
||||
|
||||
machines := []*ent.Machine{}
|
||||
if pending, err := cli.db.QueryPendingMachine(); err == nil {
|
||||
if pending, err := cli.db.QueryPendingMachine(ctx); err == nil {
|
||||
machines = append(machines, pending...)
|
||||
}
|
||||
|
||||
if !notValidOnly {
|
||||
if pending, err := cli.db.QueryMachinesInactiveSince(time.Now().UTC().Add(-duration)); err == nil {
|
||||
if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil {
|
||||
machines = append(machines, pending...)
|
||||
}
|
||||
}
|
||||
|
@ -512,7 +514,7 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b
|
|||
}
|
||||
}
|
||||
|
||||
deleted, err := cli.db.BulkDeleteWatchers(machines)
|
||||
deleted, err := cli.db.BulkDeleteWatchers(ctx, machines)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to prune machines: %w", err)
|
||||
}
|
||||
|
@ -540,8 +542,8 @@ cscli machines prune --duration 1h
|
|||
cscli machines prune --not-validated-only --force`,
|
||||
Args: cobra.NoArgs,
|
||||
DisableAutoGenTag: true,
|
||||
RunE: func(_ *cobra.Command, _ []string) error {
|
||||
return cli.prune(duration, notValidOnly, force)
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return cli.prune(cmd.Context(), duration, notValidOnly, force)
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -553,8 +555,8 @@ cscli machines prune --not-validated-only --force`,
|
|||
return cmd
|
||||
}
|
||||
|
||||
func (cli *cliMachines) validate(machineID string) error {
|
||||
if err := cli.db.ValidateMachine(machineID); err != nil {
|
||||
func (cli *cliMachines) validate(ctx context.Context, machineID string) error {
|
||||
if err := cli.db.ValidateMachine(ctx, machineID); err != nil {
|
||||
return fmt.Errorf("unable to validate machine '%s': %w", machineID, err)
|
||||
}
|
||||
|
||||
|
@ -571,8 +573,8 @@ func (cli *cliMachines) newValidateCmd() *cobra.Command {
|
|||
Example: `cscli machines validate "machine_name"`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
DisableAutoGenTag: true,
|
||||
RunE: func(_ *cobra.Command, args []string) error {
|
||||
return cli.validate(args[0])
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cli.validate(cmd.Context(), args[0])
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -690,9 +692,11 @@ func (cli *cliMachines) newInspectCmd() *cobra.Command {
|
|||
Args: cobra.ExactArgs(1),
|
||||
DisableAutoGenTag: true,
|
||||
ValidArgsFunction: cli.validMachineID,
|
||||
RunE: func(_ *cobra.Command, args []string) error {
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
machineID := args[0]
|
||||
machine, err := cli.db.QueryMachineByID(machineID)
|
||||
|
||||
machine, err := cli.db.QueryMachineByID(ctx, machineID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read machine data '%s': %w", machineID, err)
|
||||
}
|
||||
|
|
|
@ -210,7 +210,7 @@ func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *dat
|
|||
return nil
|
||||
}
|
||||
|
||||
func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error {
|
||||
func (cli *cliSupport) dumpAgents(ctx context.Context, zw *zip.Writer, db *database.Client) error {
|
||||
log.Info("Collecting agents")
|
||||
|
||||
if db == nil {
|
||||
|
@ -220,7 +220,7 @@ func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error {
|
|||
out := new(bytes.Buffer)
|
||||
cm := climachine.New(cli.cfg)
|
||||
|
||||
if err := cm.List(out, db); err != nil {
|
||||
if err := cm.List(ctx, out, db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -529,7 +529,7 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error {
|
|||
log.Warnf("could not collect bouncers information: %s", err)
|
||||
}
|
||||
|
||||
if err = cli.dumpAgents(zipWriter, db); err != nil {
|
||||
if err = cli.dumpAgents(ctx, zipWriter, db); err != nil {
|
||||
log.Warnf("could not collect agents information: %s", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -85,7 +85,9 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration {
|
|||
func (a *apic) FetchScenariosListFromDB() ([]string, error) {
|
||||
scenarios := make([]string, 0)
|
||||
|
||||
machines, err := a.dbClient.ListMachines()
|
||||
ctx := context.TODO()
|
||||
|
||||
machines, err := a.dbClient.ListMachines(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("while listing machines: %w", err)
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int,
|
|||
allMetrics := &models.AllMetrics{}
|
||||
metricsIds := make([]int, 0)
|
||||
|
||||
lps, err := a.dbClient.ListMachines()
|
||||
lps, err := a.dbClient.ListMachines(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -186,7 +186,7 @@ func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error {
|
|||
}
|
||||
|
||||
func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) {
|
||||
machines, err := a.dbClient.ListMachines()
|
||||
machines, err := a.dbClient.ListMachines(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -230,8 +230,8 @@ func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (a *apic) fetchMachineIDs() ([]string, error) {
|
||||
machines, err := a.dbClient.ListMachines()
|
||||
func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) {
|
||||
machines, err := a.dbClient.ListMachines(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -277,7 +277,7 @@ func (a *apic) SendMetrics(stop chan (bool)) {
|
|||
machineIDs := []string{}
|
||||
|
||||
reloadMachineIDs := func() {
|
||||
ids, err := a.fetchMachineIDs()
|
||||
ids, err := a.fetchMachineIDs(ctx)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get machines (%s), will retry", err)
|
||||
|
||||
|
|
|
@ -182,12 +182,12 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
|
|||
}
|
||||
|
||||
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
|
||||
ctx := context.Background()
|
||||
ctx := context.TODO()
|
||||
|
||||
dbClient, err := database.NewClient(ctx, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = dbClient.ValidateMachine(machineID)
|
||||
err = dbClient.ValidateMachine(ctx, machineID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
@ -197,7 +197,7 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg)
|
|||
dbClient, err := database.NewClient(ctx, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
machines, err := dbClient.ListMachines()
|
||||
machines, err := dbClient.ListMachines(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, machine := range machines {
|
||||
|
@ -332,7 +332,7 @@ func TestUnknownPath(t *testing.T) {
|
|||
req.Header.Set("User-Agent", UserAgent)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 404, w.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -390,7 +390,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
|
|||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
api.router.ServeHTTP(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
// wait for the request to happen
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
|
|
|
@ -9,7 +9,9 @@ import (
|
|||
func (c *Controller) HeartBeat(gctx *gin.Context) {
|
||||
machineID, _ := getMachineIDFromContext(gctx)
|
||||
|
||||
if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil {
|
||||
ctx := gctx.Request.Context()
|
||||
|
||||
if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil {
|
||||
c.HandleDBErrors(gctx, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -46,6 +46,8 @@ func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool,
|
|||
}
|
||||
|
||||
func (c *Controller) CreateMachine(gctx *gin.Context) {
|
||||
ctx := gctx.Request.Context()
|
||||
|
||||
var input models.WatcherRegistrationRequest
|
||||
|
||||
if err := gctx.ShouldBindJSON(&input); err != nil {
|
||||
|
@ -66,7 +68,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil {
|
||||
if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil {
|
||||
c.HandleDBErrors(gctx, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -55,6 +55,7 @@ type authInput struct {
|
|||
}
|
||||
|
||||
func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
|
||||
ctx := c.Request.Context()
|
||||
ret := authInput{}
|
||||
|
||||
if j.TlsAuth == nil {
|
||||
|
@ -76,7 +77,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
|
|||
|
||||
ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
|
||||
Where(machine.MachineId(ret.machineID)).
|
||||
First(j.DbClient.CTX)
|
||||
First(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
// Machine was not found, let's create it
|
||||
logger.Infof("machine %s not found, create it", ret.machineID)
|
||||
|
@ -91,7 +92,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
|
|||
|
||||
password := strfmt.Password(pwd)
|
||||
|
||||
ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType)
|
||||
ret.clientMachine, err = j.DbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err)
|
||||
}
|
||||
|
@ -175,6 +176,8 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
|||
auth *authInput
|
||||
)
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
|
||||
auth, err = j.authTLS(c)
|
||||
if err != nil {
|
||||
|
@ -198,7 +201,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
|||
}
|
||||
}
|
||||
|
||||
err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID)
|
||||
err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err)
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
|
@ -208,7 +211,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
|||
clientIP := c.ClientIP()
|
||||
|
||||
if auth.clientMachine.IpAddress == "" {
|
||||
err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
|
||||
err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err)
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
|
@ -218,7 +221,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
|||
if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" {
|
||||
log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress)
|
||||
|
||||
err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
|
||||
err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
|
@ -231,7 +234,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
|||
return nil, jwt.ErrFailedAuthentication
|
||||
}
|
||||
|
||||
if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil {
|
||||
if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil {
|
||||
log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err)
|
||||
log.Errorf("bad user agent from : %s", clientIP)
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ func TestLPMetrics(t *testing.T) {
|
|||
name: "empty metrics for LP",
|
||||
body: `{
|
||||
}`,
|
||||
expectedStatusCode: 400,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedResponse: "Missing log processor data",
|
||||
authType: PASSWORD,
|
||||
},
|
||||
|
@ -50,7 +50,7 @@ func TestLPMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 201,
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
expectedMetricsCount: 1,
|
||||
expectedResponse: "",
|
||||
expectedOSName: "foo",
|
||||
|
@ -74,7 +74,7 @@ func TestLPMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 201,
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
expectedMetricsCount: 1,
|
||||
expectedResponse: "",
|
||||
expectedOSName: "foo",
|
||||
|
@ -98,7 +98,7 @@ func TestLPMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 400,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedResponse: "Missing remediation component data",
|
||||
authType: APIKEY,
|
||||
},
|
||||
|
@ -117,7 +117,7 @@ func TestLPMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 201,
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
expectedResponse: "",
|
||||
expectedMetricsCount: 1,
|
||||
expectedFeatureFlags: "a,b,c",
|
||||
|
@ -138,7 +138,7 @@ func TestLPMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 422,
|
||||
expectedStatusCode: http.StatusUnprocessableEntity,
|
||||
expectedResponse: "log_processors.0.datasources in body is required",
|
||||
authType: PASSWORD,
|
||||
},
|
||||
|
@ -157,7 +157,7 @@ func TestLPMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 201,
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
expectedMetricsCount: 1,
|
||||
expectedOSName: "foo",
|
||||
expectedOSVersion: "42",
|
||||
|
@ -179,7 +179,7 @@ func TestLPMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 422,
|
||||
expectedStatusCode: http.StatusUnprocessableEntity,
|
||||
expectedResponse: "log_processors.0.os.name in body is required",
|
||||
authType: PASSWORD,
|
||||
},
|
||||
|
@ -199,7 +199,7 @@ func TestLPMetrics(t *testing.T) {
|
|||
assert.Equal(t, tt.expectedStatusCode, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.expectedResponse)
|
||||
|
||||
machine, _ := dbClient.QueryMachineByID("test")
|
||||
machine, _ := dbClient.QueryMachineByID(ctx, "test")
|
||||
metrics, _ := dbClient.GetLPUsageMetricsByMachineID(ctx, "test")
|
||||
|
||||
assert.Len(t, metrics, tt.expectedMetricsCount)
|
||||
|
@ -233,7 +233,7 @@ func TestRCMetrics(t *testing.T) {
|
|||
name: "empty metrics for RC",
|
||||
body: `{
|
||||
}`,
|
||||
expectedStatusCode: 400,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedResponse: "Missing remediation component data",
|
||||
authType: APIKEY,
|
||||
},
|
||||
|
@ -251,7 +251,7 @@ func TestRCMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 201,
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
expectedMetricsCount: 1,
|
||||
expectedResponse: "",
|
||||
expectedOSName: "foo",
|
||||
|
@ -273,7 +273,7 @@ func TestRCMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 201,
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
expectedMetricsCount: 1,
|
||||
expectedResponse: "",
|
||||
expectedOSName: "foo",
|
||||
|
@ -295,7 +295,7 @@ func TestRCMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 400,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedResponse: "Missing log processor data",
|
||||
authType: PASSWORD,
|
||||
},
|
||||
|
@ -312,7 +312,7 @@ func TestRCMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 201,
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
expectedResponse: "",
|
||||
expectedMetricsCount: 1,
|
||||
expectedFeatureFlags: "a,b,c",
|
||||
|
@ -331,7 +331,7 @@ func TestRCMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 201,
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
expectedMetricsCount: 1,
|
||||
expectedOSName: "foo",
|
||||
expectedOSVersion: "42",
|
||||
|
@ -351,7 +351,7 @@ func TestRCMetrics(t *testing.T) {
|
|||
}
|
||||
]
|
||||
}`,
|
||||
expectedStatusCode: 422,
|
||||
expectedStatusCode: http.StatusUnprocessableEntity,
|
||||
expectedResponse: "remediation_components.0.os.name in body is required",
|
||||
authType: APIKEY,
|
||||
},
|
||||
|
|
|
@ -687,8 +687,10 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str
|
|||
err error
|
||||
)
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
if machineID != "" {
|
||||
owner, err = c.QueryMachineByID(machineID)
|
||||
owner, err = c.QueryMachineByID(ctx, machineID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, UserNotExists) {
|
||||
return nil, fmt.Errorf("machine '%s': %w", machineID, err)
|
||||
|
|
|
@ -72,7 +72,7 @@ func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) {
|
||||
func (c *Client) CreateMachine(ctx context.Context, machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) {
|
||||
hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
c.Log.Warningf("CreateMachine: %s", err)
|
||||
|
@ -82,20 +82,20 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA
|
|||
machineExist, err := c.Ent.Machine.
|
||||
Query().
|
||||
Where(machine.MachineIdEQ(*machineID)).
|
||||
Select(machine.FieldMachineId).Strings(c.CTX)
|
||||
Select(machine.FieldMachineId).Strings(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err)
|
||||
}
|
||||
|
||||
if len(machineExist) > 0 {
|
||||
if force {
|
||||
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX)
|
||||
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("CreateMachine : %s", err)
|
||||
return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID)
|
||||
}
|
||||
|
||||
machine, err := c.QueryMachineByID(*machineID)
|
||||
machine, err := c.QueryMachineByID(ctx, *machineID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err)
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA
|
|||
SetIpAddress(ipAddress).
|
||||
SetIsValidated(isValidated).
|
||||
SetAuthType(authType).
|
||||
Save(c.CTX)
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("CreateMachine : %s", err)
|
||||
return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID)
|
||||
|
@ -122,11 +122,11 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA
|
|||
return machine, nil
|
||||
}
|
||||
|
||||
func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) {
|
||||
func (c *Client) QueryMachineByID(ctx context.Context, machineID string) (*ent.Machine, error) {
|
||||
machine, err := c.Ent.Machine.
|
||||
Query().
|
||||
Where(machine.MachineIdEQ(machineID)).
|
||||
Only(c.CTX)
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryMachineByID : %s", err)
|
||||
return &ent.Machine{}, errors.Wrapf(UserNotExists, "user '%s'", machineID)
|
||||
|
@ -135,8 +135,8 @@ func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) {
|
|||
return machine, nil
|
||||
}
|
||||
|
||||
func (c *Client) ListMachines() ([]*ent.Machine, error) {
|
||||
machines, err := c.Ent.Machine.Query().All(c.CTX)
|
||||
func (c *Client) ListMachines(ctx context.Context) ([]*ent.Machine, error) {
|
||||
machines, err := c.Ent.Machine.Query().All(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(QueryFail, "listing machines: %s", err)
|
||||
}
|
||||
|
@ -144,8 +144,8 @@ func (c *Client) ListMachines() ([]*ent.Machine, error) {
|
|||
return machines, nil
|
||||
}
|
||||
|
||||
func (c *Client) ValidateMachine(machineID string) error {
|
||||
rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(c.CTX)
|
||||
func (c *Client) ValidateMachine(ctx context.Context, machineID string) error {
|
||||
rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrapf(UpdateFail, "validating machine: %s", err)
|
||||
}
|
||||
|
@ -157,8 +157,8 @@ func (c *Client) ValidateMachine(machineID string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) {
|
||||
machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX)
|
||||
func (c *Client) QueryPendingMachine(ctx context.Context) ([]*ent.Machine, error) {
|
||||
machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(ctx)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryPendingMachine : %s", err)
|
||||
return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err)
|
||||
|
@ -167,11 +167,11 @@ func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) {
|
|||
return machines, nil
|
||||
}
|
||||
|
||||
func (c *Client) DeleteWatcher(name string) error {
|
||||
func (c *Client) DeleteWatcher(ctx context.Context, name string) error {
|
||||
nbDeleted, err := c.Ent.Machine.
|
||||
Delete().
|
||||
Where(machine.MachineIdEQ(name)).
|
||||
Exec(c.CTX)
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -183,13 +183,13 @@ func (c *Client) DeleteWatcher(name string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) {
|
||||
func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine) (int, error) {
|
||||
ids := make([]int, len(machines))
|
||||
for i, b := range machines {
|
||||
ids[i] = b.ID
|
||||
}
|
||||
|
||||
nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(c.CTX)
|
||||
nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(ctx)
|
||||
if err != nil {
|
||||
return nbDeleted, err
|
||||
}
|
||||
|
@ -197,8 +197,8 @@ func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) {
|
|||
return nbDeleted, nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateMachineLastHeartBeat(machineID string) error {
|
||||
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX)
|
||||
func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error {
|
||||
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err)
|
||||
}
|
||||
|
@ -206,11 +206,11 @@ func (c *Client) UpdateMachineLastHeartBeat(machineID string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateMachineScenarios(scenarios string, id int) error {
|
||||
func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, id int) error {
|
||||
_, err := c.Ent.Machine.UpdateOneID(id).
|
||||
SetUpdatedAt(time.Now().UTC()).
|
||||
SetScenarios(scenarios).
|
||||
Save(c.CTX)
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to update machine in database: %w", err)
|
||||
}
|
||||
|
@ -218,10 +218,10 @@ func (c *Client) UpdateMachineScenarios(scenarios string, id int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateMachineIP(ipAddr string, id int) error {
|
||||
func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, id int) error {
|
||||
_, err := c.Ent.Machine.UpdateOneID(id).
|
||||
SetIpAddress(ipAddr).
|
||||
Save(c.CTX)
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to update machine IP in database: %w", err)
|
||||
}
|
||||
|
@ -229,10 +229,10 @@ func (c *Client) UpdateMachineIP(ipAddr string, id int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateMachineVersion(ipAddr string, id int) error {
|
||||
func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, id int) error {
|
||||
_, err := c.Ent.Machine.UpdateOneID(id).
|
||||
SetVersion(ipAddr).
|
||||
Save(c.CTX)
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to update machine version in database: %w", err)
|
||||
}
|
||||
|
@ -240,8 +240,8 @@ func (c *Client) UpdateMachineVersion(ipAddr string, id int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) IsMachineRegistered(machineID string) (bool, error) {
|
||||
exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX)
|
||||
func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) {
|
||||
exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -257,11 +257,11 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Client) QueryMachinesInactiveSince(t time.Time) ([]*ent.Machine, error) {
|
||||
func (c *Client) QueryMachinesInactiveSince(ctx context.Context, t time.Time) ([]*ent.Machine, error) {
|
||||
return c.Ent.Machine.Query().Where(
|
||||
machine.Or(
|
||||
machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)),
|
||||
machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)),
|
||||
),
|
||||
).All(c.CTX)
|
||||
).All(ctx)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue