server: more support for mixed-case model names (#8017)

Fixes #7944
This commit is contained in:
Blake Mizerany 2024-12-11 15:29:59 -08:00 committed by GitHub
parent 36d111e788
commit b1fd7fef86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 123 additions and 38 deletions

View file

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"net"
@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
// what the API currently returns until we can change it.
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
model, err := GetModel(name.String())
if err != nil {
switch {
case os.IsNotExist(err):
case errors.Is(err, fs.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps = append(caps, CapabilityInsert)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
name, err := getExistingName(model.ParseName(req.Model))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) {
return
}
var model string
var mname string
if req.Model != "" {
model = req.Model
mname = req.Model
} else if req.Name != "" {
model = req.Name
mname = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil {
name, err := getExistingName(model.ParseName(mname))
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) {
streamResponse(c, ch)
}
// getExistingName returns the original, on disk name if the input name is a
// case-insensitive match, otherwise it returns the input name.
// getExistingName searches the models directory for the longest prefix match of
// the input name and returns the input name with all existing parts replaced
// with each part found. If no parts are found, the input name is returned as
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := Manifests(true)
if err != nil {
return zero, err
}
var set model.Name // tracks parts already canonicalized
for e := range existing {
if n.EqualFold(e) {
return e, nil
if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
n.Host = e.Host
}
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
n.Namespace = e.Namespace
}
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
n.Model = e.Model
}
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
n.Tag = e.Tag
}
}
return n, nil
@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or Modelfile are required"})
return
}
@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
n, err := getExistingName(n)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
return
}
m, err := ParseNamedManifest(n)
if err != nil {
switch {
@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) {
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m, err := GetModel(req.Model)
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, errModelPathInvalid
}
name, err := getExistingName(name)
if err != nil {
return nil, err
}
m, err := GetModel(name.String())
if err != nil {
return nil, err
}
@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
n := model.ParseName(req.Model)
if !n.IsValid() {
return nil, errors.New("invalid model name")
}
manifest, err := ParseNamedManifest(n)
manifest, err := ParseNamedManifest(name)
if err != nil {
return nil, err
}
@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
caps = append(caps, CapabilityTools)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return