enhance: add listen_socket to http acquisition (#3499)

* enhance: add listen_socket to http acquisition

* wrap error for long socket path

* enhance: Cancel early go routines if config is emtpy and add a socket test

* enhance: use temp dir for socket tests

* enhance: use mktemp instead of hardcoding

* enhance: mr linter pls be happy with me

---------

Co-authored-by: marco <marco@crowdsec.net>
This commit is contained in:
Laurence Jones 2025-04-29 17:19:10 +01:00 committed by GitHub
parent 46e6398868
commit 764deee1c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 96 additions and 9 deletions

View file

@ -22,6 +22,7 @@ import (
"github.com/crowdsecurity/go-cs-lib/trace"
"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
"github.com/crowdsecurity/crowdsec/pkg/csnet"
"github.com/crowdsecurity/crowdsec/pkg/types"
)
@ -38,6 +39,7 @@ type HttpConfiguration struct {
// IPFilter []string `yaml:"ip_filter"`
// ChunkSize *int64 `yaml:"chunk_size"`
ListenAddr string `yaml:"listen_addr"`
ListenSocket string `yaml:"listen_socket"`
Path string `yaml:"path"`
AuthType string `yaml:"auth_type"`
BasicAuth *BasicAuthConfig `yaml:"basic_auth"`
@ -89,8 +91,8 @@ func (h *HTTPSource) UnmarshalConfig(yamlConfig []byte) error {
}
func (hc *HttpConfiguration) Validate() error {
if hc.ListenAddr == "" {
return errors.New("listen_addr is required")
if hc.ListenAddr == "" && hc.ListenSocket == "" {
return errors.New("listen_addr or listen_socket is required")
}
if hc.Path == "" {
@ -350,6 +352,11 @@ func (h *HTTPSource) RunServer(out chan types.Event, t *tomb.Tomb) error {
return
}
if r.RemoteAddr == "@" {
//We check if request came from unix socket and if so we set to loopback
r.RemoteAddr = "127.0.0.1:65535"
}
err := h.processRequest(w, r, &h.Config, out)
if err != nil {
h.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err)
@ -396,7 +403,38 @@ func (h *HTTPSource) RunServer(out chan types.Event, t *tomb.Tomb) error {
}
t.Go(func() error {
defer trace.CatchPanic("crowdsec/acquis/http/server")
if h.Config.ListenSocket == "" {
return nil
}
defer trace.CatchPanic("crowdsec/acquis/http/server/unix")
h.logger.Infof("creating unix socket on %s", h.Config.ListenSocket)
_ = os.Remove(h.Config.ListenSocket)
listener, err := net.Listen("unix", h.Config.ListenSocket)
if err != nil {
return csnet.WrapSockErr(err, h.Config.ListenSocket)
}
if h.Config.TLS != nil {
err := h.Server.ServeTLS(listener, h.Config.TLS.ServerCert, h.Config.TLS.ServerKey)
if err != nil && err != http.ErrServerClosed {
return fmt.Errorf("https server failed: %w", err)
}
} else {
err := h.Server.Serve(listener)
if err != nil && err != http.ErrServerClosed {
return fmt.Errorf("http server failed: %w", err)
}
}
return nil
})
t.Go(func() error {
if h.Config.ListenAddr == "" {
return nil
}
defer trace.CatchPanic("crowdsec/acquis/http/server/tcp")
if h.Config.TLS != nil {
h.logger.Infof("start https server on %s", h.Config.ListenAddr)

View file

@ -2,13 +2,16 @@ package httpacquisition
import (
"compress/gzip"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
@ -37,7 +40,7 @@ func TestConfigure(t *testing.T) {
{
config: `
foobar: bla`,
expectedErr: "invalid configuration: listen_addr is required",
expectedErr: "invalid configuration: listen_addr or listen_socket is required",
},
{
config: `
@ -256,7 +259,7 @@ basic_auth:
ctx := t.Context()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr + "/test", http.NoBody)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr+"/test", http.NoBody)
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
@ -284,7 +287,7 @@ basic_auth:
time.Sleep(1 * time.Second)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr + "/unknown", http.NoBody)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr+"/unknown", http.NoBody)
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
@ -313,7 +316,7 @@ basic_auth:
client := &http.Client{}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr + "/test", strings.NewReader("test"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr+"/test", strings.NewReader("test"))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
@ -321,7 +324,7 @@ basic_auth:
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
req, err = http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr + "/test", strings.NewReader("test"))
req, err = http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr+"/test", strings.NewReader("test"))
require.NoError(t, err)
req.SetBasicAuth("test", "WrongPassword")
@ -474,6 +477,52 @@ custom_headers:
require.NoError(t, err)
}
func TestAcquistionSocket(t *testing.T) {
tempDir := t.TempDir()
socketFile := filepath.Join(tempDir, "test.sock")
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
listen_socket: `+socketFile+`
path: /test
auth_type: headers
headers:
key: test`), 2)
time.Sleep(1 * time.Second)
rawEvt := `{"test": "test"}`
errChan := make(chan error)
go assertEvents(out, []string{rawEvt}, errChan)
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", socketFile)
},
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
require.NoError(t, err)
req.Header.Add("Key", "test")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
err = <-errChan
require.NoError(t, err)
assertMetrics(t, reg, h.GetMetrics(), 1)
h.Server.Close()
tomb.Kill(nil)
err = tomb.Wait()
require.NoError(t, err)
}
type slowReader struct {
delay time.Duration
body []byte
@ -582,7 +631,7 @@ tls:
time.Sleep(1 * time.Second)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr + "/test", strings.NewReader("test"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr+"/test", strings.NewReader("test"))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")