diff --git a/cmd/crowdsec-cli/clibouncer/add.go b/cmd/crowdsec-cli/clibouncer/add.go index 8c40507a9..7cc74e45f 100644 --- a/cmd/crowdsec-cli/clibouncer/add.go +++ b/cmd/crowdsec-cli/clibouncer/add.go @@ -24,7 +24,7 @@ func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) } } - _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) + _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType, false) if err != nil { return fmt.Errorf("unable to create bouncer: %w", err) } diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go index 876b613be..2b0a35568 100644 --- a/cmd/crowdsec-cli/clibouncer/bouncers.go +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -77,6 +77,7 @@ type bouncerInfo struct { AuthType string `json:"auth_type"` OS string `json:"os,omitempty"` Featureflags []string `json:"featureflags,omitempty"` + AutoCreated bool `json:"auto_created"` } func newBouncerInfo(b *ent.Bouncer) bouncerInfo { @@ -92,6 +93,7 @@ func newBouncerInfo(b *ent.Bouncer) bouncerInfo { AuthType: b.AuthType, OS: clientinfo.GetOSNameAndVersion(b), Featureflags: clientinfo.GetFeatureFlagList(b), + AutoCreated: b.AutoCreated, } } diff --git a/cmd/crowdsec-cli/clibouncer/delete.go b/cmd/crowdsec-cli/clibouncer/delete.go index 6e2f312d4..33419f483 100644 --- a/cmd/crowdsec-cli/clibouncer/delete.go +++ b/cmd/crowdsec-cli/clibouncer/delete.go @@ -4,25 +4,73 @@ import ( "context" "errors" "fmt" + "strings" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/types" ) -func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error { - for _, bouncerID := range bouncers { - if err := cli.db.DeleteBouncer(ctx, bouncerID); err != nil { - var notFoundErr *database.BouncerNotFoundError - if ignoreMissing && errors.As(err, ¬FoundErr) { - return nil - } +func (cli *cliBouncers) findParentBouncer(bouncerName string, bouncers []*ent.Bouncer) (string, error) { + bouncerPrefix := strings.Split(bouncerName, "@")[0] + for _, bouncer := range bouncers { + if strings.HasPrefix(bouncer.Name, bouncerPrefix) && !bouncer.AutoCreated { + return bouncer.Name, nil + } + } - return fmt.Errorf("unable to delete bouncer: %w", err) + return "", errors.New("no parent bouncer found") +} + +func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error { + allBouncers, err := cli.db.ListBouncers(ctx) + if err != nil { + return fmt.Errorf("unable to list bouncers: %w", err) + } + for _, bouncerName := range bouncers { + bouncer, err := cli.db.SelectBouncerByName(ctx, bouncerName) + if err != nil { + var notFoundErr *ent.NotFoundError + if ignoreMissing && errors.As(err, ¬FoundErr) { + continue + } + return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err) } - log.Infof("bouncer '%s' deleted successfully", bouncerID) + // For TLS bouncers, always delete them, they have no parents + if bouncer.AuthType == types.TlsAuthType { + if err := cli.db.DeleteBouncer(ctx, bouncerName); err != nil { + return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err) + } + continue + } + + if bouncer.AutoCreated { + parentBouncer, err := cli.findParentBouncer(bouncerName, allBouncers) + if err != nil { + log.Errorf("bouncer '%s' is auto-created, but couldn't find a parent bouncer", err) + continue + } + log.Warnf("bouncer '%s' is auto-created and cannot be deleted, delete parent bouncer %s instead", bouncerName, parentBouncer) + continue + } + //Try to find all child bouncers and delete them + for _, childBouncer := range allBouncers { + if strings.HasPrefix(childBouncer.Name, bouncerName+"@") && childBouncer.AutoCreated { + if err := cli.db.DeleteBouncer(ctx, childBouncer.Name); err != nil { + return fmt.Errorf("unable to delete bouncer %s: %w", childBouncer.Name, err) + } + log.Infof("bouncer '%s' deleted successfully", childBouncer.Name) + } + } + + if err := cli.db.DeleteBouncer(ctx, bouncerName); err != nil { + return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err) + } + + log.Infof("bouncer '%s' deleted successfully", bouncerName) } return nil diff --git a/cmd/crowdsec-cli/clibouncer/inspect.go b/cmd/crowdsec-cli/clibouncer/inspect.go index 6dac386b8..b62344baa 100644 --- a/cmd/crowdsec-cli/clibouncer/inspect.go +++ b/cmd/crowdsec-cli/clibouncer/inspect.go @@ -40,6 +40,7 @@ func (cli *cliBouncers) inspectHuman(out io.Writer, bouncer *ent.Bouncer) { {"Last Pull", lastPull}, {"Auth type", bouncer.AuthType}, {"OS", clientinfo.GetOSNameAndVersion(bouncer)}, + {"Auto Created", bouncer.AutoCreated}, }) for _, ff := range clientinfo.GetFeatureFlagList(bouncer) { diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 4cc215c34..d86234e48 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -59,6 +59,9 @@ func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, ur t.Fatal("auth type not supported") } + // Port is required for gin to properly parse the client IP + req.RemoteAddr = "127.0.0.1:1234" + l.router.ServeHTTP(w, req) return w diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index e6ed68a6e..45c02c806 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -20,28 +20,74 @@ func TestAPIKey(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "127.0.0.1:1234" router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) - assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.JSONEq(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with invalid token w = httptest.NewRecorder() req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", "a1b2c3d4e5f6") + req.RemoteAddr = "127.0.0.1:1234" router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) - assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.JSONEq(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with valid token w = httptest.NewRecorder() req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "127.0.0.1:1234" router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) + + // Login with valid token from another IP + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "4.3.2.1:1234" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "null", w.Body.String()) + + // Make the requests multiple times to make sure we only create one + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "4.3.2.1:1234" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "null", w.Body.String()) + + // Use the original bouncer again + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "127.0.0.1:1234" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "null", w.Body.String()) + + // Check if our second bouncer was properly created + bouncers := GetBouncers(t, config.API.Server.DbConfig) + + assert.Len(t, bouncers, 2) + assert.Equal(t, "test@4.3.2.1", bouncers[1].Name) + assert.Equal(t, bouncers[0].APIKey, bouncers[1].APIKey) + assert.Equal(t, bouncers[0].AuthType, bouncers[1].AuthType) + assert.False(t, bouncers[0].AutoCreated) + assert.True(t, bouncers[1].AutoCreated) } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index cdf99462c..cf4c91ded 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -24,6 +24,7 @@ import ( middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -62,6 +63,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config { } apiServerConfig := csconfig.LocalApiServerCfg{ ListenURI: "http://127.0.0.1:8080", + LogLevel: ptr.Of(log.DebugLevel), DbConfig: &dbconfig, ProfilesPath: "./tests/profiles.yaml", ConsoleConfig: &csconfig.ConsoleConfig{ @@ -206,6 +208,18 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) return "" } +func GetBouncers(t *testing.T, config *csconfig.DatabaseCfg) []*ent.Bouncer { + ctx := context.Background() + + dbClient, err := database.NewClient(ctx, config) + require.NoError(t, err) + + bouncers, err := dbClient.ListBouncers(ctx) + require.NoError(t, err) + + return bouncers +} + func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader { alertContentBytes, err := os.ReadFile(path) require.NoError(t, err) @@ -290,7 +304,7 @@ func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.Datab apiKey, err := middlewares.GenerateAPIKey(keyLength) require.NoError(t, err) - _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) + _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType, false) require.NoError(t, err) return apiKey diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index d438c9b15..3c154be4f 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -89,7 +89,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType, true) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) return nil @@ -114,18 +114,69 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { return nil } + clientIP := c.ClientIP() + ctx := c.Request.Context() hashStr := HashSHA512(val[0]) - bouncer, err := a.DbClient.SelectBouncer(ctx, hashStr) + // Appsec case, we only care if the key is valid + // No content is returned, no last_pull update or anything + if c.Request.Method == http.MethodHead { + bouncer, err := a.DbClient.SelectBouncers(ctx, hashStr, types.ApiKeyAuthType) + if err != nil { + logger.Errorf("while fetching bouncer info: %s", err) + return nil + } + return bouncer[0] + } + + // most common case, check if this specific bouncer exists + bouncer, err := a.DbClient.SelectBouncerWithIP(ctx, hashStr, clientIP) + if err != nil && !ent.IsNotFound(err) { + logger.Errorf("while fetching bouncer info: %s", err) + return nil + } + + // We found the bouncer with key and IP, we can use it + if bouncer != nil { + if bouncer.AuthType != types.ApiKeyAuthType { + logger.Errorf("bouncer isn't allowed to auth by API key") + return nil + } + return bouncer + } + + // We didn't find the bouncer with key and IP, let's try to find it with the key only + bouncers, err := a.DbClient.SelectBouncers(ctx, hashStr, types.ApiKeyAuthType) if err != nil { logger.Errorf("while fetching bouncer info: %s", err) return nil } - if bouncer.AuthType != types.ApiKeyAuthType { - logger.Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType) + if len(bouncers) == 0 { + logger.Debugf("no bouncer found with this key") + return nil + } + + logger.Debugf("found %d bouncers with this key", len(bouncers)) + + // We only have one bouncer with this key and no IP + // This is the first request made by this bouncer, keep this one + if len(bouncers) == 1 && bouncers[0].IPAddress == "" { + return bouncers[0] + } + + // Bouncers are ordered by ID, first one *should* be the manually created one + // Can probably get a bit weird if the user deletes the manually created one + bouncerName := fmt.Sprintf("%s@%s", bouncers[0].Name, clientIP) + + logger.Infof("Creating bouncer %s", bouncerName) + + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, clientIP, hashStr, types.ApiKeyAuthType, true) + + if err != nil { + logger.Errorf("while creating bouncer db entry: %s", err) return nil } @@ -156,27 +207,20 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return } + // Appsec request, return immediately if we found something + if c.Request.Method == http.MethodHead { + c.Set(BouncerContextKey, bouncer) + return + } + logger = logger.WithField("name", bouncer.Name) + // 1st time we see this bouncer, we update its IP if bouncer.IPAddress == "" { if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() - - return - } - } - - // Don't update IP on HEAD request, as it's used by the appsec to check the validity of the API key provided - if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { - log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress) - - if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { - logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return } } diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index 04ef830ae..f9e62bc65 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -41,8 +41,19 @@ func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName strin return nil } -func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx) +func (c *Client) SelectBouncers(ctx context.Context, apiKeyHash string, authType string) ([]*ent.Bouncer, error) { + //Order by ID so manually created bouncer will be first in the list to use as the base name + //when automatically creating a new entry if API keys are shared + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash), bouncer.AuthTypeEQ(authType)).Order(ent.Asc(bouncer.FieldID)).All(ctx) + if err != nil { + return nil, err + } + + return result, nil +} + +func (c *Client) SelectBouncerWithIP(ctx context.Context, apiKeyHash string, clientIP string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash), bouncer.IPAddressEQ(clientIP)).First(ctx) if err != nil { return nil, err } @@ -68,13 +79,15 @@ func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { return result, nil } -func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { +func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string, autoCreated bool) (*ent.Bouncer, error) { bouncer, err := c.Ent.Bouncer. Create(). SetName(name). SetAPIKey(apiKey). SetRevoked(false). SetAuthType(authType). + SetIPAddress(ipAddr). + SetAutoCreated(autoCreated). Save(ctx) if err != nil { if ent.IsConstraintError(err) { diff --git a/pkg/database/ent/bouncer.go b/pkg/database/ent/bouncer.go index 3b4d619e3..197f61cde 100644 --- a/pkg/database/ent/bouncer.go +++ b/pkg/database/ent/bouncer.go @@ -43,6 +43,8 @@ type Bouncer struct { Osversion string `json:"osversion,omitempty"` // Featureflags holds the value of the "featureflags" field. Featureflags string `json:"featureflags,omitempty"` + // AutoCreated holds the value of the "auto_created" field. + AutoCreated bool `json:"auto_created"` selectValues sql.SelectValues } @@ -51,7 +53,7 @@ func (*Bouncer) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case bouncer.FieldRevoked: + case bouncer.FieldRevoked, bouncer.FieldAutoCreated: values[i] = new(sql.NullBool) case bouncer.FieldID: values[i] = new(sql.NullInt64) @@ -159,6 +161,12 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { } else if value.Valid { b.Featureflags = value.String } + case bouncer.FieldAutoCreated: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field auto_created", values[i]) + } else if value.Valid { + b.AutoCreated = value.Bool + } default: b.selectValues.Set(columns[i], values[i]) } @@ -234,6 +242,9 @@ func (b *Bouncer) String() string { builder.WriteString(", ") builder.WriteString("featureflags=") builder.WriteString(b.Featureflags) + builder.WriteString(", ") + builder.WriteString("auto_created=") + builder.WriteString(fmt.Sprintf("%v", b.AutoCreated)) builder.WriteByte(')') return builder.String() } diff --git a/pkg/database/ent/bouncer/bouncer.go b/pkg/database/ent/bouncer/bouncer.go index a6f62aead..f25b5a581 100644 --- a/pkg/database/ent/bouncer/bouncer.go +++ b/pkg/database/ent/bouncer/bouncer.go @@ -39,6 +39,8 @@ const ( FieldOsversion = "osversion" // FieldFeatureflags holds the string denoting the featureflags field in the database. FieldFeatureflags = "featureflags" + // FieldAutoCreated holds the string denoting the auto_created field in the database. + FieldAutoCreated = "auto_created" // Table holds the table name of the bouncer in the database. Table = "bouncers" ) @@ -59,6 +61,7 @@ var Columns = []string{ FieldOsname, FieldOsversion, FieldFeatureflags, + FieldAutoCreated, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -82,6 +85,8 @@ var ( DefaultIPAddress string // DefaultAuthType holds the default value on creation for the "auth_type" field. DefaultAuthType string + // DefaultAutoCreated holds the default value on creation for the "auto_created" field. + DefaultAutoCreated bool ) // OrderOption defines the ordering options for the Bouncer queries. @@ -156,3 +161,8 @@ func ByOsversion(opts ...sql.OrderTermOption) OrderOption { func ByFeatureflags(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldFeatureflags, opts...).ToFunc() } + +// ByAutoCreated orders the results by the auto_created field. +func ByAutoCreated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAutoCreated, opts...).ToFunc() +} diff --git a/pkg/database/ent/bouncer/where.go b/pkg/database/ent/bouncer/where.go index e02199bc0..79b899935 100644 --- a/pkg/database/ent/bouncer/where.go +++ b/pkg/database/ent/bouncer/where.go @@ -119,6 +119,11 @@ func Featureflags(v string) predicate.Bouncer { return predicate.Bouncer(sql.FieldEQ(FieldFeatureflags, v)) } +// AutoCreated applies equality check predicate on the "auto_created" field. It's identical to AutoCreatedEQ. +func AutoCreated(v bool) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldAutoCreated, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Bouncer { return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v)) @@ -904,6 +909,16 @@ func FeatureflagsContainsFold(v string) predicate.Bouncer { return predicate.Bouncer(sql.FieldContainsFold(FieldFeatureflags, v)) } +// AutoCreatedEQ applies the EQ predicate on the "auto_created" field. +func AutoCreatedEQ(v bool) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldAutoCreated, v)) +} + +// AutoCreatedNEQ applies the NEQ predicate on the "auto_created" field. +func AutoCreatedNEQ(v bool) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNEQ(FieldAutoCreated, v)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Bouncer) predicate.Bouncer { return predicate.Bouncer(sql.AndPredicates(predicates...)) diff --git a/pkg/database/ent/bouncer_create.go b/pkg/database/ent/bouncer_create.go index 29b23f87c..9ff4c0e08 100644 --- a/pkg/database/ent/bouncer_create.go +++ b/pkg/database/ent/bouncer_create.go @@ -178,6 +178,20 @@ func (bc *BouncerCreate) SetNillableFeatureflags(s *string) *BouncerCreate { return bc } +// SetAutoCreated sets the "auto_created" field. +func (bc *BouncerCreate) SetAutoCreated(b bool) *BouncerCreate { + bc.mutation.SetAutoCreated(b) + return bc +} + +// SetNillableAutoCreated sets the "auto_created" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableAutoCreated(b *bool) *BouncerCreate { + if b != nil { + bc.SetAutoCreated(*b) + } + return bc +} + // Mutation returns the BouncerMutation object of the builder. func (bc *BouncerCreate) Mutation() *BouncerMutation { return bc.mutation @@ -229,6 +243,10 @@ func (bc *BouncerCreate) defaults() { v := bouncer.DefaultAuthType bc.mutation.SetAuthType(v) } + if _, ok := bc.mutation.AutoCreated(); !ok { + v := bouncer.DefaultAutoCreated + bc.mutation.SetAutoCreated(v) + } } // check runs all checks and user-defined validators on the builder. @@ -251,6 +269,9 @@ func (bc *BouncerCreate) check() error { if _, ok := bc.mutation.AuthType(); !ok { return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Bouncer.auth_type"`)} } + if _, ok := bc.mutation.AutoCreated(); !ok { + return &ValidationError{Name: "auto_created", err: errors.New(`ent: missing required field "Bouncer.auto_created"`)} + } return nil } @@ -329,6 +350,10 @@ func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) { _spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value) _node.Featureflags = value } + if value, ok := bc.mutation.AutoCreated(); ok { + _spec.SetField(bouncer.FieldAutoCreated, field.TypeBool, value) + _node.AutoCreated = value + } return _node, _spec } diff --git a/pkg/database/ent/migrate/schema.go b/pkg/database/ent/migrate/schema.go index 986f5bc8c..dae248c7f 100644 --- a/pkg/database/ent/migrate/schema.go +++ b/pkg/database/ent/migrate/schema.go @@ -74,6 +74,7 @@ var ( {Name: "osname", Type: field.TypeString, Nullable: true}, {Name: "osversion", Type: field.TypeString, Nullable: true}, {Name: "featureflags", Type: field.TypeString, Nullable: true}, + {Name: "auto_created", Type: field.TypeBool, Default: false}, } // BouncersTable holds the schema information for the "bouncers" table. BouncersTable = &schema.Table{ diff --git a/pkg/database/ent/mutation.go b/pkg/database/ent/mutation.go index 5c6596f3d..fa1ccb3da 100644 --- a/pkg/database/ent/mutation.go +++ b/pkg/database/ent/mutation.go @@ -2471,6 +2471,7 @@ type BouncerMutation struct { osname *string osversion *string featureflags *string + auto_created *bool clearedFields map[string]struct{} done bool oldValue func(context.Context) (*Bouncer, error) @@ -3134,6 +3135,42 @@ func (m *BouncerMutation) ResetFeatureflags() { delete(m.clearedFields, bouncer.FieldFeatureflags) } +// SetAutoCreated sets the "auto_created" field. +func (m *BouncerMutation) SetAutoCreated(b bool) { + m.auto_created = &b +} + +// AutoCreated returns the value of the "auto_created" field in the mutation. +func (m *BouncerMutation) AutoCreated() (r bool, exists bool) { + v := m.auto_created + if v == nil { + return + } + return *v, true +} + +// OldAutoCreated returns the old "auto_created" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldAutoCreated(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAutoCreated is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAutoCreated requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAutoCreated: %w", err) + } + return oldValue.AutoCreated, nil +} + +// ResetAutoCreated resets all changes to the "auto_created" field. +func (m *BouncerMutation) ResetAutoCreated() { + m.auto_created = nil +} + // Where appends a list predicates to the BouncerMutation builder. func (m *BouncerMutation) Where(ps ...predicate.Bouncer) { m.predicates = append(m.predicates, ps...) @@ -3168,7 +3205,7 @@ func (m *BouncerMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *BouncerMutation) Fields() []string { - fields := make([]string, 0, 13) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, bouncer.FieldCreatedAt) } @@ -3208,6 +3245,9 @@ func (m *BouncerMutation) Fields() []string { if m.featureflags != nil { fields = append(fields, bouncer.FieldFeatureflags) } + if m.auto_created != nil { + fields = append(fields, bouncer.FieldAutoCreated) + } return fields } @@ -3242,6 +3282,8 @@ func (m *BouncerMutation) Field(name string) (ent.Value, bool) { return m.Osversion() case bouncer.FieldFeatureflags: return m.Featureflags() + case bouncer.FieldAutoCreated: + return m.AutoCreated() } return nil, false } @@ -3277,6 +3319,8 @@ func (m *BouncerMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldOsversion(ctx) case bouncer.FieldFeatureflags: return m.OldFeatureflags(ctx) + case bouncer.FieldAutoCreated: + return m.OldAutoCreated(ctx) } return nil, fmt.Errorf("unknown Bouncer field %s", name) } @@ -3377,6 +3421,13 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error { } m.SetFeatureflags(v) return nil + case bouncer.FieldAutoCreated: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAutoCreated(v) + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } @@ -3510,6 +3561,9 @@ func (m *BouncerMutation) ResetField(name string) error { case bouncer.FieldFeatureflags: m.ResetFeatureflags() return nil + case bouncer.FieldAutoCreated: + m.ResetAutoCreated() + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } diff --git a/pkg/database/ent/runtime.go b/pkg/database/ent/runtime.go index 154134906..49921a17b 100644 --- a/pkg/database/ent/runtime.go +++ b/pkg/database/ent/runtime.go @@ -76,6 +76,10 @@ func init() { bouncerDescAuthType := bouncerFields[9].Descriptor() // bouncer.DefaultAuthType holds the default value on creation for the auth_type field. bouncer.DefaultAuthType = bouncerDescAuthType.Default.(string) + // bouncerDescAutoCreated is the schema descriptor for auto_created field. + bouncerDescAutoCreated := bouncerFields[13].Descriptor() + // bouncer.DefaultAutoCreated holds the default value on creation for the auto_created field. + bouncer.DefaultAutoCreated = bouncerDescAutoCreated.Default.(bool) configitemFields := schema.ConfigItem{}.Fields() _ = configitemFields // configitemDescCreatedAt is the schema descriptor for created_at field. diff --git a/pkg/database/ent/schema/bouncer.go b/pkg/database/ent/schema/bouncer.go index 599c4c404..c176bf0f7 100644 --- a/pkg/database/ent/schema/bouncer.go +++ b/pkg/database/ent/schema/bouncer.go @@ -33,6 +33,8 @@ func (Bouncer) Fields() []ent.Field { field.String("osname").Optional(), field.String("osversion").Optional(), field.String("featureflags").Optional(), + // Old auto-created TLS bouncers will have a wrong value for this field + field.Bool("auto_created").StructTag(`json:"auto_created"`).Default(false).Immutable(), } } diff --git a/test/bats/10_bouncers.bats b/test/bats/10_bouncers.bats index f99913dce..b1c90116d 100644 --- a/test/bats/10_bouncers.bats +++ b/test/bats/10_bouncers.bats @@ -63,7 +63,7 @@ teardown() { @test "delete non-existent bouncer" { # this is a fatal error, which is not consistent with "machines delete" rune -1 cscli bouncers delete something - assert_stderr --partial "unable to delete bouncer: 'something' does not exist" + assert_stderr --partial "unable to delete bouncer something: ent: bouncer not found" rune -0 cscli bouncers delete something --ignore-missing refute_stderr } @@ -144,3 +144,56 @@ teardown() { rune -0 cscli bouncers prune assert_output 'No bouncers to prune.' } + +curl_localhost() { + [[ -z "$API_KEY" ]] && { fail "${FUNCNAME[0]}: missing API_KEY"; } + local path=$1 + shift + curl "localhost:8080$path" -sS --fail-with-body -H "X-Api-Key: $API_KEY" "$@" +} + +# We can't use curl-with-key here, as we want to query localhost, not 127.0.0.1 +@test "multiple bouncers sharing api key" { + export API_KEY=bouncerkey + + # crowdsec needs to listen on all interfaces + rune -0 ./instance-crowdsec stop + rune -0 config_set 'del(.api.server.listen_socket) | del(.api.server.listen_uri)' + echo "{'api':{'server':{'listen_uri':0.0.0.0:8080}}}" >"${CONFIG_YAML}.local" + + rune -0 ./instance-crowdsec start + + # add a decision for our bouncers + rune -0 cscli decisions add -i '1.2.3.5' + + rune -0 cscli bouncers add test-auto -k "$API_KEY" + + # query with 127.0.0.1 as source ip + rune -0 curl_localhost "/v1/decisions/stream" -4 + rune -0 jq -r '.new' <(output) + assert_output --partial '1.2.3.5' + + # now with ::1, we should get the same IP, even though we are using the same key + rune -0 curl_localhost "/v1/decisions/stream" -6 + rune -0 jq -r '.new' <(output) + assert_output --partial '1.2.3.5' + + rune -0 cscli bouncers list -o json + rune -0 jq -c '[.[] | [.name,.revoked,.ip_address,.auto_created]]' <(output) + assert_json '[["test-auto",false,"127.0.0.1",false],["test-auto@::1",false,"::1",true]]' + + # check the 2nd bouncer was created automatically + rune -0 cscli bouncers inspect "test-auto@::1" -o json + rune -0 jq -r '.ip_address' <(output) + assert_output --partial '::1' + + # attempt to delete the auto-created bouncer, it should fail + rune -0 cscli bouncers delete 'test-auto@::1' + assert_stderr --partial 'cannot be deleted' + + # delete the "real" bouncer, it should delete both + rune -0 cscli bouncers delete 'test-auto' + + rune -0 cscli bouncers list -o json + assert_json [] +} diff --git a/test/lib/init/crowdsec-daemon b/test/lib/init/crowdsec-daemon index a232f344b..ba8e98992 100755 --- a/test/lib/init/crowdsec-daemon +++ b/test/lib/init/crowdsec-daemon @@ -51,7 +51,11 @@ stop() { PGID="$(ps -o pgid= -p "$(cat "${DAEMON_PID}")" | tr -d ' ')" # ps above should work on linux, freebsd, busybox.. if [[ -n "${PGID}" ]]; then - kill -- "-${PGID}" + kill -- "-${PGID}" + + while pgrep -g "${PGID}" >/dev/null; do + sleep .05 + done fi rm -f -- "${DAEMON_PID}"