fix mysql client certificate support (#3575)

This commit is contained in:
blotus 2025-04-16 14:39:26 +02:00 committed by GitHub
parent 7e280b23af
commit 4004868245
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 53 additions and 8 deletions

View file

@ -1,12 +1,17 @@
package csconfig package csconfig
import ( import (
"crypto/tls"
"crypto/x509"
"errors" "errors"
"fmt" "fmt"
"net/url"
"os"
"path/filepath" "path/filepath"
"time" "time"
"entgo.io/ent/dialect" "entgo.io/ent/dialect"
"github.com/go-sql-driver/mysql"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/ptr"
@ -119,7 +124,7 @@ func (c *Config) LoadDBConfig(inCli bool) error {
return nil return nil
} }
func (d *DatabaseCfg) ConnectionString() string { func (d *DatabaseCfg) ConnectionString() (string, error) {
connString := "" connString := ""
switch d.Type { switch d.Type {
@ -133,23 +138,59 @@ func (d *DatabaseCfg) ConnectionString() string {
connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters) connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters)
case "mysql": case "mysql":
params := url.Values{}
params.Add("parseTime", "True")
tlsConfig := &tls.Config{}
// This is just to get an initial value, don't care about the error
systemRootCAs, _ := x509.SystemCertPool()
if systemRootCAs != nil {
tlsConfig.RootCAs = systemRootCAs
}
if d.isSocketConfig() { if d.isSocketConfig() {
connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", d.User, d.Password, d.DbPath, d.DbName) connString = fmt.Sprintf("%s:%s@unix(%s)/%s", d.User, d.Password, d.DbPath, d.DbName)
} else { } else {
connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", d.User, d.Password, d.Host, d.Port, d.DbName) connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", d.User, d.Password, d.Host, d.Port, d.DbName)
} }
if d.SSLMode != "" { if d.SSLMode != "" {
connString = fmt.Sprintf("%s&tls=%s", connString, d.SSLMode) // This will be overridden if a CA or client cert is provided
params.Set("tls", d.SSLMode)
} }
if d.SSLCACert != "" { if d.SSLCACert != "" {
connString = fmt.Sprintf("%s&tls-ca=%s", connString, d.SSLCACert) caCert, err := os.ReadFile(d.SSLCACert)
if err != nil {
return "", fmt.Errorf("failed to read CA cert file %s: %w", d.SSLCACert, err)
}
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(caCert) {
return "", fmt.Errorf("failed to append CA cert file %s: %w", d.SSLCACert, err)
}
params.Set("tls", "custom")
} }
if d.SSLClientCert != "" && d.SSLClientKey != "" { if d.SSLClientCert != "" && d.SSLClientKey != "" {
connString = fmt.Sprintf("%s&tls-cert=%s&tls-key=%s", connString, d.SSLClientCert, d.SSLClientKey) cert, err := tls.LoadX509KeyPair(d.SSLClientCert, d.SSLClientKey)
if err != nil {
return "", fmt.Errorf("failed to load client cert/key pair: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
params.Set("tls", "custom")
} }
if params.Get("tls") == "custom" {
// Register the custom TLS config
err := mysql.RegisterTLSConfig("custom", tlsConfig)
if err != nil {
return "", fmt.Errorf("failed to register custom TLS config: %w", err)
}
}
connString = fmt.Sprintf("%s?%s", connString, params.Encode())
case "postgres", "postgresql", "pgx": case "postgres", "postgresql", "pgx":
if d.isSocketConfig() { if d.isSocketConfig() {
connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password) connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password)
@ -170,7 +211,7 @@ func (d *DatabaseCfg) ConnectionString() string {
} }
} }
return connString return connString, nil
} }
func (d *DatabaseCfg) ConnectionDialect() (string, string, error) { func (d *DatabaseCfg) ConnectionDialect() (string, string, error) {

View file

@ -82,7 +82,11 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro
} }
} }
drv, err := getEntDriver(typ, dia, config.ConnectionString(), config) dbConnectionString, err := config.ConnectionString()
if err != nil {
return nil, fmt.Errorf("failed to generate DB connection string: %w", err)
}
drv, err := getEntDriver(typ, dia, dbConnectionString, config)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed opening connection to %s: %w", config.Type, err) return nil, fmt.Errorf("failed opening connection to %s: %w", config.Type, err)
} }