diff --git a/server/images.go b/server/images.go
index bd6d92a6c..e779193c7 100644
--- a/server/images.go
+++ b/server/images.go
@@ -16,10 +16,12 @@ import (
"net/url"
"os"
"path/filepath"
+ "regexp"
"runtime"
"slices"
"strconv"
"strings"
+ "text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
@@ -62,6 +64,7 @@ type Model struct {
Digest string
Options map[string]any
Messages []api.Message
+ ToolPrefix string
Template *template.Template
}
@@ -350,9 +353,47 @@ func GetModel(name string) (*Model, error) {
}
}
+ if model.Template != nil && model.CheckCapabilities(CapabilityTools) == nil {
+ model.addToolPrefix()
+ }
+
return model, nil
}
+// HasToolPrefix checks if the completion starts with the tool prefix, ignoring whitespace
+func (m *Model) HasToolPrefix(sb strings.Builder) bool {
+ text := regexp.MustCompile(`\s+`).ReplaceAllString(sb.String(), "")
+ toolString := regexp.MustCompile(`\s+`).ReplaceAllString(m.ToolPrefix, "")
+
+ if len(text) < len(toolString) {
+ return text == toolString[:len(text)]
+ }
+ return text[:len(toolString)] == toolString
+}
+
+// Figure out what's between the start of the tools block, and the json response, and use it as a marker. Usually that's
+// {- if .ToolCalls}this text{ range .ToolCalls}or maybe this text{{.name}}
+func (m *Model) addToolPrefix() {
+ // create a subtree from the node that ranges over .ToolCalls
+ var previousNode parse.Node
+ toolCallsTemplate := m.Template.Subtree(func(node parse.Node) bool {
+ if rangeNode, ok := node.(*parse.RangeNode); ok {
+ return slices.Contains(template.Identifiers(rangeNode.Pipe), "ToolCalls")
+ }
+ previousNode = node
+ return false
+ })
+ if textNode, ok := previousNode.(*parse.TextNode); ok {
+ m.ToolPrefix = strings.TrimSpace(textNode.String())
+ }
+ if len(m.ToolPrefix) == 0 && len(toolCallsTemplate.Root.Nodes) > 0 {
+ rangeNode, ok := toolCallsTemplate.Root.Nodes[0].(*parse.RangeNode)
+ if ok && len(rangeNode.List.Nodes) > 0 {
+ m.ToolPrefix = rangeNode.List.Nodes[0].String()
+ }
+ }
+}
+
func CopyModel(src, dst model.Name) error {
if !dst.IsFullyQualified() {
return model.Unqualified(dst)
diff --git a/server/model_test.go b/server/model_test.go
index e5c2f2bb2..0f050011f 100644
--- a/server/model_test.go
+++ b/server/model_test.go
@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"path/filepath"
+ "strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -28,19 +29,20 @@ func readFile(t *testing.T, base, name string) *bytes.Buffer {
func TestExecuteWithTools(t *testing.T) {
p := filepath.Join("testdata", "tools")
cases := []struct {
- model string
- output string
- ok bool
+ model string
+ output string
+ ok bool
+ wellFormed bool
}{
- {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
- {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
+ {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true},
+ {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
-The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
- {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
+The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true, false},
+ {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false, false},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
- [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
- {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
+ [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, false},
+ {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
{"command-r-plus", "Action: ```json" + `
[
{
@@ -58,16 +60,17 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
}
}
]
-` + "```", true},
- {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
- {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
- {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
+` + "```", true, true},
+ {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
+ {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true},
+ {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
{"llama3-groq-tool-use", `
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
-`, true},
- {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
- {"nemotron", `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true},
+`, true, true},
+ {"xlam", `### Response:
+{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true, true},
+ {"nemotron", ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true, true},
}
var tools []api.Tool
@@ -119,6 +122,21 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
}
})
+ t.Run("prefix", func(t *testing.T) {
+ m := &Model{Template: tmpl}
+ m.addToolPrefix()
+
+ if tt.wellFormed {
+ if len(m.ToolPrefix) == 0 {
+ t.Fatalf("No tool prefix detected")
+ }
+
+ if !strings.HasPrefix(strings.TrimSpace(tt.output), m.ToolPrefix) {
+ t.Fatalf("incorrect tool prefix: \"%s\", \"%s\"", m.ToolPrefix, tt.output)
+ }
+ }
+ })
+
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output)
@@ -177,3 +195,64 @@ func TestParseObjects(t *testing.T) {
})
}
}
+
+func TestAddToolPrefix(t *testing.T) {
+ tests := []struct {
+ name string
+ template string
+ want string
+ }{
+ {
+ name: "prefix_from_previous_text_node",
+ template: `Previous text node{{- range .ToolCalls}}{{.name}}{{end}}`,
+ want: "Previous text node",
+ },
+ {
+ name: "prefix_from_range_node",
+ template: `{{- range .ToolCalls}}[TOOL_CALLS]{{.name}}{{end}}`,
+ want: "[TOOL_CALLS]",
+ },
+ {
+ name: "prefix_with_extra_whitespace",
+ template: ` Previous text with spaces {{- range .ToolCalls}}{{.name}}{{end}}`,
+ want: "Previous text with spaces",
+ },
+ {
+ name: "prefix_with_newlines",
+ template: "First line\nSecond line\n{{- range .ToolCalls}}{{.name}}{{end}}",
+ want: "First line\nSecond line",
+ },
+ {
+ name: "tool_calls_json_template",
+ template: `{{ if .Content }}{{ .Content }}{{- else if .ToolCalls }}
+{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}
+{{ end }}`,
+ want: ``,
+ },
+ {
+ name: "mistral_tool_calls_template",
+ template: `{{- if .Content }} {{ .Content }}
+{{- else if .ToolCalls }}[TOOL_CALLS] [
+{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
+{{- end }}]
+{{- end }}`,
+ want: "[TOOL_CALLS] [",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpl, err := template.Parse(tt.template)
+ if err != nil {
+ t.Fatalf("failed to parse template: %v", err)
+ }
+
+ m := &Model{Template: tmpl}
+ m.addToolPrefix()
+
+ if m.ToolPrefix != tt.want {
+ t.Errorf("incorrect tool prefix:\ngot: %q\nwant: %q", m.ToolPrefix, tt.want)
+ }
+ })
+ }
+}
diff --git a/server/routes.go b/server/routes.go
index 906426b18..a10285797 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -1526,6 +1526,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch)
var sb strings.Builder
var toolCallIndex int = 0
+ var mightBeTools bool = true
+ buf := make([]api.ChatResponse, 0)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
@@ -1551,18 +1553,29 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
- // TODO: tool call checking and filtering should be moved outside of this callback once streaming
- // however this was a simple change for now without reworking streaming logic of this (and other)
- // handlers
- if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
+ // If we know we're not streaming
+ if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 || !mightBeTools {
ch <- res
return
}
+ sb.WriteString(r.Content)
+
+ // Buffer up responses while we're unsure whether to stream.
+ buf = append(buf, res)
+
+ // not a tools response, continue streaming.
+ if !m.HasToolPrefix(sb) {
+ mightBeTools = false
+ for _, item := range buf {
+ ch <- item
+ }
+ return
+ }
+
// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
- sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
@@ -1573,8 +1586,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
sb.Reset()
ch <- res
return
+ } else {
+ if !strings.HasPrefix(sb.String(), "{") {
+ ch <- res
+ return
+ }
}
-
if r.Done {
// Send any remaining content if no tool calls were detected
if toolCallIndex == 0 {