diff --git a/api/user/otp.go b/api/user/otp.go index 5211091c..e5fda405 100644 --- a/api/user/otp.go +++ b/api/user/otp.go @@ -87,7 +87,7 @@ func EnrollTOTP(c *gin.Context) { return } - t := time.Now() + t := time.Now().Unix() recoveryCodes := model.RecoveryCodes{Codes: generateRecoveryCodes(16), LastViewed: &t} codesJson, err := json.Marshal(&recoveryCodes) if err != nil { diff --git a/api/user/recovery.go b/api/user/recovery.go index 79408a1c..81ac8f39 100644 --- a/api/user/recovery.go +++ b/api/user/recovery.go @@ -1,7 +1,6 @@ package user import ( - "encoding/json" "fmt" "math/rand" "net/http" @@ -23,10 +22,12 @@ func generateRecoveryCode() string { return fmt.Sprintf("%05x-%05x", rand.Intn(0x100000), rand.Intn(0x100000)) } -func generateRecoveryCodes(count int) []model.RecoveryCode { - recoveryCodes := make([]model.RecoveryCode, count) +func generateRecoveryCodes(count int) []*model.RecoveryCode { + recoveryCodes := make([]*model.RecoveryCode, count) for i := 0; i < count; i++ { - recoveryCodes[i].Code = generateRecoveryCode() + recoveryCodes[i] = &model.RecoveryCode{ + Code: generateRecoveryCode(), + } } return recoveryCodes } @@ -34,17 +35,11 @@ func generateRecoveryCodes(count int) []model.RecoveryCode { func ViewRecoveryCodes(c *gin.Context) { user := api.CurrentUser(c) - u := query.User - user, err := u.Where(u.ID.Eq(user.ID)).First() - if err != nil { - api.ErrHandler(c, err) - return - } - // update last viewed time - t := time.Now() + u := query.User + t := time.Now().Unix() user.RecoveryCodes.LastViewed = &t - _, err = u.Where(u.ID.Eq(user.ID)).Updates(user) + _, err := u.Where(u.ID.Eq(user.ID)).Updates(user) if err != nil { api.ErrHandler(c, err) return @@ -59,16 +54,12 @@ func ViewRecoveryCodes(c *gin.Context) { func GenerateRecoveryCodes(c *gin.Context) { user := api.CurrentUser(c) - t := time.Now() + t := time.Now().Unix() recoveryCodes := model.RecoveryCodes{Codes: generateRecoveryCodes(16), LastViewed: &t} - codesJson, err := json.Marshal(&recoveryCodes) - if err != nil { - api.ErrHandler(c, err) - return - } + user.RecoveryCodes = recoveryCodes u := query.User - _, err = u.Where(u.ID.Eq(user.ID)).Update(u.RecoveryCodes, codesJson) + _, err := u.Where(u.ID.Eq(user.ID)).Updates(user) if err != nil { api.ErrHandler(c, err) return diff --git a/internal/crypto/aes.go b/internal/crypto/aes.go index fdbf4428..87ef71e4 100644 --- a/internal/crypto/aes.go +++ b/internal/crypto/aes.go @@ -1,12 +1,17 @@ package crypto import ( + "context" "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" - "github.com/0xJacky/Nginx-UI/settings" + "encoding/json" "io" + "reflect" + + "github.com/0xJacky/Nginx-UI/settings" + "gorm.io/gorm/schema" ) // AesEncrypt encrypts text and given key with AES. @@ -55,3 +60,49 @@ func AesDecrypt(text []byte) ([]byte, error) { return data, nil } + +type JSONAesSerializer struct{} + +func (JSONAesSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + bytes, err = json.Marshal(v) + if err != nil { + return err + } + } + + if len(bytes) > 0 { + bytes, err = AesDecrypt(bytes) + if err != nil { + return err + } + err = json.Unmarshal(bytes, fieldValue.Interface()) + } + } + + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (JSONAesSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + result, err := json.Marshal(fieldValue) + if string(result) == "null" { + if field.TagSettings["NOT NULL"] != "" { + return "", nil + } + return nil, err + } + + encrypt, err := AesEncrypt(result) + return string(encrypt), err +} diff --git a/internal/user/otp.go b/internal/user/otp.go index 6339800c..91ecf11f 100644 --- a/internal/user/otp.go +++ b/internal/user/otp.go @@ -52,7 +52,7 @@ func VerifyOTP(user *model.User, otp, recoveryCode string) (err error) { // check recovery code for _, code := range user.RecoveryCodes.Codes { if code.Code == recoveryCode && code.UsedTime == nil { - t := time.Now() + t := time.Now().Unix() code.UsedTime = &t _, err = u.Where(u.ID.Eq(user.ID)).Updates(user) return diff --git a/model/user.go b/model/user.go index e814ee18..8cd3b344 100644 --- a/model/user.go +++ b/model/user.go @@ -1,22 +1,26 @@ package model import ( - "time" - + "github.com/0xJacky/Nginx-UI/internal/crypto" "github.com/go-webauthn/webauthn/webauthn" "github.com/spf13/cast" "gorm.io/gorm" + "gorm.io/gorm/schema" ) +func init() { + schema.RegisterSerializer("json[aes]", crypto.JSONAesSerializer{}) +} + type RecoveryCode struct { - Code string `json:"code"` - UsedTime *time.Time `json:"used_time,omitempty" gorm:"type:datetime;default:null"` + Code string `json:"code"` + UsedTime *int64 `json:"used_time,omitempty" gorm:"type:datetime;default:null"` } type RecoveryCodes struct { - Codes []RecoveryCode `json:"codes"` - LastViewed *time.Time `json:"last_viewed,omitempty" gorm:"type:datetime;default:null"` - LastDownloaded *time.Time `json:"last_downloaded,omitempty" gorm:"type:datetime;default:null"` + Codes []*RecoveryCode `json:"codes"` + LastViewed *int64 `json:"last_viewed,omitempty" gorm:"serializer:unixtime;type:datetime;default:null"` + LastDownloaded *int64 `json:"last_downloaded,omitempty" gorm:"serializer:unixtime;type:datetime;default:null"` } type User struct { @@ -26,7 +30,7 @@ type User struct { Password string `json:"-" cosy:"json:password;add:required,max=20;update:omitempty,max=20"` Status bool `json:"status" gorm:"default:1"` OTPSecret []byte `json:"-" gorm:"type:blob"` - RecoveryCodes RecoveryCodes `json:"-" gorm:"serializer:json"` + RecoveryCodes RecoveryCodes `json:"-" gorm:"serializer:json[aes]"` EnabledTwoFA bool `json:"enabled_2fa" gorm:"-"` }