context propagation: pkg/database/machines (#3248)

This commit is contained in:
mmetc 2024-09-20 16:00:58 +02:00 committed by GitHub
parent e2196bdd66
commit fee3debdcc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 109 additions and 94 deletions

View file

@ -1,6 +1,7 @@
package climachine
import (
"context"
"encoding/csv"
"encoding/json"
"errors"
@ -210,11 +211,11 @@ func (cli *cliMachines) listCSV(out io.Writer, machines ent.Machines) error {
return nil
}
func (cli *cliMachines) List(out io.Writer, db *database.Client) error {
func (cli *cliMachines) List(ctx context.Context, out io.Writer, db *database.Client) error {
// XXX: must use the provided db object, the one in the struct might be nil
// (calling List directly skips the PersistentPreRunE)
machines, err := db.ListMachines()
machines, err := db.ListMachines(ctx)
if err != nil {
return fmt.Errorf("unable to list machines: %w", err)
}
@ -251,8 +252,8 @@ func (cli *cliMachines) newListCmd() *cobra.Command {
Example: `cscli machines list`,
Args: cobra.NoArgs,
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.List(color.Output, cli.db)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.List(cmd.Context(), color.Output, cli.db)
},
}
@ -278,8 +279,8 @@ func (cli *cliMachines) newAddCmd() *cobra.Command {
cscli machines add MyTestMachine --auto
cscli machines add MyTestMachine --password MyPassword
cscli machines add -f- --auto > /tmp/mycreds.yaml`,
RunE: func(_ *cobra.Command, args []string) error {
return cli.add(args, string(password), dumpFile, apiURL, interactive, autoAdd, force)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force)
},
}
@ -294,7 +295,7 @@ cscli machines add -f- --auto > /tmp/mycreds.yaml`,
return cmd
}
func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error {
func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error {
var (
err error
machineID string
@ -353,7 +354,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri
password := strfmt.Password(machinePassword)
_, err = cli.db.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType)
_, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType)
if err != nil {
return fmt.Errorf("unable to create machine: %w", err)
}
@ -399,6 +400,7 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp
var err error
cfg := cli.cfg()
ctx := cmd.Context()
// need to load config and db because PersistentPreRunE is not called for completions
@ -407,13 +409,13 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp
return nil, cobra.ShellCompDirectiveNoFileComp
}
cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
cli.db, err = require.DBClient(ctx, cfg.DbConfig)
if err != nil {
cobra.CompError("unable to list machines " + err.Error())
return nil, cobra.ShellCompDirectiveNoFileComp
}
machines, err := cli.db.ListMachines()
machines, err := cli.db.ListMachines(ctx)
if err != nil {
cobra.CompError("unable to list machines " + err.Error())
return nil, cobra.ShellCompDirectiveNoFileComp
@ -430,9 +432,9 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp
return ret, cobra.ShellCompDirectiveNoFileComp
}
func (cli *cliMachines) delete(machines []string, ignoreMissing bool) error {
func (cli *cliMachines) delete(ctx context.Context, machines []string, ignoreMissing bool) error {
for _, machineID := range machines {
if err := cli.db.DeleteWatcher(machineID); err != nil {
if err := cli.db.DeleteWatcher(ctx, machineID); err != nil {
var notFoundErr *database.MachineNotFoundError
if ignoreMissing && errors.As(err, &notFoundErr) {
return nil
@ -460,8 +462,8 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command {
Aliases: []string{"remove"},
DisableAutoGenTag: true,
ValidArgsFunction: cli.validMachineID,
RunE: func(_ *cobra.Command, args []string) error {
return cli.delete(args, ignoreMissing)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.delete(cmd.Context(), args, ignoreMissing)
},
}
@ -471,7 +473,7 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command {
return cmd
}
func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force bool) error {
func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notValidOnly bool, force bool) error {
if duration < 2*time.Minute && !notValidOnly {
if yes, err := ask.YesNo(
"The duration you provided is less than 2 minutes. "+
@ -484,12 +486,12 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b
}
machines := []*ent.Machine{}
if pending, err := cli.db.QueryPendingMachine(); err == nil {
if pending, err := cli.db.QueryPendingMachine(ctx); err == nil {
machines = append(machines, pending...)
}
if !notValidOnly {
if pending, err := cli.db.QueryMachinesInactiveSince(time.Now().UTC().Add(-duration)); err == nil {
if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil {
machines = append(machines, pending...)
}
}
@ -512,7 +514,7 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b
}
}
deleted, err := cli.db.BulkDeleteWatchers(machines)
deleted, err := cli.db.BulkDeleteWatchers(ctx, machines)
if err != nil {
return fmt.Errorf("unable to prune machines: %w", err)
}
@ -540,8 +542,8 @@ cscli machines prune --duration 1h
cscli machines prune --not-validated-only --force`,
Args: cobra.NoArgs,
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.prune(duration, notValidOnly, force)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.prune(cmd.Context(), duration, notValidOnly, force)
},
}
@ -553,8 +555,8 @@ cscli machines prune --not-validated-only --force`,
return cmd
}
func (cli *cliMachines) validate(machineID string) error {
if err := cli.db.ValidateMachine(machineID); err != nil {
func (cli *cliMachines) validate(ctx context.Context, machineID string) error {
if err := cli.db.ValidateMachine(ctx, machineID); err != nil {
return fmt.Errorf("unable to validate machine '%s': %w", machineID, err)
}
@ -571,8 +573,8 @@ func (cli *cliMachines) newValidateCmd() *cobra.Command {
Example: `cscli machines validate "machine_name"`,
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, args []string) error {
return cli.validate(args[0])
RunE: func(cmd *cobra.Command, args []string) error {
return cli.validate(cmd.Context(), args[0])
},
}
@ -690,9 +692,11 @@ func (cli *cliMachines) newInspectCmd() *cobra.Command {
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
ValidArgsFunction: cli.validMachineID,
RunE: func(_ *cobra.Command, args []string) error {
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
machineID := args[0]
machine, err := cli.db.QueryMachineByID(machineID)
machine, err := cli.db.QueryMachineByID(ctx, machineID)
if err != nil {
return fmt.Errorf("unable to read machine data '%s': %w", machineID, err)
}

View file

@ -210,7 +210,7 @@ func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *dat
return nil
}
func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error {
func (cli *cliSupport) dumpAgents(ctx context.Context, zw *zip.Writer, db *database.Client) error {
log.Info("Collecting agents")
if db == nil {
@ -220,7 +220,7 @@ func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error {
out := new(bytes.Buffer)
cm := climachine.New(cli.cfg)
if err := cm.List(out, db); err != nil {
if err := cm.List(ctx, out, db); err != nil {
return err
}
@ -529,7 +529,7 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error {
log.Warnf("could not collect bouncers information: %s", err)
}
if err = cli.dumpAgents(zipWriter, db); err != nil {
if err = cli.dumpAgents(ctx, zipWriter, db); err != nil {
log.Warnf("could not collect agents information: %s", err)
}

View file

@ -85,7 +85,9 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration {
func (a *apic) FetchScenariosListFromDB() ([]string, error) {
scenarios := make([]string, 0)
machines, err := a.dbClient.ListMachines()
ctx := context.TODO()
machines, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, fmt.Errorf("while listing machines: %w", err)
}

View file

@ -27,7 +27,7 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int,
allMetrics := &models.AllMetrics{}
metricsIds := make([]int, 0)
lps, err := a.dbClient.ListMachines()
lps, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, nil, err
}
@ -186,7 +186,7 @@ func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error {
}
func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) {
machines, err := a.dbClient.ListMachines()
machines, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, err
}
@ -230,8 +230,8 @@ func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) {
}, nil
}
func (a *apic) fetchMachineIDs() ([]string, error) {
machines, err := a.dbClient.ListMachines()
func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) {
machines, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, err
}
@ -277,7 +277,7 @@ func (a *apic) SendMetrics(stop chan (bool)) {
machineIDs := []string{}
reloadMachineIDs := func() {
ids, err := a.fetchMachineIDs()
ids, err := a.fetchMachineIDs(ctx)
if err != nil {
log.Debugf("unable to get machines (%s), will retry", err)

View file

@ -182,12 +182,12 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
}
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
ctx := context.Background()
ctx := context.TODO()
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)
err = dbClient.ValidateMachine(machineID)
err = dbClient.ValidateMachine(ctx, machineID)
require.NoError(t, err)
}
@ -197,7 +197,7 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg)
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)
machines, err := dbClient.ListMachines()
machines, err := dbClient.ListMachines(ctx)
require.NoError(t, err)
for _, machine := range machines {
@ -332,7 +332,7 @@ func TestUnknownPath(t *testing.T) {
req.Header.Set("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
assert.Equal(t, http.StatusNotFound, w.Code)
}
/*
@ -390,7 +390,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil)
req.Header.Set("User-Agent", UserAgent)
api.router.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
assert.Equal(t, http.StatusNotFound, w.Code)
// wait for the request to happen
time.Sleep(500 * time.Millisecond)

View file

@ -9,7 +9,9 @@ import (
func (c *Controller) HeartBeat(gctx *gin.Context) {
machineID, _ := getMachineIDFromContext(gctx)
if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil {
ctx := gctx.Request.Context()
if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil {
c.HandleDBErrors(gctx, err)
return
}

View file

@ -46,6 +46,8 @@ func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool,
}
func (c *Controller) CreateMachine(gctx *gin.Context) {
ctx := gctx.Request.Context()
var input models.WatcherRegistrationRequest
if err := gctx.ShouldBindJSON(&input); err != nil {
@ -66,7 +68,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) {
return
}
if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil {
if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil {
c.HandleDBErrors(gctx, err)
return
}

View file

@ -55,6 +55,7 @@ type authInput struct {
}
func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
ctx := c.Request.Context()
ret := authInput{}
if j.TlsAuth == nil {
@ -76,7 +77,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
Where(machine.MachineId(ret.machineID)).
First(j.DbClient.CTX)
First(ctx)
if ent.IsNotFound(err) {
// Machine was not found, let's create it
logger.Infof("machine %s not found, create it", ret.machineID)
@ -91,7 +92,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
password := strfmt.Password(pwd)
ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType)
ret.clientMachine, err = j.DbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType)
if err != nil {
return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err)
}
@ -175,6 +176,8 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
auth *authInput
)
ctx := c.Request.Context()
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
auth, err = j.authTLS(c)
if err != nil {
@ -198,7 +201,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
}
}
err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID)
err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID)
if err != nil {
log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err)
return nil, jwt.ErrFailedAuthentication
@ -208,7 +211,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
clientIP := c.ClientIP()
if auth.clientMachine.IpAddress == "" {
err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID)
if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err)
return nil, jwt.ErrFailedAuthentication
@ -218,7 +221,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" {
log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress)
err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID)
if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
return nil, jwt.ErrFailedAuthentication
@ -231,7 +234,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
return nil, jwt.ErrFailedAuthentication
}
if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil {
if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil {
log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err)
log.Errorf("bad user agent from : %s", clientIP)

View file

@ -30,7 +30,7 @@ func TestLPMetrics(t *testing.T) {
name: "empty metrics for LP",
body: `{
}`,
expectedStatusCode: 400,
expectedStatusCode: http.StatusBadRequest,
expectedResponse: "Missing log processor data",
authType: PASSWORD,
},
@ -50,7 +50,7 @@ func TestLPMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 201,
expectedStatusCode: http.StatusCreated,
expectedMetricsCount: 1,
expectedResponse: "",
expectedOSName: "foo",
@ -74,7 +74,7 @@ func TestLPMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 201,
expectedStatusCode: http.StatusCreated,
expectedMetricsCount: 1,
expectedResponse: "",
expectedOSName: "foo",
@ -98,7 +98,7 @@ func TestLPMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 400,
expectedStatusCode: http.StatusBadRequest,
expectedResponse: "Missing remediation component data",
authType: APIKEY,
},
@ -117,7 +117,7 @@ func TestLPMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 201,
expectedStatusCode: http.StatusCreated,
expectedResponse: "",
expectedMetricsCount: 1,
expectedFeatureFlags: "a,b,c",
@ -138,7 +138,7 @@ func TestLPMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 422,
expectedStatusCode: http.StatusUnprocessableEntity,
expectedResponse: "log_processors.0.datasources in body is required",
authType: PASSWORD,
},
@ -157,7 +157,7 @@ func TestLPMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 201,
expectedStatusCode: http.StatusCreated,
expectedMetricsCount: 1,
expectedOSName: "foo",
expectedOSVersion: "42",
@ -179,7 +179,7 @@ func TestLPMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 422,
expectedStatusCode: http.StatusUnprocessableEntity,
expectedResponse: "log_processors.0.os.name in body is required",
authType: PASSWORD,
},
@ -199,7 +199,7 @@ func TestLPMetrics(t *testing.T) {
assert.Equal(t, tt.expectedStatusCode, w.Code)
assert.Contains(t, w.Body.String(), tt.expectedResponse)
machine, _ := dbClient.QueryMachineByID("test")
machine, _ := dbClient.QueryMachineByID(ctx, "test")
metrics, _ := dbClient.GetLPUsageMetricsByMachineID(ctx, "test")
assert.Len(t, metrics, tt.expectedMetricsCount)
@ -233,7 +233,7 @@ func TestRCMetrics(t *testing.T) {
name: "empty metrics for RC",
body: `{
}`,
expectedStatusCode: 400,
expectedStatusCode: http.StatusBadRequest,
expectedResponse: "Missing remediation component data",
authType: APIKEY,
},
@ -251,7 +251,7 @@ func TestRCMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 201,
expectedStatusCode: http.StatusCreated,
expectedMetricsCount: 1,
expectedResponse: "",
expectedOSName: "foo",
@ -273,7 +273,7 @@ func TestRCMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 201,
expectedStatusCode: http.StatusCreated,
expectedMetricsCount: 1,
expectedResponse: "",
expectedOSName: "foo",
@ -295,7 +295,7 @@ func TestRCMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 400,
expectedStatusCode: http.StatusBadRequest,
expectedResponse: "Missing log processor data",
authType: PASSWORD,
},
@ -312,7 +312,7 @@ func TestRCMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 201,
expectedStatusCode: http.StatusCreated,
expectedResponse: "",
expectedMetricsCount: 1,
expectedFeatureFlags: "a,b,c",
@ -331,7 +331,7 @@ func TestRCMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 201,
expectedStatusCode: http.StatusCreated,
expectedMetricsCount: 1,
expectedOSName: "foo",
expectedOSVersion: "42",
@ -351,7 +351,7 @@ func TestRCMetrics(t *testing.T) {
}
]
}`,
expectedStatusCode: 422,
expectedStatusCode: http.StatusUnprocessableEntity,
expectedResponse: "remediation_components.0.os.name in body is required",
authType: APIKEY,
},

View file

@ -687,8 +687,10 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str
err error
)
ctx := context.TODO()
if machineID != "" {
owner, err = c.QueryMachineByID(machineID)
owner, err = c.QueryMachineByID(ctx, machineID)
if err != nil {
if !errors.Is(err, UserNotExists) {
return nil, fmt.Errorf("machine '%s': %w", machineID, err)

View file

@ -72,7 +72,7 @@ func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string,
return nil
}
func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) {
func (c *Client) CreateMachine(ctx context.Context, machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) {
hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if err != nil {
c.Log.Warningf("CreateMachine: %s", err)
@ -82,20 +82,20 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA
machineExist, err := c.Ent.Machine.
Query().
Where(machine.MachineIdEQ(*machineID)).
Select(machine.FieldMachineId).Strings(c.CTX)
Select(machine.FieldMachineId).Strings(ctx)
if err != nil {
return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err)
}
if len(machineExist) > 0 {
if force {
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX)
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(ctx)
if err != nil {
c.Log.Warningf("CreateMachine : %s", err)
return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID)
}
machine, err := c.QueryMachineByID(*machineID)
machine, err := c.QueryMachineByID(ctx, *machineID)
if err != nil {
return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err)
}
@ -113,7 +113,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA
SetIpAddress(ipAddress).
SetIsValidated(isValidated).
SetAuthType(authType).
Save(c.CTX)
Save(ctx)
if err != nil {
c.Log.Warningf("CreateMachine : %s", err)
return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID)
@ -122,11 +122,11 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA
return machine, nil
}
func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) {
func (c *Client) QueryMachineByID(ctx context.Context, machineID string) (*ent.Machine, error) {
machine, err := c.Ent.Machine.
Query().
Where(machine.MachineIdEQ(machineID)).
Only(c.CTX)
Only(ctx)
if err != nil {
c.Log.Warningf("QueryMachineByID : %s", err)
return &ent.Machine{}, errors.Wrapf(UserNotExists, "user '%s'", machineID)
@ -135,8 +135,8 @@ func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) {
return machine, nil
}
func (c *Client) ListMachines() ([]*ent.Machine, error) {
machines, err := c.Ent.Machine.Query().All(c.CTX)
func (c *Client) ListMachines(ctx context.Context) ([]*ent.Machine, error) {
machines, err := c.Ent.Machine.Query().All(ctx)
if err != nil {
return nil, errors.Wrapf(QueryFail, "listing machines: %s", err)
}
@ -144,8 +144,8 @@ func (c *Client) ListMachines() ([]*ent.Machine, error) {
return machines, nil
}
func (c *Client) ValidateMachine(machineID string) error {
rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(c.CTX)
func (c *Client) ValidateMachine(ctx context.Context, machineID string) error {
rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(ctx)
if err != nil {
return errors.Wrapf(UpdateFail, "validating machine: %s", err)
}
@ -157,8 +157,8 @@ func (c *Client) ValidateMachine(machineID string) error {
return nil
}
func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) {
machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX)
func (c *Client) QueryPendingMachine(ctx context.Context) ([]*ent.Machine, error) {
machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(ctx)
if err != nil {
c.Log.Warningf("QueryPendingMachine : %s", err)
return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err)
@ -167,11 +167,11 @@ func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) {
return machines, nil
}
func (c *Client) DeleteWatcher(name string) error {
func (c *Client) DeleteWatcher(ctx context.Context, name string) error {
nbDeleted, err := c.Ent.Machine.
Delete().
Where(machine.MachineIdEQ(name)).
Exec(c.CTX)
Exec(ctx)
if err != nil {
return err
}
@ -183,13 +183,13 @@ func (c *Client) DeleteWatcher(name string) error {
return nil
}
func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) {
func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine) (int, error) {
ids := make([]int, len(machines))
for i, b := range machines {
ids[i] = b.ID
}
nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(c.CTX)
nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(ctx)
if err != nil {
return nbDeleted, err
}
@ -197,8 +197,8 @@ func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) {
return nbDeleted, nil
}
func (c *Client) UpdateMachineLastHeartBeat(machineID string) error {
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX)
func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error {
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx)
if err != nil {
return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err)
}
@ -206,11 +206,11 @@ func (c *Client) UpdateMachineLastHeartBeat(machineID string) error {
return nil
}
func (c *Client) UpdateMachineScenarios(scenarios string, id int) error {
func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, id int) error {
_, err := c.Ent.Machine.UpdateOneID(id).
SetUpdatedAt(time.Now().UTC()).
SetScenarios(scenarios).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update machine in database: %w", err)
}
@ -218,10 +218,10 @@ func (c *Client) UpdateMachineScenarios(scenarios string, id int) error {
return nil
}
func (c *Client) UpdateMachineIP(ipAddr string, id int) error {
func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, id int) error {
_, err := c.Ent.Machine.UpdateOneID(id).
SetIpAddress(ipAddr).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update machine IP in database: %w", err)
}
@ -229,10 +229,10 @@ func (c *Client) UpdateMachineIP(ipAddr string, id int) error {
return nil
}
func (c *Client) UpdateMachineVersion(ipAddr string, id int) error {
func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, id int) error {
_, err := c.Ent.Machine.UpdateOneID(id).
SetVersion(ipAddr).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update machine version in database: %w", err)
}
@ -240,8 +240,8 @@ func (c *Client) UpdateMachineVersion(ipAddr string, id int) error {
return nil
}
func (c *Client) IsMachineRegistered(machineID string) (bool, error) {
exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX)
func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) {
exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx)
if err != nil {
return false, err
}
@ -257,11 +257,11 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) {
return false, nil
}
func (c *Client) QueryMachinesInactiveSince(t time.Time) ([]*ent.Machine, error) {
func (c *Client) QueryMachinesInactiveSince(ctx context.Context, t time.Time) ([]*ent.Machine, error) {
return c.Ent.Machine.Query().Where(
machine.Or(
machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)),
machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)),
),
).All(c.CTX)
).All(ctx)
}