mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-11 04:15:54 +02:00
enhance: add listen_socket to http acquisition (#3499)
* enhance: add listen_socket to http acquisition * wrap error for long socket path * enhance: Cancel early go routines if config is emtpy and add a socket test * enhance: use temp dir for socket tests * enhance: use mktemp instead of hardcoding * enhance: mr linter pls be happy with me --------- Co-authored-by: marco <marco@crowdsec.net>
This commit is contained in:
parent
46e6398868
commit
764deee1c0
2 changed files with 96 additions and 9 deletions
|
@ -22,6 +22,7 @@ import (
|
||||||
"github.com/crowdsecurity/go-cs-lib/trace"
|
"github.com/crowdsecurity/go-cs-lib/trace"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
|
"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
|
||||||
|
"github.com/crowdsecurity/crowdsec/pkg/csnet"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,6 +39,7 @@ type HttpConfiguration struct {
|
||||||
// IPFilter []string `yaml:"ip_filter"`
|
// IPFilter []string `yaml:"ip_filter"`
|
||||||
// ChunkSize *int64 `yaml:"chunk_size"`
|
// ChunkSize *int64 `yaml:"chunk_size"`
|
||||||
ListenAddr string `yaml:"listen_addr"`
|
ListenAddr string `yaml:"listen_addr"`
|
||||||
|
ListenSocket string `yaml:"listen_socket"`
|
||||||
Path string `yaml:"path"`
|
Path string `yaml:"path"`
|
||||||
AuthType string `yaml:"auth_type"`
|
AuthType string `yaml:"auth_type"`
|
||||||
BasicAuth *BasicAuthConfig `yaml:"basic_auth"`
|
BasicAuth *BasicAuthConfig `yaml:"basic_auth"`
|
||||||
|
@ -89,8 +91,8 @@ func (h *HTTPSource) UnmarshalConfig(yamlConfig []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hc *HttpConfiguration) Validate() error {
|
func (hc *HttpConfiguration) Validate() error {
|
||||||
if hc.ListenAddr == "" {
|
if hc.ListenAddr == "" && hc.ListenSocket == "" {
|
||||||
return errors.New("listen_addr is required")
|
return errors.New("listen_addr or listen_socket is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
if hc.Path == "" {
|
if hc.Path == "" {
|
||||||
|
@ -350,6 +352,11 @@ func (h *HTTPSource) RunServer(out chan types.Event, t *tomb.Tomb) error {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.RemoteAddr == "@" {
|
||||||
|
//We check if request came from unix socket and if so we set to loopback
|
||||||
|
r.RemoteAddr = "127.0.0.1:65535"
|
||||||
|
}
|
||||||
|
|
||||||
err := h.processRequest(w, r, &h.Config, out)
|
err := h.processRequest(w, r, &h.Config, out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err)
|
h.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err)
|
||||||
|
@ -396,7 +403,38 @@ func (h *HTTPSource) RunServer(out chan types.Event, t *tomb.Tomb) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Go(func() error {
|
t.Go(func() error {
|
||||||
defer trace.CatchPanic("crowdsec/acquis/http/server")
|
if h.Config.ListenSocket == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defer trace.CatchPanic("crowdsec/acquis/http/server/unix")
|
||||||
|
h.logger.Infof("creating unix socket on %s", h.Config.ListenSocket)
|
||||||
|
_ = os.Remove(h.Config.ListenSocket)
|
||||||
|
listener, err := net.Listen("unix", h.Config.ListenSocket)
|
||||||
|
if err != nil {
|
||||||
|
return csnet.WrapSockErr(err, h.Config.ListenSocket)
|
||||||
|
}
|
||||||
|
if h.Config.TLS != nil {
|
||||||
|
err := h.Server.ServeTLS(listener, h.Config.TLS.ServerCert, h.Config.TLS.ServerKey)
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
return fmt.Errorf("https server failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := h.Server.Serve(listener)
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
return fmt.Errorf("http server failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Go(func() error {
|
||||||
|
if h.Config.ListenAddr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defer trace.CatchPanic("crowdsec/acquis/http/server/tcp")
|
||||||
|
|
||||||
if h.Config.TLS != nil {
|
if h.Config.TLS != nil {
|
||||||
h.logger.Infof("start https server on %s", h.Config.ListenAddr)
|
h.logger.Infof("start https server on %s", h.Config.ListenAddr)
|
||||||
|
|
|
@ -2,13 +2,16 @@ package httpacquisition
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -37,7 +40,7 @@ func TestConfigure(t *testing.T) {
|
||||||
{
|
{
|
||||||
config: `
|
config: `
|
||||||
foobar: bla`,
|
foobar: bla`,
|
||||||
expectedErr: "invalid configuration: listen_addr is required",
|
expectedErr: "invalid configuration: listen_addr or listen_socket is required",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
config: `
|
config: `
|
||||||
|
@ -256,7 +259,7 @@ basic_auth:
|
||||||
|
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr + "/test", http.NoBody)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr+"/test", http.NoBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
@ -284,7 +287,7 @@ basic_auth:
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr + "/unknown", http.NoBody)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testHTTPServerAddr+"/unknown", http.NoBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
@ -313,7 +316,7 @@ basic_auth:
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr + "/test", strings.NewReader("test"))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr+"/test", strings.NewReader("test"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
@ -321,7 +324,7 @@ basic_auth:
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
|
||||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr + "/test", strings.NewReader("test"))
|
req, err = http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr+"/test", strings.NewReader("test"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.SetBasicAuth("test", "WrongPassword")
|
req.SetBasicAuth("test", "WrongPassword")
|
||||||
|
|
||||||
|
@ -474,6 +477,52 @@ custom_headers:
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAcquistionSocket(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
socketFile := filepath.Join(tempDir, "test.sock")
|
||||||
|
|
||||||
|
ctx := t.Context()
|
||||||
|
h := &HTTPSource{}
|
||||||
|
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
|
||||||
|
source: http
|
||||||
|
listen_socket: `+socketFile+`
|
||||||
|
path: /test
|
||||||
|
auth_type: headers
|
||||||
|
headers:
|
||||||
|
key: test`), 2)
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
rawEvt := `{"test": "test"}`
|
||||||
|
errChan := make(chan error)
|
||||||
|
go assertEvents(out, []string{rawEvt}, errChan)
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return net.Dial("unix", socketFile)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req.Header.Add("Key", "test")
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
err = <-errChan
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assertMetrics(t, reg, h.GetMetrics(), 1)
|
||||||
|
|
||||||
|
h.Server.Close()
|
||||||
|
tomb.Kill(nil)
|
||||||
|
err = tomb.Wait()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
type slowReader struct {
|
type slowReader struct {
|
||||||
delay time.Duration
|
delay time.Duration
|
||||||
body []byte
|
body []byte
|
||||||
|
@ -582,7 +631,7 @@ tls:
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr + "/test", strings.NewReader("test"))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, testHTTPServerAddr+"/test", strings.NewReader("test"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue