From 4004868245070ae45918ff06ce79928ca120f15c Mon Sep 17 00:00:00 2001 From: blotus Date: Wed, 16 Apr 2025 14:39:26 +0200 Subject: [PATCH] fix mysql client certificate support (#3575) --- pkg/csconfig/database.go | 55 +++++++++++++++++++++++++++++++++++----- pkg/database/database.go | 6 ++++- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/pkg/csconfig/database.go b/pkg/csconfig/database.go index 29e8e4c33..26150eb2e 100644 --- a/pkg/csconfig/database.go +++ b/pkg/csconfig/database.go @@ -1,12 +1,17 @@ package csconfig import ( + "crypto/tls" + "crypto/x509" "errors" "fmt" + "net/url" + "os" "path/filepath" "time" "entgo.io/ent/dialect" + "github.com/go-sql-driver/mysql" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -119,7 +124,7 @@ func (c *Config) LoadDBConfig(inCli bool) error { return nil } -func (d *DatabaseCfg) ConnectionString() string { +func (d *DatabaseCfg) ConnectionString() (string, error) { connString := "" switch d.Type { @@ -133,23 +138,59 @@ func (d *DatabaseCfg) ConnectionString() string { connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters) case "mysql": + params := url.Values{} + params.Add("parseTime", "True") + + tlsConfig := &tls.Config{} + + // This is just to get an initial value, don't care about the error + systemRootCAs, _ := x509.SystemCertPool() + if systemRootCAs != nil { + tlsConfig.RootCAs = systemRootCAs + } + if d.isSocketConfig() { - connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", d.User, d.Password, d.DbPath, d.DbName) + connString = fmt.Sprintf("%s:%s@unix(%s)/%s", d.User, d.Password, d.DbPath, d.DbName) } else { - connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", d.User, d.Password, d.Host, d.Port, d.DbName) + connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", d.User, d.Password, d.Host, d.Port, d.DbName) } if d.SSLMode != "" { - connString = fmt.Sprintf("%s&tls=%s", connString, d.SSLMode) + // This will be overridden if a CA or client cert is provided + params.Set("tls", d.SSLMode) } if d.SSLCACert != "" { - connString = fmt.Sprintf("%s&tls-ca=%s", connString, d.SSLCACert) + caCert, err := os.ReadFile(d.SSLCACert) + if err != nil { + return "", fmt.Errorf("failed to read CA cert file %s: %w", d.SSLCACert, err) + } + if tlsConfig.RootCAs == nil { + tlsConfig.RootCAs = x509.NewCertPool() + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(caCert) { + return "", fmt.Errorf("failed to append CA cert file %s: %w", d.SSLCACert, err) + } + params.Set("tls", "custom") } if d.SSLClientCert != "" && d.SSLClientKey != "" { - connString = fmt.Sprintf("%s&tls-cert=%s&tls-key=%s", connString, d.SSLClientCert, d.SSLClientKey) + cert, err := tls.LoadX509KeyPair(d.SSLClientCert, d.SSLClientKey) + if err != nil { + return "", fmt.Errorf("failed to load client cert/key pair: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + params.Set("tls", "custom") } + + if params.Get("tls") == "custom" { + // Register the custom TLS config + err := mysql.RegisterTLSConfig("custom", tlsConfig) + if err != nil { + return "", fmt.Errorf("failed to register custom TLS config: %w", err) + } + } + connString = fmt.Sprintf("%s?%s", connString, params.Encode()) case "postgres", "postgresql", "pgx": if d.isSocketConfig() { connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password) @@ -170,7 +211,7 @@ func (d *DatabaseCfg) ConnectionString() string { } } - return connString + return connString, nil } func (d *DatabaseCfg) ConnectionDialect() (string, string, error) { diff --git a/pkg/database/database.go b/pkg/database/database.go index d5186a76d..3d7a4e1b0 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -82,7 +82,11 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro } } - drv, err := getEntDriver(typ, dia, config.ConnectionString(), config) + dbConnectionString, err := config.ConnectionString() + if err != nil { + return nil, fmt.Errorf("failed to generate DB connection string: %w", err) + } + drv, err := getEntDriver(typ, dia, dbConnectionString, config) if err != nil { return nil, fmt.Errorf("failed opening connection to %s: %w", config.Type, err) }