fix: ensure the list sort query is validated to prevent SQL injection

Credits to @jorgectf for the advisories.
This commit is contained in:
Hintay 2023-12-20 04:52:02 +09:00
parent 827e76c46e
commit ec93ab05a3
No known key found for this signature in database
GPG key ID: 120FC7FF121F2F2D
2 changed files with 34 additions and 10 deletions

View file

@ -2,27 +2,39 @@ package cosy
import ( import (
"fmt" "fmt"
"github.com/0xJacky/Nginx-UI/internal/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/schema"
"sync"
) )
func (c *Ctx[T]) SortOrder() func(db *gorm.DB) *gorm.DB { func (c *Ctx[T]) SortOrder() func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB {
sort := c.ctx.DefaultQuery("order", "desc") sort := c.ctx.DefaultQuery("order", "desc")
order := fmt.Sprintf("%s %s", DefaultQuery(c.ctx, "sort_by", c.itemKey), sort) if sort != "desc" && sort != "asc" {
return db.Order(order) sort = "desc"
}
// check if the order field is valid
// todo: maybe we can use more generic way to check if the sort_by is valid
order := DefaultQuery(c.ctx, "sort_by", c.itemKey)
s, _ := schema.Parse(c.Model, &sync.Map{}, schema.NamingStrategy{})
if _, ok := s.FieldsByDBName[order]; ok {
order = fmt.Sprintf("%s %s", order, sort)
return db.Order(order)
} else {
logger.Error("invalid order field:", order)
}
return db
} }
} }
func (c *Ctx[T]) OrderAndPaginate() func(db *gorm.DB) *gorm.DB { func (c *Ctx[T]) OrderAndPaginate() func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB {
sort := c.ctx.DefaultQuery("order", "desc") db = c.SortOrder()(db)
order := fmt.Sprintf("%s %s", DefaultQuery(c.ctx, "sort_by", c.itemKey), sort)
db = db.Order(order)
_, offset, pageSize := GetPagingParams(c.ctx) _, offset, pageSize := GetPagingParams(c.ctx)
return db.Offset(offset).Limit(pageSize) return db.Offset(offset).Limit(pageSize)
} }
} }

View file

@ -10,8 +10,10 @@ import (
"gorm.io/gen" "gorm.io/gen"
"gorm.io/gorm" "gorm.io/gorm"
gormlogger "gorm.io/gorm/logger" gormlogger "gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"path" "path"
"strings" "strings"
"sync"
"time" "time"
) )
@ -100,9 +102,19 @@ func SortOrder(c *gin.Context) func(db *gorm.DB) *gorm.DB {
func OrderAndPaginate(c *gin.Context) func(db *gorm.DB) *gorm.DB { func OrderAndPaginate(c *gin.Context) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB {
sort := c.DefaultQuery("order", "desc") sort := c.DefaultQuery("order", "desc")
if sort != "desc" && sort != "asc" {
sort = "desc"
}
order := fmt.Sprintf("`%s` %s", DefaultQuery(c, "sort_by", "id"), sort) // check if the order field is valid
db = db.Order(order) order := c.DefaultQuery("sort_by", "id")
s, _ := schema.Parse(db.Model, &sync.Map{}, schema.NamingStrategy{})
if _, ok := s.FieldsByName[order]; ok {
order = fmt.Sprintf("%s %s", order, sort)
db = db.Order(order)
} else {
logger.Error("invalid order field: ", order)
}
page := cast.ToInt(c.Query("page")) page := cast.ToInt(c.Query("page"))
if page == 0 { if page == 0 {