diff --git a/api/api.go b/api/api.go index ddee8fc7..82e77fa8 100644 --- a/api/api.go +++ b/api/api.go @@ -4,28 +4,12 @@ import ( "errors" "github.com/0xJacky/Nginx-UI/internal/logger" "github.com/gin-gonic/gin" - "github.com/gin-gonic/gin/binding" val "github.com/go-playground/validator/v10" "net/http" "reflect" - "regexp" "strings" ) -func init() { - if v, ok := binding.Validator.Engine().(*val.Validate); ok { - err := v.RegisterValidation("alphanumdash", func(fl val.FieldLevel) bool { - return regexp.MustCompile(`^[a-zA-Z0-9-]+$`).MatchString(fl.Field().String()) - }) - - if err != nil { - logger.Fatal(err) - } - return - } - logger.Fatal("binding validator engine is not initialized") -} - func ErrHandler(c *gin.Context, err error) { logger.GetLogger().Errorln(err) c.JSON(http.StatusInternalServerError, gin.H{ @@ -54,11 +38,18 @@ func BindAndValid(c *gin.Context, target interface{}) bool { return false } - t := reflect.TypeOf(target).Elem() + t := reflect.TypeOf(target) errorsMap := make(map[string]interface{}) for _, value := range verrs { var path []string - getJsonPath(t, value.StructNamespace(), &path) + + namespace := strings.Split(value.StructNamespace(), ".") + + if t.Name() == "" && len(namespace) > 1 { + namespace = namespace[1:] + } + + getJsonPath(t.Elem(), namespace, &path) insertError(errorsMap, path, value.Tag()) } @@ -75,11 +66,7 @@ func BindAndValid(c *gin.Context, target interface{}) bool { } // findField recursively finds the field in a nested struct -func getJsonPath(t reflect.Type, namespace string, path *[]string) { - fields := strings.Split(namespace, ".") - if len(fields) == 0 { - return - } +func getJsonPath(t reflect.Type, fields []string, path *[]string) { f, ok := t.FieldByName(fields[0]) if !ok { return @@ -88,7 +75,7 @@ func getJsonPath(t reflect.Type, namespace string, path *[]string) { *path = append(*path, f.Tag.Get("json")) if len(fields) > 1 { - subFields := strings.Join(fields[1:], ".") + subFields := fields[1:] getJsonPath(f.Type, subFields, path) } } diff --git a/api/certificate/certificate.go b/api/certificate/certificate.go index d34c3d29..b5ff78cd 100644 --- a/api/certificate/certificate.go +++ b/api/certificate/certificate.go @@ -1,174 +1,174 @@ package certificate import ( - "github.com/0xJacky/Nginx-UI/api" - "github.com/0xJacky/Nginx-UI/api/cosy" - "github.com/0xJacky/Nginx-UI/internal/cert" - "github.com/0xJacky/Nginx-UI/model" - "github.com/0xJacky/Nginx-UI/query" - "github.com/gin-gonic/gin" - "github.com/spf13/cast" - "net/http" - "os" + "github.com/0xJacky/Nginx-UI/api" + "github.com/0xJacky/Nginx-UI/api/cosy" + "github.com/0xJacky/Nginx-UI/internal/cert" + "github.com/0xJacky/Nginx-UI/model" + "github.com/0xJacky/Nginx-UI/query" + "github.com/gin-gonic/gin" + "github.com/spf13/cast" + "net/http" + "os" ) type APICertificate struct { - *model.Cert - SSLCertificate string `json:"ssl_certificate,omitempty"` - SSLCertificateKey string `json:"ssl_certificate_key,omitempty"` - CertificateInfo *cert.Info `json:"certificate_info,omitempty"` + *model.Cert + SSLCertificate string `json:"ssl_certificate,omitempty"` + SSLCertificateKey string `json:"ssl_certificate_key,omitempty"` + CertificateInfo *cert.Info `json:"certificate_info,omitempty"` } func Transformer(certModel *model.Cert) (certificate *APICertificate) { - var sslCertificationBytes, sslCertificationKeyBytes []byte - var certificateInfo *cert.Info - if certModel.SSLCertificatePath != "" { - if _, err := os.Stat(certModel.SSLCertificatePath); err == nil { - sslCertificationBytes, _ = os.ReadFile(certModel.SSLCertificatePath) - } + var sslCertificationBytes, sslCertificationKeyBytes []byte + var certificateInfo *cert.Info + if certModel.SSLCertificatePath != "" { + if _, err := os.Stat(certModel.SSLCertificatePath); err == nil { + sslCertificationBytes, _ = os.ReadFile(certModel.SSLCertificatePath) + if !cert.IsPublicKey(string(sslCertificationBytes)) { + sslCertificationBytes = []byte{} + } + } - certificateInfo, _ = cert.GetCertInfo(certModel.SSLCertificatePath) - } + certificateInfo, _ = cert.GetCertInfo(certModel.SSLCertificatePath) + } - if certModel.SSLCertificateKeyPath != "" { - if _, err := os.Stat(certModel.SSLCertificateKeyPath); err == nil { - sslCertificationKeyBytes, _ = os.ReadFile(certModel.SSLCertificateKeyPath) - } - } + if certModel.SSLCertificateKeyPath != "" { + if _, err := os.Stat(certModel.SSLCertificateKeyPath); err == nil { + sslCertificationKeyBytes, _ = os.ReadFile(certModel.SSLCertificateKeyPath) + if !cert.IsPrivateKey(string(sslCertificationKeyBytes)) { + sslCertificationKeyBytes = []byte{} + } + } + } - return &APICertificate{ - Cert: certModel, - SSLCertificate: string(sslCertificationBytes), - SSLCertificateKey: string(sslCertificationKeyBytes), - CertificateInfo: certificateInfo, - } + return &APICertificate{ + Cert: certModel, + SSLCertificate: string(sslCertificationBytes), + SSLCertificateKey: string(sslCertificationKeyBytes), + CertificateInfo: certificateInfo, + } } func GetCertList(c *gin.Context) { - cosy.Core[model.Cert](c).SetFussy("name", "domain").SetTransformer(func(m *model.Cert) any { + cosy.Core[model.Cert](c).SetFussy("name", "domain").SetTransformer(func(m *model.Cert) any { - info, _ := cert.GetCertInfo(m.SSLCertificatePath) + info, _ := cert.GetCertInfo(m.SSLCertificatePath) - return APICertificate{ - Cert: m, - CertificateInfo: info, - } - }).PagingList() + return APICertificate{ + Cert: m, + CertificateInfo: info, + } + }).PagingList() } func GetCert(c *gin.Context) { - q := query.Cert + q := query.Cert - certModel, err := q.FirstByID(cast.ToInt(c.Param("id"))) + certModel, err := q.FirstByID(cast.ToInt(c.Param("id"))) - if err != nil { - api.ErrHandler(c, err) - return - } + if err != nil { + api.ErrHandler(c, err) + return + } - c.JSON(http.StatusOK, Transformer(certModel)) + c.JSON(http.StatusOK, Transformer(certModel)) +} + +type certJson struct { + Name string `json:"name"` + SSLCertificatePath string `json:"ssl_certificate_path" binding:"publickey_path"` + SSLCertificateKeyPath string `json:"ssl_certificate_key_path" binding:"privatekey_path"` + SSLCertificate string `json:"ssl_certificate" binding:"omitempty,publickey"` + SSLCertificateKey string `json:"ssl_certificate_key" binding:"omitempty,privatekey"` + ChallengeMethod string `json:"challenge_method"` + DnsCredentialID int `json:"dns_credential_id"` } func AddCert(c *gin.Context) { - var json struct { - Name string `json:"name"` - SSLCertificatePath string `json:"ssl_certificate_path" binding:"required"` - SSLCertificateKeyPath string `json:"ssl_certificate_key_path" binding:"required"` - SSLCertificate string `json:"ssl_certificate"` - SSLCertificateKey string `json:"ssl_certificate_key"` - ChallengeMethod string `json:"challenge_method"` - DnsCredentialID int `json:"dns_credential_id"` - } - if !api.BindAndValid(c, &json) { - return - } - certModel := &model.Cert{ - Name: json.Name, - SSLCertificatePath: json.SSLCertificatePath, - SSLCertificateKeyPath: json.SSLCertificateKeyPath, - ChallengeMethod: json.ChallengeMethod, - DnsCredentialID: json.DnsCredentialID, - } + var json certJson + if !api.BindAndValid(c, &json) { + return + } + certModel := &model.Cert{ + Name: json.Name, + SSLCertificatePath: json.SSLCertificatePath, + SSLCertificateKeyPath: json.SSLCertificateKeyPath, + ChallengeMethod: json.ChallengeMethod, + DnsCredentialID: json.DnsCredentialID, + } - err := certModel.Insert() + err := certModel.Insert() - if err != nil { - api.ErrHandler(c, err) - return - } + if err != nil { + api.ErrHandler(c, err) + return + } - content := &cert.Content{ - SSLCertificatePath: json.SSLCertificatePath, - SSLCertificateKeyPath: json.SSLCertificateKeyPath, - SSLCertificate: json.SSLCertificate, - SSLCertificateKey: json.SSLCertificateKey, - } + content := &cert.Content{ + SSLCertificatePath: json.SSLCertificatePath, + SSLCertificateKeyPath: json.SSLCertificateKeyPath, + SSLCertificate: json.SSLCertificate, + SSLCertificateKey: json.SSLCertificateKey, + } - err = content.WriteFile() + err = content.WriteFile() - if err != nil { - api.ErrHandler(c, err) - return - } + if err != nil { + api.ErrHandler(c, err) + return + } - c.JSON(http.StatusOK, Transformer(certModel)) + c.JSON(http.StatusOK, Transformer(certModel)) } func ModifyCert(c *gin.Context) { - id := cast.ToInt(c.Param("id")) + id := cast.ToInt(c.Param("id")) - var json struct { - Name string `json:"name"` - SSLCertificatePath string `json:"ssl_certificate_path" binding:"required"` - SSLCertificateKeyPath string `json:"ssl_certificate_key_path" binding:"required"` - SSLCertificate string `json:"ssl_certificate"` - SSLCertificateKey string `json:"ssl_certificate_key"` - ChallengeMethod string `json:"challenge_method"` - DnsCredentialID int `json:"dns_credential_id"` - } + var json certJson - if !api.BindAndValid(c, &json) { - return - } + if !api.BindAndValid(c, &json) { + return + } - q := query.Cert + q := query.Cert - certModel, err := q.FirstByID(id) - if err != nil { - api.ErrHandler(c, err) - return - } + certModel, err := q.FirstByID(id) + if err != nil { + api.ErrHandler(c, err) + return + } - err = certModel.Updates(&model.Cert{ - Name: json.Name, - SSLCertificatePath: json.SSLCertificatePath, - SSLCertificateKeyPath: json.SSLCertificateKeyPath, - ChallengeMethod: json.ChallengeMethod, - DnsCredentialID: json.DnsCredentialID, - }) + err = certModel.Updates(&model.Cert{ + Name: json.Name, + SSLCertificatePath: json.SSLCertificatePath, + SSLCertificateKeyPath: json.SSLCertificateKeyPath, + ChallengeMethod: json.ChallengeMethod, + DnsCredentialID: json.DnsCredentialID, + }) - if err != nil { - api.ErrHandler(c, err) - return - } + if err != nil { + api.ErrHandler(c, err) + return + } - content := &cert.Content{ - SSLCertificatePath: json.SSLCertificatePath, - SSLCertificateKeyPath: json.SSLCertificateKeyPath, - SSLCertificate: json.SSLCertificate, - SSLCertificateKey: json.SSLCertificateKey, - } + content := &cert.Content{ + SSLCertificatePath: json.SSLCertificatePath, + SSLCertificateKeyPath: json.SSLCertificateKeyPath, + SSLCertificate: json.SSLCertificate, + SSLCertificateKey: json.SSLCertificateKey, + } - err = content.WriteFile() + err = content.WriteFile() - if err != nil { - api.ErrHandler(c, err) - return - } + if err != nil { + api.ErrHandler(c, err) + return + } - GetCert(c) + GetCert(c) } func RemoveCert(c *gin.Context) { - cosy.Core[model.Cert](c).Destroy() + cosy.Core[model.Cert](c).Destroy() } diff --git a/internal/cert/helper.go b/internal/cert/helper.go new file mode 100644 index 00000000..8e1e449d --- /dev/null +++ b/internal/cert/helper.go @@ -0,0 +1,70 @@ +package cert + +import ( + "crypto/x509" + "encoding/pem" + "os" +) + +func IsPublicKey(pemStr string) bool { + block, _ := pem.Decode([]byte(pemStr)) + if block == nil { + return false + } + + _, err := x509.ParsePKIXPublicKey(block.Bytes) + return err == nil +} + +func IsPrivateKey(pemStr string) bool { + block, _ := pem.Decode([]byte(pemStr)) + if block == nil { + return false + } + + _, errRSA := x509.ParsePKCS1PrivateKey(block.Bytes) + if errRSA == nil { + return true + } + + _, errECDSA := x509.ParseECPrivateKey(block.Bytes) + return errECDSA == nil +} + +// IsPublicKeyPath checks if the file at the given path is a public key or not exists. +func IsPublicKeyPath(path string) bool { + _, err := os.Stat(path) + + if err != nil { + if os.IsNotExist(err) { + return true + } + return false + } + + bytes, err := os.ReadFile(path) + if err != nil { + return false + } + + return IsPublicKey(string(bytes)) +} + +// IsPrivateKeyPath checks if the file at the given path is a private key or not exists. +func IsPrivateKeyPath(path string) bool { + _, err := os.Stat(path) + + if err != nil { + if os.IsNotExist(err) { + return true + } + return false + } + + bytes, err := os.ReadFile(path) + if err != nil { + return false + } + + return IsPrivateKey(string(bytes)) +} diff --git a/internal/kernal/boot.go b/internal/kernal/boot.go index de951560..ae41e5f2 100644 --- a/internal/kernal/boot.go +++ b/internal/kernal/boot.go @@ -4,6 +4,7 @@ import ( "github.com/0xJacky/Nginx-UI/internal/analytic" "github.com/0xJacky/Nginx-UI/internal/cert" "github.com/0xJacky/Nginx-UI/internal/logger" + "github.com/0xJacky/Nginx-UI/internal/validation" "github.com/0xJacky/Nginx-UI/model" "github.com/0xJacky/Nginx-UI/query" "github.com/0xJacky/Nginx-UI/settings" @@ -21,6 +22,7 @@ func Boot() { InitJsExtensionType, InitDatabase, InitNodeSecret, + validation.Init, } syncs := []func(){ diff --git a/internal/validation/alphanumdash.go b/internal/validation/alphanumdash.go new file mode 100644 index 00000000..39f6ce75 --- /dev/null +++ b/internal/validation/alphanumdash.go @@ -0,0 +1,10 @@ +package validation + +import ( + val "github.com/go-playground/validator/v10" + "regexp" +) + +func alphaNumDash(fl val.FieldLevel) bool { + return regexp.MustCompile(`^[a-zA-Z0-9-]+$`).MatchString(fl.Field().String()) +} diff --git a/internal/validation/certificate.go b/internal/validation/certificate.go new file mode 100644 index 00000000..ed6db215 --- /dev/null +++ b/internal/validation/certificate.go @@ -0,0 +1,22 @@ +package validation + +import ( + "github.com/0xJacky/Nginx-UI/internal/cert" + val "github.com/go-playground/validator/v10" +) + +func isPublicKey(fl val.FieldLevel) bool { + return cert.IsPublicKey(fl.Field().String()) +} + +func isPrivateKey(fl val.FieldLevel) bool { + return cert.IsPrivateKey(fl.Field().String()) +} + +func isPublicKeyPath(fl val.FieldLevel) bool { + return cert.IsPublicKeyPath(fl.Field().String()) +} + +func isPrivateKeyPath(fl val.FieldLevel) bool { + return cert.IsPrivateKeyPath(fl.Field().String()) +} diff --git a/internal/validation/validation.go b/internal/validation/validation.go new file mode 100644 index 00000000..0389e827 --- /dev/null +++ b/internal/validation/validation.go @@ -0,0 +1,46 @@ +package validation + +import ( + "github.com/0xJacky/Nginx-UI/internal/logger" + "github.com/gin-gonic/gin/binding" + val "github.com/go-playground/validator/v10" +) + +func Init() { + v, ok := binding.Validator.Engine().(*val.Validate) + if !ok { + logger.Fatal("binding validator engine is not initialized") + } + + err := v.RegisterValidation("alphanumdash", alphaNumDash) + + if err != nil { + logger.Fatal(err) + } + + err = v.RegisterValidation("publickey", isPublicKey) + + if err != nil { + logger.Fatal(err) + } + + err = v.RegisterValidation("privatekey", isPrivateKey) + + if err != nil { + logger.Fatal(err) + } + + err = v.RegisterValidation("publickey_path", isPublicKeyPath) + + if err != nil { + logger.Fatal(err) + } + + err = v.RegisterValidation("privatekey_path", isPrivateKeyPath) + + if err != nil { + logger.Fatal(err) + } + + return +}