diff --git a/api/certificate/issue.go b/api/certificate/issue.go index e61c694f..b631c642 100644 --- a/api/certificate/issue.go +++ b/api/certificate/issue.go @@ -3,13 +3,11 @@ package certificate import ( "github.com/0xJacky/Nginx-UI/internal/cert" "github.com/0xJacky/Nginx-UI/internal/logger" - "github.com/0xJacky/Nginx-UI/internal/nginx" "github.com/0xJacky/Nginx-UI/model" "github.com/gin-gonic/gin" "github.com/go-acme/lego/v4/certcrypto" "github.com/gorilla/websocket" "net/http" - "strings" ) const ( @@ -71,7 +69,6 @@ func IssueCert(c *gin.Context) { payload := &cert.ConfigPayload{} err = ws.ReadJSON(payload) - if err != nil { logger.Error(err) return @@ -122,14 +119,10 @@ func IssueCert(c *gin.Context) { return } - certDirName := strings.Join(payload.ServerName, "_") + "_" + string(payload.GetKeyType()) - sslCertificatePath := nginx.GetConfPath("ssl", certDirName, "fullchain.cer") - sslCertificateKeyPath := nginx.GetConfPath("ssl", certDirName, "private.key") - err = certModel.Updates(&model.Cert{ Domains: payload.ServerName, - SSLCertificatePath: sslCertificatePath, - SSLCertificateKeyPath: sslCertificateKeyPath, + SSLCertificatePath: payload.GetCertificatePath(), + SSLCertificateKeyPath: payload.GetCertificateKeyPath(), AutoCert: model.AutoCertEnabled, KeyType: payload.KeyType, ChallengeMethod: payload.ChallengeMethod, @@ -152,8 +145,8 @@ func IssueCert(c *gin.Context) { err = ws.WriteJSON(IssueCertResponse{ Status: Success, Message: "Issued certificate successfully", - SSLCertificate: sslCertificatePath, - SSLCertificateKey: sslCertificateKeyPath, + SSLCertificate: payload.GetCertificatePath(), + SSLCertificateKey: payload.GetCertificateKeyPath(), KeyType: payload.GetKeyType(), }) diff --git a/internal/cert/payload.go b/internal/cert/payload.go index 857e6301..83ded9b5 100644 --- a/internal/cert/payload.go +++ b/internal/cert/payload.go @@ -16,14 +16,17 @@ import ( ) type ConfigPayload struct { - CertID int `json:"cert_id"` - ServerName []string `json:"server_name"` - ChallengeMethod string `json:"challenge_method"` - DNSCredentialID int `json:"dns_credential_id"` - ACMEUserID int `json:"acme_user_id"` - KeyType certcrypto.KeyType `json:"key_type"` - Resource *model.CertificateResource `json:"resource,omitempty"` - NotBefore time.Time + CertID int `json:"cert_id"` + ServerName []string `json:"server_name"` + ChallengeMethod string `json:"challenge_method"` + DNSCredentialID int `json:"dns_credential_id"` + ACMEUserID int `json:"acme_user_id"` + KeyType certcrypto.KeyType `json:"key_type"` + Resource *model.CertificateResource `json:"resource,omitempty"` + NotBefore time.Time `json:"-"` + CertificateDir string `json:"-"` + SSLCertificatePath string `json:"-"` + SSLCertificateKeyPath string `json:"-"` } func (c *ConfigPayload) GetACMEUser() (user *model.AcmeUser, err error) { @@ -46,21 +49,38 @@ func (c *ConfigPayload) GetKeyType() certcrypto.KeyType { return helper.GetKeyType(c.KeyType) } -func (c *ConfigPayload) WriteFile(l *log.Logger, errChan chan error) { - name := strings.Join(c.ServerName, "_") - saveDir := nginx.GetConfPath("ssl/" + name + "_" + string(c.KeyType)) - if _, err := os.Stat(saveDir); os.IsNotExist(err) { - err = os.MkdirAll(saveDir, 0755) - if err != nil { - errChan <- errors.Wrap(err, "mkdir error") - return +func (c *ConfigPayload) mkCertificateDir() (err error) { + dir := c.getCertificateDirPath() + if _, err = os.Stat(dir); os.IsNotExist(err) { + err = os.MkdirAll(dir, 0755) + if err == nil { + return nil } } + // For windows, replace # with * (issue #403) + c.CertificateDir = strings.ReplaceAll(c.CertificateDir, "#", "*") + if _, err = os.Stat(c.CertificateDir); os.IsNotExist(err) { + err = os.MkdirAll(c.CertificateDir, 0755) + if err == nil { + return nil + } + } + + return +} + +func (c *ConfigPayload) WriteFile(l *log.Logger, errChan chan error) { + err := c.mkCertificateDir() + if err != nil { + errChan <- errors.Wrap(err, "make certificate dir error") + return + } + // Each certificate comes back with the cert bytes, the bytes of the client's // private key, and a certificate URL. SAVE THESE TO DISK. l.Println("[INFO] [Nginx UI] Writing certificate to disk") - err := os.WriteFile(filepath.Join(saveDir, "fullchain.cer"), + err = os.WriteFile(c.GetCertificatePath(), c.Resource.Certificate, 0644) if err != nil { @@ -69,7 +89,7 @@ func (c *ConfigPayload) WriteFile(l *log.Logger, errChan chan error) { } l.Println("[INFO] [Nginx UI] Writing certificate private key to disk") - err = os.WriteFile(filepath.Join(saveDir, "private.key"), + err = os.WriteFile(c.GetCertificateKeyPath(), c.Resource.PrivateKey, 0644) if err != nil { @@ -84,7 +104,31 @@ func (c *ConfigPayload) WriteFile(l *log.Logger, errChan chan error) { db := model.UseDB() db.Where("id = ?", c.CertID).Updates(&model.Cert{ - SSLCertificatePath: filepath.Join(saveDir, "fullchain.cer"), - SSLCertificateKeyPath: filepath.Join(saveDir, "private.key"), + SSLCertificatePath: c.GetCertificatePath(), + SSLCertificateKeyPath: c.GetCertificateKeyPath(), }) } + +func (c *ConfigPayload) getCertificateDirPath() string { + if c.CertificateDir != "" { + return c.CertificateDir + } + c.CertificateDir = nginx.GetConfPath("ssl", strings.Join(c.ServerName, "_")+"_"+string(c.GetKeyType())) + return c.CertificateDir +} + +func (c *ConfigPayload) GetCertificatePath() string { + if c.SSLCertificatePath != "" { + return c.SSLCertificatePath + } + c.SSLCertificatePath = filepath.Join(c.getCertificateDirPath(), "fullchain.cer") + return c.SSLCertificatePath +} + +func (c *ConfigPayload) GetCertificateKeyPath() string { + if c.SSLCertificateKeyPath != "" { + return c.SSLCertificateKeyPath + } + c.SSLCertificateKeyPath = filepath.Join(c.getCertificateDirPath(), "private.key") + return c.SSLCertificateKeyPath +}