From 764deee1c0f3370eb724f4b521c577fd43e26931 Mon Sep 17 00:00:00 2001 From: Laurence Jones Date: Tue, 29 Apr 2025 17:19:10 +0100 Subject: [PATCH] enhance: add listen_socket to http acquisition (#3499) * enhance: add listen_socket to http acquisition * wrap error for long socket path * enhance: Cancel early go routines if config is emtpy and add a socket test * enhance: use temp dir for socket tests * enhance: use mktemp instead of hardcoding * enhance: mr linter pls be happy with me --------- Co-authored-by: marco --- pkg/acquisition/modules/http/http.go | 44 ++++++++++++++-- pkg/acquisition/modules/http/http_test.go | 61 ++++++++++++++++++++--- 2 files changed, 96 insertions(+), 9 deletions(-) diff --git a/pkg/acquisition/modules/http/http.go b/pkg/acquisition/modules/http/http.go index 76d7d06d2..4cf5d6bbf 100644 --- a/pkg/acquisition/modules/http/http.go +++ b/pkg/acquisition/modules/http/http.go @@ -22,6 +22,7 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/csnet" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -38,6 +39,7 @@ type HttpConfiguration struct { // IPFilter []string `yaml:"ip_filter"` // ChunkSize *int64 `yaml:"chunk_size"` ListenAddr string `yaml:"listen_addr"` + ListenSocket string `yaml:"listen_socket"` Path string `yaml:"path"` AuthType string `yaml:"auth_type"` BasicAuth *BasicAuthConfig `yaml:"basic_auth"` @@ -89,8 +91,8 @@ func (h *HTTPSource) UnmarshalConfig(yamlConfig []byte) error { } func (hc *HttpConfiguration) Validate() error { - if hc.ListenAddr == "" { - return errors.New("listen_addr is required") + if hc.ListenAddr == "" && hc.ListenSocket == "" { + return errors.New("listen_addr or listen_socket is required") } if hc.Path == "" { @@ -350,6 +352,11 @@ func (h *HTTPSource) RunServer(out chan types.Event, t *tomb.Tomb) error { return } + if r.RemoteAddr == "@" { + //We check if request came from unix socket and if so we set to loopback + r.RemoteAddr = "127.0.0.1:65535" + } + err := h.processRequest(w, r, &h.Config, out) if err != nil { h.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err) @@ -396,7 +403,38 @@ func (h *HTTPSource) RunServer(out chan types.Event, t *tomb.Tomb) error { } t.Go(func() error { - defer trace.CatchPanic("crowdsec/acquis/http/server") + if h.Config.ListenSocket == "" { + return nil + } + + defer trace.CatchPanic("crowdsec/acquis/http/server/unix") + h.logger.Infof("creating unix socket on %s", h.Config.ListenSocket) + _ = os.Remove(h.Config.ListenSocket) + listener, err := net.Listen("unix", h.Config.ListenSocket) + if err != nil { + return csnet.WrapSockErr(err, h.Config.ListenSocket) + } + if h.Config.TLS != nil { + err := h.Server.ServeTLS(listener, h.Config.TLS.ServerCert, h.Config.TLS.ServerKey) + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("https server failed: %w", err) + } + } else { + err := h.Server.Serve(listener) + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("http server failed: %w", err) + } + } + + return nil + }) + + t.Go(func() error { + if h.Config.ListenAddr == "" { + return nil + } + + defer trace.CatchPanic("crowdsec/acquis/http/server/tcp") if h.Config.TLS != nil { h.logger.Infof("start https server on %s", h.Config.ListenAddr) diff --git a/pkg/acquisition/modules/http/http_test.go b/pkg/acquisition/modules/http/http_test.go index 552fe90e3..ab55e956c 100644 --- a/pkg/acquisition/modules/http/http_test.go +++ b/pkg/acquisition/modules/http/http_test.go @@ -2,13 +2,16 @@ package httpacquisition import ( "compress/gzip" + "context" "crypto/tls" "crypto/x509" "errors" "fmt" "io" + "net" "net/http" "os" + "path/filepath" "strings" "testing" "time" @@ -37,7 +40,7 @@ func TestConfigure(t *testing.T) { { config: ` foobar: bla`, - expectedErr: "invalid configuration: listen_addr is required", + expectedErr: "invalid configuration: listen_addr or listen_socket is required", }, { config: ` @@ -256,7 +259,7 @@ basic_auth: ctx := t.Context() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr + "/test", http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr+"/test", http.NoBody) require.NoError(t, err) res, err := http.DefaultClient.Do(req) @@ -284,7 +287,7 @@ basic_auth: time.Sleep(1 * time.Second) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr + "/unknown", http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr+"/unknown", http.NoBody) require.NoError(t, err) res, err := http.DefaultClient.Do(req) @@ -313,7 +316,7 @@ basic_auth: client := &http.Client{} - req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr + "/test", strings.NewReader("test")) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr+"/test", strings.NewReader("test")) require.NoError(t, err) req.Header.Set("Content-Type", "application/json") @@ -321,7 +324,7 @@ basic_auth: require.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - req, err = http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr + "/test", strings.NewReader("test")) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr+"/test", strings.NewReader("test")) require.NoError(t, err) req.SetBasicAuth("test", "WrongPassword") @@ -474,6 +477,52 @@ custom_headers: require.NoError(t, err) } +func TestAcquistionSocket(t *testing.T) { + tempDir := t.TempDir() + socketFile := filepath.Join(tempDir, "test.sock") + + ctx := t.Context() + h := &HTTPSource{} + out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_socket: `+socketFile+` +path: /test +auth_type: headers +headers: + key: test`), 2) + + time.Sleep(1 * time.Second) + rawEvt := `{"test": "test"}` + errChan := make(chan error) + go assertEvents(out, []string{rawEvt}, errChan) + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", socketFile) + }, + }, + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt)) + require.NoError(t, err) + + req.Header.Add("Key", "test") + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, reg, h.GetMetrics(), 1) + + h.Server.Close() + tomb.Kill(nil) + err = tomb.Wait() + require.NoError(t, err) +} + type slowReader struct { delay time.Duration body []byte @@ -582,7 +631,7 @@ tls: time.Sleep(1 * time.Second) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr + "/test", strings.NewReader("test")) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr+"/test", strings.NewReader("test")) require.NoError(t, err) req.Header.Set("Content-Type", "application/json")