mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-11 04:15:54 +02:00
fix mysql client certificate support (#3575)
This commit is contained in:
parent
7e280b23af
commit
4004868245
2 changed files with 53 additions and 8 deletions
|
@ -1,12 +1,17 @@
|
||||||
package csconfig
|
package csconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"entgo.io/ent/dialect"
|
"entgo.io/ent/dialect"
|
||||||
|
"github.com/go-sql-driver/mysql"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/crowdsecurity/go-cs-lib/ptr"
|
"github.com/crowdsecurity/go-cs-lib/ptr"
|
||||||
|
@ -119,7 +124,7 @@ func (c *Config) LoadDBConfig(inCli bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DatabaseCfg) ConnectionString() string {
|
func (d *DatabaseCfg) ConnectionString() (string, error) {
|
||||||
connString := ""
|
connString := ""
|
||||||
|
|
||||||
switch d.Type {
|
switch d.Type {
|
||||||
|
@ -133,23 +138,59 @@ func (d *DatabaseCfg) ConnectionString() string {
|
||||||
|
|
||||||
connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters)
|
connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters)
|
||||||
case "mysql":
|
case "mysql":
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("parseTime", "True")
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
// This is just to get an initial value, don't care about the error
|
||||||
|
systemRootCAs, _ := x509.SystemCertPool()
|
||||||
|
if systemRootCAs != nil {
|
||||||
|
tlsConfig.RootCAs = systemRootCAs
|
||||||
|
}
|
||||||
|
|
||||||
if d.isSocketConfig() {
|
if d.isSocketConfig() {
|
||||||
connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", d.User, d.Password, d.DbPath, d.DbName)
|
connString = fmt.Sprintf("%s:%s@unix(%s)/%s", d.User, d.Password, d.DbPath, d.DbName)
|
||||||
} else {
|
} else {
|
||||||
connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", d.User, d.Password, d.Host, d.Port, d.DbName)
|
connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", d.User, d.Password, d.Host, d.Port, d.DbName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.SSLMode != "" {
|
if d.SSLMode != "" {
|
||||||
connString = fmt.Sprintf("%s&tls=%s", connString, d.SSLMode)
|
// This will be overridden if a CA or client cert is provided
|
||||||
|
params.Set("tls", d.SSLMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.SSLCACert != "" {
|
if d.SSLCACert != "" {
|
||||||
connString = fmt.Sprintf("%s&tls-ca=%s", connString, d.SSLCACert)
|
caCert, err := os.ReadFile(d.SSLCACert)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read CA cert file %s: %w", d.SSLCACert, err)
|
||||||
|
}
|
||||||
|
if tlsConfig.RootCAs == nil {
|
||||||
|
tlsConfig.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
if !tlsConfig.RootCAs.AppendCertsFromPEM(caCert) {
|
||||||
|
return "", fmt.Errorf("failed to append CA cert file %s: %w", d.SSLCACert, err)
|
||||||
|
}
|
||||||
|
params.Set("tls", "custom")
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.SSLClientCert != "" && d.SSLClientKey != "" {
|
if d.SSLClientCert != "" && d.SSLClientKey != "" {
|
||||||
connString = fmt.Sprintf("%s&tls-cert=%s&tls-key=%s", connString, d.SSLClientCert, d.SSLClientKey)
|
cert, err := tls.LoadX509KeyPair(d.SSLClientCert, d.SSLClientKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to load client cert/key pair: %w", err)
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
params.Set("tls", "custom")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if params.Get("tls") == "custom" {
|
||||||
|
// Register the custom TLS config
|
||||||
|
err := mysql.RegisterTLSConfig("custom", tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to register custom TLS config: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
connString = fmt.Sprintf("%s?%s", connString, params.Encode())
|
||||||
case "postgres", "postgresql", "pgx":
|
case "postgres", "postgresql", "pgx":
|
||||||
if d.isSocketConfig() {
|
if d.isSocketConfig() {
|
||||||
connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password)
|
connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password)
|
||||||
|
@ -170,7 +211,7 @@ func (d *DatabaseCfg) ConnectionString() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return connString
|
return connString, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DatabaseCfg) ConnectionDialect() (string, string, error) {
|
func (d *DatabaseCfg) ConnectionDialect() (string, string, error) {
|
||||||
|
|
|
@ -82,7 +82,11 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
drv, err := getEntDriver(typ, dia, config.ConnectionString(), config)
|
dbConnectionString, err := config.ConnectionString()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate DB connection string: %w", err)
|
||||||
|
}
|
||||||
|
drv, err := getEntDriver(typ, dia, dbConnectionString, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed opening connection to %s: %w", config.Type, err)
|
return nil, fmt.Errorf("failed opening connection to %s: %w", config.Type, err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue