rebase + fix tests

This commit is contained in:
ParthSareen 2025-04-03 17:23:38 -07:00
parent 4053c489b4
commit 3bc9d42e2e

View file

@ -16,7 +16,6 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"slices" "slices"
"strconv" "strconv"
@ -263,7 +262,7 @@ func GetModel(name string) (*Model, error) {
return nil, err return nil, err
} }
model := &Model{ m := &Model{
Name: mp.GetFullTagname(), Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(), ShortName: mp.GetShortTagname(),
Digest: digest, Digest: digest,
@ -282,7 +281,7 @@ func GetModel(name string) (*Model, error) {
} }
defer configFile.Close() defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil { if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
return nil, err return nil, err
} }
} }
@ -295,16 +294,16 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename m.ModelPath = filename
model.ParentModel = layer.From m.ParentModel = layer.From
case "application/vnd.ollama.image.embed": case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2 // Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version // TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.") slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter": case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename) m.AdapterPaths = append(m.AdapterPaths, filename)
case "application/vnd.ollama.image.projector": case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename) m.ProjectorPaths = append(m.ProjectorPaths, filename)
case "application/vnd.ollama.image.prompt", case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template": "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
@ -312,7 +311,7 @@ func GetModel(name string) (*Model, error) {
return nil, err return nil, err
} }
model.Template, err = template.Parse(string(bts)) m.Template, err = template.Parse(string(bts))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -322,7 +321,7 @@ func GetModel(name string) (*Model, error) {
return nil, err return nil, err
} }
model.System = string(bts) m.System = string(bts)
case "application/vnd.ollama.image.params": case "application/vnd.ollama.image.params":
params, err := os.Open(filename) params, err := os.Open(filename)
if err != nil { if err != nil {
@ -331,7 +330,7 @@ func GetModel(name string) (*Model, error) {
defer params.Close() defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly // parse model options parameters into a map so that we can see which fields have been specified explicitly
if err = json.NewDecoder(params).Decode(&model.Options); err != nil { if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.messages": case "application/vnd.ollama.image.messages":
@ -341,7 +340,7 @@ func GetModel(name string) (*Model, error) {
} }
defer msgs.Close() defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil { if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.license": case "application/vnd.ollama.image.license":
@ -349,21 +348,22 @@ func GetModel(name string) (*Model, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
model.License = append(model.License, string(bts)) m.License = append(m.License, string(bts))
} }
} }
if model.Template != nil && model.CheckCapabilities(CapabilityTools) == nil { capabilities := m.Capabilities()
model.addToolPrefix() if slices.Contains(capabilities, model.CapabilityTools) {
m.addToolPrefix()
} }
return model, nil return m, nil
} }
// HasToolPrefix checks if the completion starts with the tool prefix, ignoring whitespace // HasToolPrefix checks if the completion starts with the tool prefix, ignoring whitespace
func (m *Model) HasToolPrefix(sb strings.Builder) bool { func (m *Model) HasToolPrefix(sb strings.Builder) bool {
text := regexp.MustCompile(`\s+`).ReplaceAllString(sb.String(), "") text := strings.ReplaceAll(strings.TrimSpace(sb.String()), " ", "")
toolString := regexp.MustCompile(`\s+`).ReplaceAllString(m.ToolPrefix, "") toolString := strings.ReplaceAll(strings.TrimSpace(m.ToolPrefix), " ", "")
if len(text) < len(toolString) { if len(text) < len(toolString) {
return text == toolString[:len(text)] return text == toolString[:len(text)]