From b1fd7fef866bcd060810582c172cffe1185db077 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Wed, 11 Dec 2024 15:29:59 -0800 Subject: [PATCH] server: more support for mixed-case model names (#8017) Fixes #7944 --- cmd/cmd.go | 2 +- server/images.go | 4 ++ server/modelpath.go | 15 +++-- server/modelpath_test.go | 8 --- server/routes.go | 112 ++++++++++++++++++++++++++------- server/routes_generate_test.go | 2 +- server/routes_test.go | 14 +++++ types/model/name.go | 4 +- 8 files changed, 123 insertions(+), 38 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index f934a2679..2f77640c2 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -601,7 +601,7 @@ func ListHandler(cmd *cobra.Command, args []string) error { var data [][]string for _, m := range models.Models { - if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) { + if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) { data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")}) } } diff --git a/server/images.go b/server/images.go index 29877db33..4006584fa 100644 --- a/server/images.go +++ b/server/images.go @@ -376,6 +376,10 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio switch command { case "model", "adapter": if name := model.ParseName(c.Args); name.IsValid() && command == "model" { + name, err := getExistingName(name) + if err != nil { + return err + } baseLayers, err = parseFromModel(ctx, name, fn) if err != nil { return err diff --git a/server/modelpath.go b/server/modelpath.go index d498c4678..5a96dec57 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -3,6 +3,7 @@ package server import ( "errors" "fmt" + "io/fs" "net/url" "os" "path/filepath" @@ -10,6 +11,7 @@ import ( "strings" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/types/model" ) type ModelPath struct { @@ -93,11 +95,16 @@ func (mp ModelPath) GetShortTagname() string { // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist. func (mp ModelPath) GetManifestPath() (string, error) { - if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) { - return filepath.Join(envconfig.Models(), "manifests", p), nil + name := model.Name{ + Host: mp.Registry, + Namespace: mp.Namespace, + Model: mp.Repository, + Tag: mp.Tag, } - - return "", errModelPathInvalid + if !name.IsValid() { + return "", fs.ErrNotExist + } + return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil } func (mp ModelPath) BaseURL() *url.URL { diff --git a/server/modelpath_test.go b/server/modelpath_test.go index ef26266bd..849e0fa73 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -1,7 +1,6 @@ package server import ( - "errors" "os" "path/filepath" "testing" @@ -155,10 +154,3 @@ func TestParseModelPath(t *testing.T) { }) } } - -func TestInsecureModelpath(t *testing.T) { - mp := ParseModelPath("../../..:something") - if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) { - t.Errorf("expected error: %v", err) - } -} diff --git a/server/routes.go b/server/routes.go index 593d372e5..d7a1b88db 100644 --- a/server/routes.go +++ b/server/routes.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "io/fs" "log/slog" "math" "net" @@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - model, err := GetModel(req.Model) + name := model.ParseName(req.Model) + if !name.IsValid() { + // Ideally this is "invalid model name" but we're keeping with + // what the API currently returns until we can change it. + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + return + } + + // We cannot currently consolidate this into GetModel because all we'll + // induce infinite recursion given the current code structure. + name, err := getExistingName(name) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + return + } + + model, err := GetModel(name.String()) if err != nil { switch { - case os.IsNotExist(err): + case errors.Is(err, fs.ErrNotExist): c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) case err.Error() == "invalid model name": c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { caps = append(caps, CapabilityInsert) } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) return @@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) { } } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + name, err := getExistingName(model.ParseName(req.Model)) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + return + } + + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + name := model.ParseName(req.Model) + if !name.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) { return } - var model string + var mname string if req.Model != "" { - model = req.Model + mname = req.Model } else if req.Name != "" { - model = req.Name + mname = req.Name } else { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return @@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := PushModel(ctx, model, regOpts, fn); err != nil { + name, err := getExistingName(model.ParseName(mname)) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } + + if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) { streamResponse(c, ch) } -// getExistingName returns the original, on disk name if the input name is a -// case-insensitive match, otherwise it returns the input name. +// getExistingName searches the models directory for the longest prefix match of +// the input name and returns the input name with all existing parts replaced +// with each part found. If no parts are found, the input name is returned as +// is. func getExistingName(n model.Name) (model.Name, error) { var zero model.Name existing, err := Manifests(true) if err != nil { return zero, err } + var set model.Name // tracks parts already canonicalized for e := range existing { - if n.EqualFold(e) { - return e, nil + if set.Host == "" && strings.EqualFold(e.Host, n.Host) { + n.Host = e.Host + } + if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) { + n.Namespace = e.Namespace + } + if set.Model == "" && strings.EqualFold(e.Model, n.Model) { + n.Model = e.Model + } + if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) { + n.Tag = e.Tag } } return n, nil @@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) { } if r.Path == "" && r.Modelfile == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or Modelfile are required"}) return } @@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) { return } + n, err := getExistingName(n) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))}) + return + } + m, err := ParseNamedManifest(n) if err != nil { switch { @@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) { } func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { - m, err := GetModel(req.Model) + name := model.ParseName(req.Model) + if !name.IsValid() { + return nil, errModelPathInvalid + } + name, err := getExistingName(name) + if err != nil { + return nil, err + } + + m, err := GetModel(name.String()) if err != nil { return nil, err } @@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { msgs[i] = api.Message{Role: msg.Role, Content: msg.Content} } - n := model.ParseName(req.Model) - if !n.IsValid() { - return nil, errors.New("invalid model name") - } - - manifest, err := ParseNamedManifest(n) + manifest, err := ParseNamedManifest(name) if err != nil { return nil, err } @@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) { caps = append(caps, CapabilityTools) } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + name := model.ParseName(req.Model) + if !name.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + name, err := getExistingName(name) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) return diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 4bde55bb4..737fa79c1 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -719,7 +719,7 @@ func TestGenerate(t *testing.T) { t.Errorf("expected status 400, got %d", w.Code) } - if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" { + if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } }) diff --git a/server/routes_test.go b/server/routes_test.go index 1daf36f1a..bc007714c 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -514,6 +514,8 @@ func TestManifestCaseSensitivity(t *testing.T) { wantStableName := name() + t.Logf("stable name: %s", wantStableName) + // checkManifestList tests that there is strictly one manifest in the // models directory, and that the manifest is for the model under test. checkManifestList := func() { @@ -601,6 +603,18 @@ func TestManifestCaseSensitivity(t *testing.T) { Destination: name(), })) checkManifestList() + + t.Logf("pushing") + rr := createRequest(t, s.PushHandler, api.PushRequest{ + Model: name(), + Insecure: true, + Username: "alice", + Password: "x", + }) + checkOK(rr) + if !strings.Contains(rr.Body.String(), `"status":"success"`) { + t.Errorf("got = %q, want success", rr.Body.String()) + } } func TestShow(t *testing.T) { diff --git a/types/model/name.go b/types/model/name.go index 9d819f100..a46f3e28d 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -223,12 +223,12 @@ func (n Name) String() string { func (n Name) DisplayShortest() string { var sb strings.Builder - if n.Host != defaultHost { + if !strings.EqualFold(n.Host, defaultHost) { sb.WriteString(n.Host) sb.WriteByte('/') sb.WriteString(n.Namespace) sb.WriteByte('/') - } else if n.Namespace != defaultNamespace { + } else if !strings.EqualFold(n.Namespace, defaultNamespace) { sb.WriteString(n.Namespace) sb.WriteByte('/') }