This commit is contained in:
Michael Yang 2024-06-20 13:45:47 -07:00
parent 9e35d9bbee
commit d02bbebb11
7 changed files with 263 additions and 53 deletions

View file

@ -4,6 +4,7 @@ import (
"archive/zip"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
@ -11,7 +12,11 @@ import (
"net/http"
"os"
"path/filepath"
"slices"
"strings"
"text/template/parse"
"github.com/google/uuid"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
@ -289,3 +294,103 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil
}
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]map[string]any{
"ToolCalls": {
{
"Function": map[string]any{
"Name": "@@name@@",
"Arguments": "@@arguments@@",
},
},
},
}); err != nil {
return nil, false
}
var kv map[string]string
// execute the subtree with placeholders to identify the keys
if err := json.Unmarshal(b.Bytes(), &kv); err != nil {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
switch v {
case "@@name@@":
name = k
case "@@arguments@@":
arguments = k
}
}
var sm []map[string]any
decoder := json.NewDecoder(strings.NewReader(s))
for {
// incrementally decode the JSON into a list of JSON objects
// skipping over any invalid tokens
if err := decoder.Decode(&sm); err != nil {
if errors.Is(err, io.EOF) {
break
}
if errors.As(err, new(*json.SyntaxError)) {
r := decoder.Buffered()
if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil {
break
}
decoder = json.NewDecoder(r)
continue
}
return nil, false
}
// break as soon as a valid object is decoded
break
}
var toolCalls []api.ToolCall
for _, kv := range sm {
call := api.ToolCall{
ID: uuid.New().String(),
Type: "function",
}
for k, v := range kv {
switch k {
case name:
call.Function.Name = v.(string)
case arguments:
call.Function.Arguments = v.(map[string]any)
}
}
toolCalls = append(toolCalls, call)
}
if len(toolCalls) > 0 {
return toolCalls, true
}
return nil, false
}