mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-11 12:25:53 +02:00
context propagation: pkg/apiserver (#3272)
* context propagation: apic.Push() * context propagation: NewServer() * lint
This commit is contained in:
parent
40021b6bcf
commit
b9bccfa56f
12 changed files with 59 additions and 61 deletions
|
@ -321,7 +321,7 @@ issues:
|
|||
# `err` is often shadowed, we may continue to do it
|
||||
- linters:
|
||||
- govet
|
||||
text: "shadow: declaration of \"err\" shadows declaration"
|
||||
text: "shadow: declaration of \"(err|ctx)\" shadows declaration"
|
||||
|
||||
- linters:
|
||||
- errcheck
|
||||
|
|
|
@ -127,7 +127,7 @@ func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client
|
|||
return fmt.Errorf("unable to initialize API client: %w", err)
|
||||
}
|
||||
|
||||
t.Go(apic.Push)
|
||||
t.Go(func() error { return apic.Push(ctx) })
|
||||
|
||||
papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel())
|
||||
if err != nil {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
@ -14,12 +15,12 @@ import (
|
|||
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
||||
)
|
||||
|
||||
func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) {
|
||||
func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.APIServer, error) {
|
||||
if cConfig.API.Server.OnlineClient == nil || cConfig.API.Server.OnlineClient.Credentials == nil {
|
||||
log.Info("push and pull to Central API disabled")
|
||||
}
|
||||
|
||||
apiServer, err := apiserver.NewServer(cConfig.API.Server)
|
||||
apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to run local API: %w", err)
|
||||
}
|
||||
|
@ -58,11 +59,14 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) {
|
|||
|
||||
func serveAPIServer(apiServer *apiserver.APIServer) {
|
||||
apiReady := make(chan bool, 1)
|
||||
|
||||
apiTomb.Go(func() error {
|
||||
defer trace.CatchPanic("crowdsec/serveAPIServer")
|
||||
|
||||
go func() {
|
||||
defer trace.CatchPanic("crowdsec/runAPIServer")
|
||||
log.Debugf("serving API after %s ms", time.Since(crowdsecT0))
|
||||
|
||||
if err := apiServer.Run(apiReady); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -76,6 +80,7 @@ func serveAPIServer(apiServer *apiserver.APIServer) {
|
|||
<-apiTomb.Dying() // lock until go routine is dying
|
||||
pluginTomb.Kill(nil)
|
||||
log.Infof("serve: shutting down api server")
|
||||
|
||||
return apiServer.Shutdown()
|
||||
})
|
||||
<-apiReady
|
||||
|
@ -87,5 +92,6 @@ func hasPlugins(profiles []*csconfig.ProfileCfg) bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -52,6 +52,8 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error {
|
|||
func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
|
||||
var tmpFile string
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
// re-initialize tombs
|
||||
acquisTomb = tomb.Tomb{}
|
||||
parsersTomb = tomb.Tomb{}
|
||||
|
@ -74,7 +76,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
|
|||
cConfig.API.Server.OnlineClient = nil
|
||||
}
|
||||
|
||||
apiServer, err := initAPIServer(cConfig)
|
||||
apiServer, err := initAPIServer(ctx, cConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to init api server: %w", err)
|
||||
}
|
||||
|
@ -88,7 +90,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := hub.Load(); err != nil {
|
||||
if err = hub.Load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -374,7 +376,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
|
|||
cConfig.API.Server.OnlineClient = nil
|
||||
}
|
||||
|
||||
apiServer, err := initAPIServer(cConfig)
|
||||
apiServer, err := initAPIServer(ctx, cConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("api server init: %w", err)
|
||||
}
|
||||
|
@ -390,7 +392,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err := hub.Load(); err != nil {
|
||||
if err = hub.Load(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, ur
|
|||
}
|
||||
|
||||
func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) {
|
||||
router, config := NewAPITest(t)
|
||||
router, config := NewAPITest(t, ctx)
|
||||
loginResp := LoginToTestAPI(t, ctx, router, config)
|
||||
|
||||
return router, loginResp, config
|
||||
|
@ -137,7 +137,7 @@ func TestCreateAlert(t *testing.T) {
|
|||
|
||||
func TestCreateAlertChannels(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
apiServer, config := NewAPIServer(t)
|
||||
apiServer, config := NewAPIServer(t, ctx)
|
||||
apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert)
|
||||
apiServer.InitController()
|
||||
|
||||
|
@ -437,7 +437,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
|
|||
// cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24", "::"}
|
||||
cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"}
|
||||
cfg.API.Server.ListenURI = "::8080"
|
||||
server, err := NewServer(cfg.API.Server)
|
||||
server, err := NewServer(ctx, cfg.API.Server)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.InitController()
|
||||
|
|
|
@ -11,9 +11,8 @@ import (
|
|||
)
|
||||
|
||||
func TestAPIKey(t *testing.T) {
|
||||
router, config := NewAPITest(t)
|
||||
|
||||
ctx := context.Background()
|
||||
router, config := NewAPITest(t, ctx)
|
||||
|
||||
APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig)
|
||||
|
||||
|
|
|
@ -256,7 +256,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient
|
|||
}
|
||||
|
||||
// keep track of all alerts in cache and push it to CAPI every PushInterval.
|
||||
func (a *apic) Push() error {
|
||||
func (a *apic) Push(ctx context.Context) error {
|
||||
defer trace.CatchPanic("lapi/pushToAPIC")
|
||||
|
||||
var cache models.AddSignalsRequest
|
||||
|
@ -276,7 +276,7 @@ func (a *apic) Push() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
go a.Send(&cache)
|
||||
go a.Send(ctx, &cache)
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
|
@ -289,7 +289,7 @@ func (a *apic) Push() error {
|
|||
a.mu.Unlock()
|
||||
log.Infof("Signal push: %d signals to push", len(cacheCopy))
|
||||
|
||||
go a.Send(&cacheCopy)
|
||||
go a.Send(ctx, &cacheCopy)
|
||||
}
|
||||
case alerts := <-a.AlertsAddChan:
|
||||
var signals []*models.AddSignalsRequestItem
|
||||
|
@ -351,7 +351,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig
|
|||
return true
|
||||
}
|
||||
|
||||
func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
|
||||
func (a *apic) Send(ctx context.Context, cacheOrig *models.AddSignalsRequest) {
|
||||
/*we do have a problem with this :
|
||||
The apic.Push background routine reads from alertToPush chan.
|
||||
This chan is filled by Controller.CreateAlert
|
||||
|
@ -375,7 +375,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
|
|||
for {
|
||||
if pageEnd >= len(cache) {
|
||||
send = cache[pageStart:]
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
|
||||
defer cancel()
|
||||
|
||||
|
@ -389,7 +389,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
|
|||
}
|
||||
|
||||
send = cache[pageStart:pageEnd]
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
|
||||
defer cancel()
|
||||
|
||||
|
|
|
@ -1134,7 +1134,7 @@ func TestAPICPush(t *testing.T) {
|
|||
api.Shutdown()
|
||||
}()
|
||||
|
||||
err = api.Push()
|
||||
err = api.Push(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount())
|
||||
})
|
||||
|
|
|
@ -159,11 +159,9 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro
|
|||
|
||||
// NewServer creates a LAPI server.
|
||||
// It sets up a gin router, a database client, and a controller.
|
||||
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
||||
func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
||||
var flushScheduler *gocron.Scheduler
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
dbClient, err := database.NewClient(ctx, config.DbConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to init database client: %w", err)
|
||||
|
@ -300,8 +298,8 @@ func (s *APIServer) Router() (*gin.Engine, error) {
|
|||
return s.router, nil
|
||||
}
|
||||
|
||||
func (s *APIServer) apicPush() error {
|
||||
if err := s.apic.Push(); err != nil {
|
||||
func (s *APIServer) apicPush(ctx context.Context) error {
|
||||
if err := s.apic.Push(ctx); err != nil {
|
||||
log.Errorf("capi push: %s", err)
|
||||
return err
|
||||
}
|
||||
|
@ -337,7 +335,7 @@ func (s *APIServer) papiSync() error {
|
|||
}
|
||||
|
||||
func (s *APIServer) initAPIC(ctx context.Context) {
|
||||
s.apic.pushTomb.Go(s.apicPush)
|
||||
s.apic.pushTomb.Go(func() error { return s.apicPush(ctx) })
|
||||
s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) })
|
||||
|
||||
// csConfig.API.Server.ConsoleConfig.ShareCustomScenarios
|
||||
|
|
|
@ -3,7 +3,6 @@ package apiserver
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
|
@ -41,7 +40,7 @@ var (
|
|||
MachineID: &testMachineID,
|
||||
Password: &testPassword,
|
||||
}
|
||||
UserAgent = fmt.Sprintf("crowdsec-test/%s", version.Version)
|
||||
UserAgent = "crowdsec-test/" + version.Version
|
||||
emptyBody = strings.NewReader("")
|
||||
)
|
||||
|
||||
|
@ -135,12 +134,12 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
|
|||
return config
|
||||
}
|
||||
|
||||
func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) {
|
||||
func NewAPIServer(t *testing.T, ctx context.Context) (*APIServer, csconfig.Config) {
|
||||
config := LoadTestConfig(t)
|
||||
|
||||
os.Remove("./ent")
|
||||
|
||||
apiServer, err := NewServer(config.API.Server)
|
||||
apiServer, err := NewServer(ctx, config.API.Server)
|
||||
require.NoError(t, err)
|
||||
|
||||
log.Printf("Creating new API server")
|
||||
|
@ -149,8 +148,8 @@ func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) {
|
|||
return apiServer, config
|
||||
}
|
||||
|
||||
func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) {
|
||||
apiServer, config := NewAPIServer(t)
|
||||
func NewAPITest(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) {
|
||||
apiServer, config := NewAPIServer(t, ctx)
|
||||
|
||||
err := apiServer.InitController()
|
||||
require.NoError(t, err)
|
||||
|
@ -161,12 +160,12 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) {
|
|||
return router, config
|
||||
}
|
||||
|
||||
func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
|
||||
func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) {
|
||||
config := LoadTestConfigForwardedFor(t)
|
||||
|
||||
os.Remove("./ent")
|
||||
|
||||
apiServer, err := NewServer(config.API.Server)
|
||||
apiServer, err := NewServer(ctx, config.API.Server)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = apiServer.InitController()
|
||||
|
@ -302,28 +301,29 @@ func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.Datab
|
|||
}
|
||||
|
||||
func TestWithWrongDBConfig(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
config := LoadTestConfig(t)
|
||||
config.API.Server.DbConfig.Type = "test"
|
||||
apiServer, err := NewServer(config.API.Server)
|
||||
apiServer, err := NewServer(ctx, config.API.Server)
|
||||
|
||||
cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'")
|
||||
assert.Nil(t, apiServer)
|
||||
}
|
||||
|
||||
func TestWithWrongFlushConfig(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
config := LoadTestConfig(t)
|
||||
maxItems := -1
|
||||
config.API.Server.DbConfig.Flush.MaxItems = &maxItems
|
||||
apiServer, err := NewServer(config.API.Server)
|
||||
apiServer, err := NewServer(ctx, config.API.Server)
|
||||
|
||||
cstest.RequireErrorContains(t, err, "max_items can't be zero or negative")
|
||||
assert.Nil(t, apiServer)
|
||||
}
|
||||
|
||||
func TestUnknownPath(t *testing.T) {
|
||||
router, _ := NewAPITest(t)
|
||||
|
||||
ctx := context.Background()
|
||||
router, _ := NewAPITest(t, ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil)
|
||||
|
@ -349,6 +349,8 @@ ListenURI string `yaml:"listen_uri,omitempty"` //127.0
|
|||
*/
|
||||
|
||||
func TestLoggingDebugToFileConfig(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
/*declare settings*/
|
||||
maxAge := "1h"
|
||||
flushConfig := csconfig.FlushDBCfg{
|
||||
|
@ -370,7 +372,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
|
|||
LogDir: tempDir,
|
||||
DbConfig: &dbconfig,
|
||||
}
|
||||
expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
|
||||
expectedFile := filepath.Join(tempDir, "crowdsec_api.log")
|
||||
expectedLines := []string{"/test42"}
|
||||
cfg.LogLevel = ptr.Of(log.DebugLevel)
|
||||
|
||||
|
@ -378,12 +380,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
|
|||
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
api, err := NewServer(&cfg)
|
||||
api, err := NewServer(ctx, &cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
@ -402,6 +402,8 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLoggingErrorToFileConfig(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
/*declare settings*/
|
||||
maxAge := "1h"
|
||||
flushConfig := csconfig.FlushDBCfg{
|
||||
|
@ -423,19 +425,17 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
|
|||
LogDir: tempDir,
|
||||
DbConfig: &dbconfig,
|
||||
}
|
||||
expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
|
||||
expectedFile := filepath.Join(tempDir, "crowdsec_api.log")
|
||||
cfg.LogLevel = ptr.Of(log.ErrorLevel)
|
||||
|
||||
// Configure logging
|
||||
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
api, err := NewServer(&cfg)
|
||||
api, err := NewServer(ctx, &cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
|
|
@ -11,9 +11,8 @@ import (
|
|||
)
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
router, config := NewAPITest(t)
|
||||
|
||||
ctx := context.Background()
|
||||
router, config := NewAPITest(t, ctx)
|
||||
|
||||
body := CreateTestMachine(t, router, "")
|
||||
|
||||
|
|
|
@ -15,9 +15,8 @@ import (
|
|||
)
|
||||
|
||||
func TestCreateMachine(t *testing.T) {
|
||||
router, _ := NewAPITest(t)
|
||||
|
||||
ctx := context.Background()
|
||||
router, _ := NewAPITest(t, ctx)
|
||||
|
||||
// Create machine with invalid format
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -53,10 +52,9 @@ func TestCreateMachine(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreateMachineWithForwardedFor(t *testing.T) {
|
||||
router, config := NewAPITestForwardedFor(t)
|
||||
router.TrustedPlatform = "X-Real-IP"
|
||||
|
||||
ctx := context.Background()
|
||||
router, config := NewAPITestForwardedFor(t, ctx)
|
||||
router.TrustedPlatform = "X-Real-IP"
|
||||
|
||||
// Create machine
|
||||
b, err := json.Marshal(MachineTest)
|
||||
|
@ -79,9 +77,8 @@ func TestCreateMachineWithForwardedFor(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
|
||||
router, config := NewAPITest(t)
|
||||
|
||||
ctx := context.Background()
|
||||
router, config := NewAPITest(t, ctx)
|
||||
|
||||
// Create machine
|
||||
b, err := json.Marshal(MachineTest)
|
||||
|
@ -106,9 +103,8 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreateMachineWithoutForwardedFor(t *testing.T) {
|
||||
router, config := NewAPITestForwardedFor(t)
|
||||
|
||||
ctx := context.Background()
|
||||
router, config := NewAPITestForwardedFor(t, ctx)
|
||||
|
||||
// Create machine
|
||||
b, err := json.Marshal(MachineTest)
|
||||
|
@ -132,9 +128,8 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreateMachineAlreadyExist(t *testing.T) {
|
||||
router, _ := NewAPITest(t)
|
||||
|
||||
ctx := context.Background()
|
||||
router, _ := NewAPITest(t, ctx)
|
||||
|
||||
body := CreateTestMachine(t, router, "")
|
||||
|
||||
|
@ -153,9 +148,8 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAutoRegistration(t *testing.T) {
|
||||
router, _ := NewAPITest(t)
|
||||
|
||||
ctx := context.Background()
|
||||
router, _ := NewAPITest(t, ctx)
|
||||
|
||||
// Invalid registration token / valid source IP
|
||||
regReq := MachineTest
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue