context propagation: pkg/apiserver (#3272)

* context propagation: apic.Push()

* context propagation: NewServer()

* lint
This commit is contained in:
mmetc 2024-10-09 13:06:03 +02:00 committed by GitHub
parent 40021b6bcf
commit b9bccfa56f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 59 additions and 61 deletions

View file

@ -321,7 +321,7 @@ issues:
# `err` is often shadowed, we may continue to do it
- linters:
- govet
text: "shadow: declaration of \"err\" shadows declaration"
text: "shadow: declaration of \"(err|ctx)\" shadows declaration"
- linters:
- errcheck

View file

@ -127,7 +127,7 @@ func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client
return fmt.Errorf("unable to initialize API client: %w", err)
}
t.Go(apic.Push)
t.Go(func() error { return apic.Push(ctx) })
papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel())
if err != nil {

View file

@ -1,6 +1,7 @@
package main
import (
"context"
"errors"
"fmt"
"runtime"
@ -14,12 +15,12 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
)
func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) {
func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.APIServer, error) {
if cConfig.API.Server.OnlineClient == nil || cConfig.API.Server.OnlineClient.Credentials == nil {
log.Info("push and pull to Central API disabled")
}
apiServer, err := apiserver.NewServer(cConfig.API.Server)
apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server)
if err != nil {
return nil, fmt.Errorf("unable to run local API: %w", err)
}
@ -58,11 +59,14 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) {
func serveAPIServer(apiServer *apiserver.APIServer) {
apiReady := make(chan bool, 1)
apiTomb.Go(func() error {
defer trace.CatchPanic("crowdsec/serveAPIServer")
go func() {
defer trace.CatchPanic("crowdsec/runAPIServer")
log.Debugf("serving API after %s ms", time.Since(crowdsecT0))
if err := apiServer.Run(apiReady); err != nil {
log.Fatal(err)
}
@ -76,6 +80,7 @@ func serveAPIServer(apiServer *apiserver.APIServer) {
<-apiTomb.Dying() // lock until go routine is dying
pluginTomb.Kill(nil)
log.Infof("serve: shutting down api server")
return apiServer.Shutdown()
})
<-apiReady
@ -87,5 +92,6 @@ func hasPlugins(profiles []*csconfig.ProfileCfg) bool {
return true
}
}
return false
}

View file

@ -52,6 +52,8 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error {
func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
var tmpFile string
ctx := context.TODO()
// re-initialize tombs
acquisTomb = tomb.Tomb{}
parsersTomb = tomb.Tomb{}
@ -74,7 +76,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
cConfig.API.Server.OnlineClient = nil
}
apiServer, err := initAPIServer(cConfig)
apiServer, err := initAPIServer(ctx, cConfig)
if err != nil {
return nil, fmt.Errorf("unable to init api server: %w", err)
}
@ -88,7 +90,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
return nil, err
}
if err := hub.Load(); err != nil {
if err = hub.Load(); err != nil {
return nil, err
}
@ -374,7 +376,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
cConfig.API.Server.OnlineClient = nil
}
apiServer, err := initAPIServer(cConfig)
apiServer, err := initAPIServer(ctx, cConfig)
if err != nil {
return fmt.Errorf("api server init: %w", err)
}
@ -390,7 +392,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
return err
}
if err := hub.Load(); err != nil {
if err = hub.Load(); err != nil {
return err
}

View file

@ -65,7 +65,7 @@ func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, ur
}
func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) {
router, config := NewAPITest(t)
router, config := NewAPITest(t, ctx)
loginResp := LoginToTestAPI(t, ctx, router, config)
return router, loginResp, config
@ -137,7 +137,7 @@ func TestCreateAlert(t *testing.T) {
func TestCreateAlertChannels(t *testing.T) {
ctx := context.Background()
apiServer, config := NewAPIServer(t)
apiServer, config := NewAPIServer(t, ctx)
apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert)
apiServer.InitController()
@ -437,7 +437,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
// cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24", "::"}
cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"}
cfg.API.Server.ListenURI = "::8080"
server, err := NewServer(cfg.API.Server)
server, err := NewServer(ctx, cfg.API.Server)
require.NoError(t, err)
err = server.InitController()

View file

@ -11,9 +11,8 @@ import (
)
func TestAPIKey(t *testing.T) {
router, config := NewAPITest(t)
ctx := context.Background()
router, config := NewAPITest(t, ctx)
APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig)

View file

@ -256,7 +256,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient
}
// keep track of all alerts in cache and push it to CAPI every PushInterval.
func (a *apic) Push() error {
func (a *apic) Push(ctx context.Context) error {
defer trace.CatchPanic("lapi/pushToAPIC")
var cache models.AddSignalsRequest
@ -276,7 +276,7 @@ func (a *apic) Push() error {
return nil
}
go a.Send(&cache)
go a.Send(ctx, &cache)
return nil
case <-ticker.C:
@ -289,7 +289,7 @@ func (a *apic) Push() error {
a.mu.Unlock()
log.Infof("Signal push: %d signals to push", len(cacheCopy))
go a.Send(&cacheCopy)
go a.Send(ctx, &cacheCopy)
}
case alerts := <-a.AlertsAddChan:
var signals []*models.AddSignalsRequestItem
@ -351,7 +351,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig
return true
}
func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
func (a *apic) Send(ctx context.Context, cacheOrig *models.AddSignalsRequest) {
/*we do have a problem with this :
The apic.Push background routine reads from alertToPush chan.
This chan is filled by Controller.CreateAlert
@ -375,7 +375,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
for {
if pageEnd >= len(cache) {
send = cache[pageStart:]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
@ -389,7 +389,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
}
send = cache[pageStart:pageEnd]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

View file

@ -1134,7 +1134,7 @@ func TestAPICPush(t *testing.T) {
api.Shutdown()
}()
err = api.Push()
err = api.Push(ctx)
require.NoError(t, err)
assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount())
})

View file

@ -159,11 +159,9 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro
// NewServer creates a LAPI server.
// It sets up a gin router, a database client, and a controller.
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg) (*APIServer, error) {
var flushScheduler *gocron.Scheduler
ctx := context.TODO()
dbClient, err := database.NewClient(ctx, config.DbConfig)
if err != nil {
return nil, fmt.Errorf("unable to init database client: %w", err)
@ -300,8 +298,8 @@ func (s *APIServer) Router() (*gin.Engine, error) {
return s.router, nil
}
func (s *APIServer) apicPush() error {
if err := s.apic.Push(); err != nil {
func (s *APIServer) apicPush(ctx context.Context) error {
if err := s.apic.Push(ctx); err != nil {
log.Errorf("capi push: %s", err)
return err
}
@ -337,7 +335,7 @@ func (s *APIServer) papiSync() error {
}
func (s *APIServer) initAPIC(ctx context.Context) {
s.apic.pushTomb.Go(s.apicPush)
s.apic.pushTomb.Go(func() error { return s.apicPush(ctx) })
s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) })
// csConfig.API.Server.ConsoleConfig.ShareCustomScenarios

View file

@ -3,7 +3,6 @@ package apiserver
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
@ -41,7 +40,7 @@ var (
MachineID: &testMachineID,
Password: &testPassword,
}
UserAgent = fmt.Sprintf("crowdsec-test/%s", version.Version)
UserAgent = "crowdsec-test/" + version.Version
emptyBody = strings.NewReader("")
)
@ -135,12 +134,12 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
return config
}
func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) {
func NewAPIServer(t *testing.T, ctx context.Context) (*APIServer, csconfig.Config) {
config := LoadTestConfig(t)
os.Remove("./ent")
apiServer, err := NewServer(config.API.Server)
apiServer, err := NewServer(ctx, config.API.Server)
require.NoError(t, err)
log.Printf("Creating new API server")
@ -149,8 +148,8 @@ func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) {
return apiServer, config
}
func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) {
apiServer, config := NewAPIServer(t)
func NewAPITest(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) {
apiServer, config := NewAPIServer(t, ctx)
err := apiServer.InitController()
require.NoError(t, err)
@ -161,12 +160,12 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) {
return router, config
}
func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) {
config := LoadTestConfigForwardedFor(t)
os.Remove("./ent")
apiServer, err := NewServer(config.API.Server)
apiServer, err := NewServer(ctx, config.API.Server)
require.NoError(t, err)
err = apiServer.InitController()
@ -302,28 +301,29 @@ func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.Datab
}
func TestWithWrongDBConfig(t *testing.T) {
ctx := context.Background()
config := LoadTestConfig(t)
config.API.Server.DbConfig.Type = "test"
apiServer, err := NewServer(config.API.Server)
apiServer, err := NewServer(ctx, config.API.Server)
cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'")
assert.Nil(t, apiServer)
}
func TestWithWrongFlushConfig(t *testing.T) {
ctx := context.Background()
config := LoadTestConfig(t)
maxItems := -1
config.API.Server.DbConfig.Flush.MaxItems = &maxItems
apiServer, err := NewServer(config.API.Server)
apiServer, err := NewServer(ctx, config.API.Server)
cstest.RequireErrorContains(t, err, "max_items can't be zero or negative")
assert.Nil(t, apiServer)
}
func TestUnknownPath(t *testing.T) {
router, _ := NewAPITest(t)
ctx := context.Background()
router, _ := NewAPITest(t, ctx)
w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil)
@ -349,6 +349,8 @@ ListenURI string `yaml:"listen_uri,omitempty"` //127.0
*/
func TestLoggingDebugToFileConfig(t *testing.T) {
ctx := context.Background()
/*declare settings*/
maxAge := "1h"
flushConfig := csconfig.FlushDBCfg{
@ -370,7 +372,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
LogDir: tempDir,
DbConfig: &dbconfig,
}
expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
expectedFile := filepath.Join(tempDir, "crowdsec_api.log")
expectedLines := []string{"/test42"}
cfg.LogLevel = ptr.Of(log.DebugLevel)
@ -378,12 +380,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
require.NoError(t, err)
api, err := NewServer(&cfg)
api, err := NewServer(ctx, &cfg)
require.NoError(t, err)
require.NotNil(t, api)
ctx := context.Background()
w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil)
req.Header.Set("User-Agent", UserAgent)
@ -402,6 +402,8 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
}
func TestLoggingErrorToFileConfig(t *testing.T) {
ctx := context.Background()
/*declare settings*/
maxAge := "1h"
flushConfig := csconfig.FlushDBCfg{
@ -423,19 +425,17 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
LogDir: tempDir,
DbConfig: &dbconfig,
}
expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
expectedFile := filepath.Join(tempDir, "crowdsec_api.log")
cfg.LogLevel = ptr.Of(log.ErrorLevel)
// Configure logging
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
require.NoError(t, err)
api, err := NewServer(&cfg)
api, err := NewServer(ctx, &cfg)
require.NoError(t, err)
require.NotNil(t, api)
ctx := context.Background()
w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil)
req.Header.Set("User-Agent", UserAgent)

View file

@ -11,9 +11,8 @@ import (
)
func TestLogin(t *testing.T) {
router, config := NewAPITest(t)
ctx := context.Background()
router, config := NewAPITest(t, ctx)
body := CreateTestMachine(t, router, "")

View file

@ -15,9 +15,8 @@ import (
)
func TestCreateMachine(t *testing.T) {
router, _ := NewAPITest(t)
ctx := context.Background()
router, _ := NewAPITest(t, ctx)
// Create machine with invalid format
w := httptest.NewRecorder()
@ -53,10 +52,9 @@ func TestCreateMachine(t *testing.T) {
}
func TestCreateMachineWithForwardedFor(t *testing.T) {
router, config := NewAPITestForwardedFor(t)
router.TrustedPlatform = "X-Real-IP"
ctx := context.Background()
router, config := NewAPITestForwardedFor(t, ctx)
router.TrustedPlatform = "X-Real-IP"
// Create machine
b, err := json.Marshal(MachineTest)
@ -79,9 +77,8 @@ func TestCreateMachineWithForwardedFor(t *testing.T) {
}
func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
router, config := NewAPITest(t)
ctx := context.Background()
router, config := NewAPITest(t, ctx)
// Create machine
b, err := json.Marshal(MachineTest)
@ -106,9 +103,8 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
}
func TestCreateMachineWithoutForwardedFor(t *testing.T) {
router, config := NewAPITestForwardedFor(t)
ctx := context.Background()
router, config := NewAPITestForwardedFor(t, ctx)
// Create machine
b, err := json.Marshal(MachineTest)
@ -132,9 +128,8 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) {
}
func TestCreateMachineAlreadyExist(t *testing.T) {
router, _ := NewAPITest(t)
ctx := context.Background()
router, _ := NewAPITest(t, ctx)
body := CreateTestMachine(t, router, "")
@ -153,9 +148,8 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
}
func TestAutoRegistration(t *testing.T) {
router, _ := NewAPITest(t)
ctx := context.Background()
router, _ := NewAPITest(t, ctx)
// Invalid registration token / valid source IP
regReq := MachineTest