mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-10 20:05:55 +02:00
fix mysql client certificate support (#3575)
This commit is contained in:
parent
7e280b23af
commit
4004868245
2 changed files with 53 additions and 8 deletions
|
@ -1,12 +1,17 @@
|
|||
package csconfig
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/crowdsecurity/go-cs-lib/ptr"
|
||||
|
@ -119,7 +124,7 @@ func (c *Config) LoadDBConfig(inCli bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *DatabaseCfg) ConnectionString() string {
|
||||
func (d *DatabaseCfg) ConnectionString() (string, error) {
|
||||
connString := ""
|
||||
|
||||
switch d.Type {
|
||||
|
@ -133,23 +138,59 @@ func (d *DatabaseCfg) ConnectionString() string {
|
|||
|
||||
connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters)
|
||||
case "mysql":
|
||||
params := url.Values{}
|
||||
params.Add("parseTime", "True")
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
// This is just to get an initial value, don't care about the error
|
||||
systemRootCAs, _ := x509.SystemCertPool()
|
||||
if systemRootCAs != nil {
|
||||
tlsConfig.RootCAs = systemRootCAs
|
||||
}
|
||||
|
||||
if d.isSocketConfig() {
|
||||
connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", d.User, d.Password, d.DbPath, d.DbName)
|
||||
connString = fmt.Sprintf("%s:%s@unix(%s)/%s", d.User, d.Password, d.DbPath, d.DbName)
|
||||
} else {
|
||||
connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", d.User, d.Password, d.Host, d.Port, d.DbName)
|
||||
connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", d.User, d.Password, d.Host, d.Port, d.DbName)
|
||||
}
|
||||
|
||||
if d.SSLMode != "" {
|
||||
connString = fmt.Sprintf("%s&tls=%s", connString, d.SSLMode)
|
||||
// This will be overridden if a CA or client cert is provided
|
||||
params.Set("tls", d.SSLMode)
|
||||
}
|
||||
|
||||
if d.SSLCACert != "" {
|
||||
connString = fmt.Sprintf("%s&tls-ca=%s", connString, d.SSLCACert)
|
||||
caCert, err := os.ReadFile(d.SSLCACert)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read CA cert file %s: %w", d.SSLCACert, err)
|
||||
}
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(caCert) {
|
||||
return "", fmt.Errorf("failed to append CA cert file %s: %w", d.SSLCACert, err)
|
||||
}
|
||||
params.Set("tls", "custom")
|
||||
}
|
||||
|
||||
if d.SSLClientCert != "" && d.SSLClientKey != "" {
|
||||
connString = fmt.Sprintf("%s&tls-cert=%s&tls-key=%s", connString, d.SSLClientCert, d.SSLClientKey)
|
||||
cert, err := tls.LoadX509KeyPair(d.SSLClientCert, d.SSLClientKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load client cert/key pair: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
params.Set("tls", "custom")
|
||||
}
|
||||
|
||||
if params.Get("tls") == "custom" {
|
||||
// Register the custom TLS config
|
||||
err := mysql.RegisterTLSConfig("custom", tlsConfig)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to register custom TLS config: %w", err)
|
||||
}
|
||||
}
|
||||
connString = fmt.Sprintf("%s?%s", connString, params.Encode())
|
||||
case "postgres", "postgresql", "pgx":
|
||||
if d.isSocketConfig() {
|
||||
connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password)
|
||||
|
@ -170,7 +211,7 @@ func (d *DatabaseCfg) ConnectionString() string {
|
|||
}
|
||||
}
|
||||
|
||||
return connString
|
||||
return connString, nil
|
||||
}
|
||||
|
||||
func (d *DatabaseCfg) ConnectionDialect() (string, string, error) {
|
||||
|
|
|
@ -82,7 +82,11 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro
|
|||
}
|
||||
}
|
||||
|
||||
drv, err := getEntDriver(typ, dia, config.ConnectionString(), config)
|
||||
dbConnectionString, err := config.ConnectionString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate DB connection string: %w", err)
|
||||
}
|
||||
drv, err := getEntDriver(typ, dia, dbConnectionString, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed opening connection to %s: %w", config.Type, err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue