diff --git a/api/cosy/sort.go b/api/cosy/sort.go index d1894dd6..ad612bd7 100644 --- a/api/cosy/sort.go +++ b/api/cosy/sort.go @@ -2,27 +2,39 @@ package cosy import ( "fmt" + "github.com/0xJacky/Nginx-UI/internal/logger" "github.com/gin-gonic/gin" "gorm.io/gorm" + "gorm.io/gorm/schema" + "sync" ) func (c *Ctx[T]) SortOrder() func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { sort := c.ctx.DefaultQuery("order", "desc") - order := fmt.Sprintf("%s %s", DefaultQuery(c.ctx, "sort_by", c.itemKey), sort) - return db.Order(order) + if sort != "desc" && sort != "asc" { + sort = "desc" + } + + // check if the order field is valid + // todo: maybe we can use more generic way to check if the sort_by is valid + order := DefaultQuery(c.ctx, "sort_by", c.itemKey) + s, _ := schema.Parse(c.Model, &sync.Map{}, schema.NamingStrategy{}) + if _, ok := s.FieldsByDBName[order]; ok { + order = fmt.Sprintf("%s %s", order, sort) + return db.Order(order) + } else { + logger.Error("invalid order field:", order) + } + + return db } } func (c *Ctx[T]) OrderAndPaginate() func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { - sort := c.ctx.DefaultQuery("order", "desc") - - order := fmt.Sprintf("%s %s", DefaultQuery(c.ctx, "sort_by", c.itemKey), sort) - db = db.Order(order) - + db = c.SortOrder()(db) _, offset, pageSize := GetPagingParams(c.ctx) - return db.Offset(offset).Limit(pageSize) } } diff --git a/model/model.go b/model/model.go index 54ae22d4..b1ecbe93 100644 --- a/model/model.go +++ b/model/model.go @@ -10,8 +10,10 @@ import ( "gorm.io/gen" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" + "gorm.io/gorm/schema" "path" "strings" + "sync" "time" ) @@ -100,9 +102,19 @@ func SortOrder(c *gin.Context) func(db *gorm.DB) *gorm.DB { func OrderAndPaginate(c *gin.Context) func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { sort := c.DefaultQuery("order", "desc") + if sort != "desc" && sort != "asc" { + sort = "desc" + } - order := fmt.Sprintf("`%s` %s", DefaultQuery(c, "sort_by", "id"), sort) - db = db.Order(order) + // check if the order field is valid + order := c.DefaultQuery("sort_by", "id") + s, _ := schema.Parse(db.Model, &sync.Map{}, schema.NamingStrategy{}) + if _, ok := s.FieldsByName[order]; ok { + order = fmt.Sprintf("%s %s", order, sort) + db = db.Order(order) + } else { + logger.Error("invalid order field: ", order) + } page := cast.ToInt(c.Query("page")) if page == 0 {