fix mysql client certificate support (#3575)

This commit is contained in:
blotus 2025-04-16 14:39:26 +02:00 committed by GitHub
parent 7e280b23af
commit 4004868245
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 53 additions and 8 deletions

View file

@ -1,12 +1,17 @@
package csconfig
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"time"
"entgo.io/ent/dialect"
"github.com/go-sql-driver/mysql"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/go-cs-lib/ptr"
@ -119,7 +124,7 @@ func (c *Config) LoadDBConfig(inCli bool) error {
return nil
}
func (d *DatabaseCfg) ConnectionString() string {
func (d *DatabaseCfg) ConnectionString() (string, error) {
connString := ""
switch d.Type {
@ -133,23 +138,59 @@ func (d *DatabaseCfg) ConnectionString() string {
connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters)
case "mysql":
params := url.Values{}
params.Add("parseTime", "True")
tlsConfig := &tls.Config{}
// This is just to get an initial value, don't care about the error
systemRootCAs, _ := x509.SystemCertPool()
if systemRootCAs != nil {
tlsConfig.RootCAs = systemRootCAs
}
if d.isSocketConfig() {
connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", d.User, d.Password, d.DbPath, d.DbName)
connString = fmt.Sprintf("%s:%s@unix(%s)/%s", d.User, d.Password, d.DbPath, d.DbName)
} else {
connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", d.User, d.Password, d.Host, d.Port, d.DbName)
connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", d.User, d.Password, d.Host, d.Port, d.DbName)
}
if d.SSLMode != "" {
connString = fmt.Sprintf("%s&tls=%s", connString, d.SSLMode)
// This will be overridden if a CA or client cert is provided
params.Set("tls", d.SSLMode)
}
if d.SSLCACert != "" {
connString = fmt.Sprintf("%s&tls-ca=%s", connString, d.SSLCACert)
caCert, err := os.ReadFile(d.SSLCACert)
if err != nil {
return "", fmt.Errorf("failed to read CA cert file %s: %w", d.SSLCACert, err)
}
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(caCert) {
return "", fmt.Errorf("failed to append CA cert file %s: %w", d.SSLCACert, err)
}
params.Set("tls", "custom")
}
if d.SSLClientCert != "" && d.SSLClientKey != "" {
connString = fmt.Sprintf("%s&tls-cert=%s&tls-key=%s", connString, d.SSLClientCert, d.SSLClientKey)
cert, err := tls.LoadX509KeyPair(d.SSLClientCert, d.SSLClientKey)
if err != nil {
return "", fmt.Errorf("failed to load client cert/key pair: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
params.Set("tls", "custom")
}
if params.Get("tls") == "custom" {
// Register the custom TLS config
err := mysql.RegisterTLSConfig("custom", tlsConfig)
if err != nil {
return "", fmt.Errorf("failed to register custom TLS config: %w", err)
}
}
connString = fmt.Sprintf("%s?%s", connString, params.Encode())
case "postgres", "postgresql", "pgx":
if d.isSocketConfig() {
connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password)
@ -170,7 +211,7 @@ func (d *DatabaseCfg) ConnectionString() string {
}
}
return connString
return connString, nil
}
func (d *DatabaseCfg) ConnectionDialect() (string, string, error) {

View file

@ -82,7 +82,11 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro
}
}
drv, err := getEntDriver(typ, dia, config.ConnectionString(), config)
dbConnectionString, err := config.ConnectionString()
if err != nil {
return nil, fmt.Errorf("failed to generate DB connection string: %w", err)
}
drv, err := getEntDriver(typ, dia, dbConnectionString, config)
if err != nil {
return nil, fmt.Errorf("failed opening connection to %s: %w", config.Type, err)
}