mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
checkpoint
This commit is contained in:
parent
b6ca295f24
commit
46c95b25dd
3 changed files with 165 additions and 94 deletions
|
@ -1527,7 +1527,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
// var sb strings.Builder
|
||||
var toolParser *ToolParser
|
||||
if len(req.Tools) > 0 {
|
||||
toolParser = NewToolParser(m)
|
||||
|
@ -1560,6 +1559,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
|
||||
if len(req.Tools) > 0 && !toolParser.Done {
|
||||
toolCalls, leftover := toolParser.ParseToolCalls(r.Content)
|
||||
// * This can be abstracted again to a .handleState(tp.state)
|
||||
// * However, we'd need a flag to indicate whether to send the response or not
|
||||
// * happy to take whatever is more idiomatic
|
||||
switch toolParser.ParserState {
|
||||
case ToolCallAccumulate:
|
||||
// tokens are accumulated in the tool parser
|
||||
|
@ -1568,7 +1570,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
// tokens are sent back in the response
|
||||
case ToolCallSendPartial:
|
||||
// tokens not needed for parsing are sent back in the response
|
||||
res.Message.Content = leftover
|
||||
if len(leftover) > 0 {
|
||||
res.Message.Content = leftover
|
||||
}
|
||||
// ! state is needed as we need to not match on the other states
|
||||
case ToolCallFound:
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
|
@ -1576,6 +1581,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
fmt.Println("sending response", res.Message.Content)
|
||||
// * this is where we'd need the flag if we have a .handleState(tp.state)
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
|
|
158
server/tools.go
158
server/tools.go
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
|
||||
|
@ -24,7 +23,9 @@ const (
|
|||
GreedyToolNoPrefix
|
||||
ForceTools
|
||||
ToolSuffix
|
||||
ContainsPartialPrefix
|
||||
ContainsPrefix
|
||||
PartialPrefix
|
||||
NotPartialPrefix
|
||||
Done
|
||||
)
|
||||
|
||||
|
@ -64,9 +65,11 @@ func (s State) String() string {
|
|||
return "ForceTools"
|
||||
case ToolSuffix:
|
||||
return "ToolSuffix"
|
||||
case PartialPrefix:
|
||||
return "PossiblePrefix"
|
||||
case Done:
|
||||
return "Done"
|
||||
case ContainsPartialPrefix:
|
||||
case ContainsPrefix:
|
||||
return "PartialPrefix"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown State (%d)", s)
|
||||
|
@ -88,6 +91,8 @@ type ToolParser struct {
|
|||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
|
||||
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||
fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
|
@ -101,6 +106,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||
},
|
||||
},
|
||||
}); err != nil {
|
||||
fmt.Printf("failed to execute template: error=%v\n", err)
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
|
@ -108,6 +114,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||
var temp any
|
||||
err := jsonv2.Unmarshal(b.Bytes(), &temp)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to unmarshal template: error=%v\n", err)
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
|
@ -125,6 +132,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||
}
|
||||
default:
|
||||
// TODO: err or fallback
|
||||
fmt.Printf("collect encountered unknown type: type=%T\n", obj)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -142,6 +150,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||
templateObjects = collect(t)
|
||||
}
|
||||
if len(templateObjects) == 0 {
|
||||
fmt.Println("no template objects found")
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
|
@ -151,12 +160,15 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||
switch v.(type) {
|
||||
case string:
|
||||
name = k
|
||||
fmt.Printf("found name field: key=%s\n", k)
|
||||
case map[string]any:
|
||||
arguments = k
|
||||
fmt.Printf("found arguments field: key=%s\n", k)
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" || arguments == "" {
|
||||
fmt.Printf("missing required fields: name_found=%v arguments_found=%v\n", name != "", arguments != "")
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
|
@ -165,18 +177,17 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||
dec := jsontext.NewDecoder(strings.NewReader(s))
|
||||
if got, err := dec.ReadValue(); err == nil {
|
||||
s = got.String()
|
||||
fmt.Printf("decoded JSON value: value=%s\n", s)
|
||||
}
|
||||
|
||||
var responseObjects any
|
||||
err = jsonv2.Unmarshal([]byte(s), &responseObjects)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
|
||||
fmt.Println("Detected partial or incomplete JSON.")
|
||||
fmt.Println("state", p.state)
|
||||
fmt.Println("incomplete JSON detected")
|
||||
return nil, true, false
|
||||
} else {
|
||||
fmt.Printf("Other error: %v\n", err)
|
||||
fmt.Println("exiting from JSON parsing", p.state)
|
||||
fmt.Printf("failed to unmarshal response: error=%v\n", err)
|
||||
return nil, false, false
|
||||
}
|
||||
}
|
||||
|
@ -187,14 +198,14 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||
return nil, false, false
|
||||
}
|
||||
|
||||
slog.Debug("collected objects", "count", len(objs))
|
||||
fmt.Printf("collected objects: count=%d\n", len(objs))
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
slog.Debug("found valid tool call", "name", n)
|
||||
fmt.Printf("found valid tool call: name=%s\n", n)
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
|
@ -204,84 +215,89 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||
}
|
||||
}
|
||||
|
||||
slog.Debug("parsed tool calls", "count", len(toolCalls))
|
||||
fmt.Printf("parsed tool calls: count=%d\n", len(toolCalls))
|
||||
return toolCalls, false, true
|
||||
}
|
||||
|
||||
func (p *ToolParser) updateOutputState(ok bool, partial bool, tcs []api.ToolCall) {
|
||||
// TODO: clean up the boundary of internal and external state transitions
|
||||
func (p *ToolParser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) {
|
||||
fmt.Printf("updating output state: ok=%v partial=%v tool_calls=%d current_state=%s\n", ok, partial, len(tcs), p.state)
|
||||
|
||||
// state transition logic
|
||||
switch {
|
||||
case !ok && !partial && p.state == ForceTools:
|
||||
fmt.Println("Case: !ok && !partial && ForceTools - staying in force tools, resetting buffer")
|
||||
// force partial tool if we have a prefix
|
||||
// no op and stay in force tools
|
||||
p.sb.Reset()
|
||||
case !ok && !partial:
|
||||
fmt.Println("Case: !ok && !partial")
|
||||
fmt.Println("state", p.state)
|
||||
if p.state == GreedyToolNoPrefix {
|
||||
fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done")
|
||||
p.state = Done
|
||||
// p.ParserState = DoneFR
|
||||
p.ParserState = ToolCallSendTokens
|
||||
// ? the output parser state is the same even though internal can we not leak the external state?
|
||||
p.Done = true
|
||||
}
|
||||
if p.state == GreedyToolWithPrefix {
|
||||
fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens")
|
||||
p.state = SendTokens
|
||||
p.ParserState = ToolCallSendTokens
|
||||
}
|
||||
p.sb.Reset()
|
||||
if p.state == PartialPrefix {
|
||||
p.state = NotPartialPrefix
|
||||
}
|
||||
case !ok && partial:
|
||||
fmt.Println("Case: !ok && partial - accumulating partial content")
|
||||
// ! acucumulate
|
||||
// acucumulate
|
||||
|
||||
case len(tcs) > 0:
|
||||
fmt.Println("Case: tool calls found")
|
||||
// do not parse again in the greedy JSON case as soon as we have a tool call
|
||||
if p.state == GreedyToolWithPrefix {
|
||||
p.state = SendTokens
|
||||
p.ParserState = ToolCallFound
|
||||
p.state = Done
|
||||
p.Done = true
|
||||
} else if p.state == GreedyToolNoPrefix {
|
||||
fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens")
|
||||
p.state = Done
|
||||
p.Done = true
|
||||
}
|
||||
p.sb.Reset()
|
||||
}
|
||||
p.updateExternalState(tcs)
|
||||
fmt.Printf("state updated: new_state=%s parser_state=%s\n", p.state, p.ParserState)
|
||||
}
|
||||
|
||||
func (p *ToolParser) updateExternalState(tcs []api.ToolCall) {
|
||||
if (p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix || p.state == ToolSuffix) || (p.state == ForceTools && len(tcs) == 0) {
|
||||
p.ParserState = ToolCallAccumulate
|
||||
} else if p.state == ContainsPartialPrefix {
|
||||
p.ParserState = ToolCallSendPartial
|
||||
} else if len(tcs) > 0 {
|
||||
fmt.Printf("updating external state: current_state=%s tool_calls=%d\n", p.state, len(tcs))
|
||||
|
||||
switch {
|
||||
case len(tcs) > 0:
|
||||
// do not parse again in the greedy JSON case as soon as we have a tool call
|
||||
if p.state == GreedyToolWithPrefix {
|
||||
p.state = SendTokens
|
||||
} else if p.state == GreedyToolNoPrefix {
|
||||
p.state = Done
|
||||
p.Done = true
|
||||
}
|
||||
p.ParserState = ToolCallFound
|
||||
} else if p.state == SendTokens {
|
||||
case p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix ||
|
||||
p.state == ToolSuffix || p.state == PartialPrefix ||
|
||||
(p.state == ForceTools && len(tcs) == 0):
|
||||
p.ParserState = ToolCallAccumulate
|
||||
case p.state == ContainsPrefix:
|
||||
p.ParserState = ToolCallSendPartial
|
||||
case p.state == SendTokens || p.state == Done:
|
||||
p.ParserState = ToolCallSendTokens
|
||||
case p.state == NotPartialPrefix:
|
||||
p.ParserState = ToolCallSendPartial
|
||||
default:
|
||||
p.ParserState = ToolCallSendTokens
|
||||
p.sb.Reset()
|
||||
p.state = SendTokens
|
||||
}
|
||||
}
|
||||
|
||||
// string, and if it has a prefix
|
||||
func (p *ToolParser) checkPrefix(s string) (string, bool) {
|
||||
fmt.Printf("checking prefix: input=%s prefix=%s\n", s, p.toolPrefix)
|
||||
|
||||
if p.toolPrefix == "" {
|
||||
return s, true
|
||||
}
|
||||
original := s
|
||||
// s = strings.TrimSpace(s)
|
||||
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
|
||||
if hasPrefix {
|
||||
fmt.Println("has prefix", s)
|
||||
p.state = ForceTools
|
||||
// partial tool possibly
|
||||
} else if strings.HasPrefix(p.toolPrefix, s) {
|
||||
slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s)
|
||||
// TODO: could possibly err maybe this should be greedy instead?
|
||||
p.state = ForceTools
|
||||
// this would basically be a no op on rest of the input
|
||||
fmt.Printf("found exact prefix match: remaining=%s\n", s)
|
||||
// partial tool possibly - accumulate
|
||||
} else if suffixOverlap(s, p.toolPrefix) > 0 {
|
||||
p.state = PartialPrefix
|
||||
fmt.Printf("found partial prefix: remaining=%s\n", s)
|
||||
return "", false
|
||||
// the case where "token<tool_call>" - send "token" back
|
||||
// accounts for spaces in prefix or suffix to avoid breaking cache
|
||||
|
@ -289,11 +305,13 @@ func (p *ToolParser) checkPrefix(s string) (string, bool) {
|
|||
idx := strings.Index(original, p.toolPrefix)
|
||||
if idx != -1 {
|
||||
// still keeps the prefix
|
||||
p.state = ContainsPartialPrefix
|
||||
p.state = ContainsPrefix
|
||||
p.sb.Reset()
|
||||
// todo: see if there is a simpler way for this
|
||||
idx2 := strings.Index(s, p.toolPrefix)
|
||||
// buffer now only has the prefix
|
||||
p.sb.WriteString(s[idx2:])
|
||||
fmt.Printf("found prefix in middle: prefix_start=%d content_before=%s\n", idx, original[:idx])
|
||||
return original[:idx], false
|
||||
}
|
||||
}
|
||||
|
@ -305,51 +323,71 @@ func (p *ToolParser) checkPrefix(s string) (string, bool) {
|
|||
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
||||
// Returns tool calls, whether parsing is incomplete, and any errors.
|
||||
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string) {
|
||||
fmt.Println("checking tool calls", s)
|
||||
fmt.Println("external state", p.ParserState)
|
||||
fmt.Println("internal state", p.state)
|
||||
fmt.Printf("parsing tool calls: input=%s current_state=%s\n", s, p.state)
|
||||
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
|
||||
s = strings.TrimSpace(s)
|
||||
fmt.Println("sb", s)
|
||||
|
||||
p.updateExternalState(nil)
|
||||
if len(s) == 0 {
|
||||
p.updateExternalState(nil)
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
s, cont := p.checkPrefix(s)
|
||||
if !cont {
|
||||
p.updateExternalState(nil)
|
||||
if p.state == ContainsPartialPrefix {
|
||||
if p.state == ContainsPrefix {
|
||||
fmt.Printf("returning partial prefix: remaining=%s\n", s)
|
||||
return nil, s
|
||||
}
|
||||
// * we'd be returning here for just accumulating with possible prefix
|
||||
// * ext state is accumulation
|
||||
return nil, ""
|
||||
}
|
||||
// * lets say the check fails here and now we're still in external state accumulation here
|
||||
|
||||
// stay in SendTokens unless we have a prefix
|
||||
if p.state == SendTokens {
|
||||
fmt.Println("SendTokens - resetting buffer")
|
||||
p.updateExternalState(nil)
|
||||
p.sb.Reset()
|
||||
return nil, ""
|
||||
fmt.Printf("returning send tokens: remaining=%s\n", s)
|
||||
return nil, s
|
||||
}
|
||||
|
||||
// * we'd parse here as json to see if it's a tool call
|
||||
tcs, partial, ok := p.parseJSONToolCalls(s)
|
||||
p.updateOutputState(ok, partial, tcs)
|
||||
fmt.Println("output state", p.ParserState, p.state)
|
||||
// * it would not be a tool call here
|
||||
p.updateStateAfterJSONParse(ok, partial, tcs)
|
||||
if !ok {
|
||||
fmt.Println("returning empty tool calls")
|
||||
// * and so we should send the data here
|
||||
// * we also need to move out of that internal state after sending the tokens
|
||||
if p.state == NotPartialPrefix {
|
||||
p.state = SendTokens
|
||||
// the string would have acc until here
|
||||
return nil, p.sb.String()
|
||||
}
|
||||
return nil, ""
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc.Function.Index = p.toolIndex
|
||||
p.toolIndex++
|
||||
}
|
||||
fmt.Printf("finished parsing tool calls: tool_calls_found=%d\n", len(tcs))
|
||||
return tcs, ""
|
||||
}
|
||||
|
||||
func suffixOverlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func NewToolParser(model *Model) *ToolParser {
|
||||
// TODO: use new template parsing to get all tokens for the prefix
|
||||
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
|
||||
|
@ -365,7 +403,7 @@ func NewToolParser(model *Model) *ToolParser {
|
|||
} else {
|
||||
state = GreedyToolWithPrefix
|
||||
}
|
||||
fmt.Println("setup state", state)
|
||||
fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", templateToolPrefix, state)
|
||||
return &ToolParser{
|
||||
tmpl: tmpl,
|
||||
sb: &strings.Builder{},
|
||||
|
|
|
@ -55,21 +55,21 @@ func TestParseToolCalls(t *testing.T) {
|
|||
expectedTokens string
|
||||
}{
|
||||
{
|
||||
name: "mistral invalid json",
|
||||
name: "mistral malformed json with tool calls prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral multiple tool calls - no prefix",
|
||||
name: "mistral multiple tool calls without prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral tool calls with text in between - no prefix",
|
||||
name: "mistral tool calls with text between no prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
|
@ -77,15 +77,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "mistral valid json - with prefix",
|
||||
name: "mistral valid json with tool calls prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
// In this case we'd be ignoring the text in between and just returning the tool calls
|
||||
name: "mistral valid json with text in between - with prefix",
|
||||
name: "mistral multiple tool calls with text between and prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
|
@ -93,14 +92,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral incomplete json",
|
||||
name: "mistral incomplete json with tool calls prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral without tool token",
|
||||
name: "mistral invalid tool call with explanatory text no prefix",
|
||||
model: "mistral",
|
||||
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
|
@ -109,14 +108,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "mistral without tool token - tool first",
|
||||
name: "mistral tool calls without prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "command-r-plus with json block",
|
||||
name: "command r plus tool calls with json block format",
|
||||
model: "command-r-plus",
|
||||
output: "Action: ```json" + `
|
||||
[
|
||||
|
@ -140,14 +139,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "firefunction with functools",
|
||||
name: "firefunction tool calls with functools prefix",
|
||||
model: "firefunction",
|
||||
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "llama3 with tool call tags",
|
||||
name: "llama3 groq single tool call with xml tags",
|
||||
model: "llama3-groq-tool-use",
|
||||
output: `<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||
|
@ -156,99 +155,126 @@ func TestParseToolCalls(t *testing.T) {
|
|||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "xlam with tool_calls wrapper",
|
||||
name: "xlam tool calls with wrapper object",
|
||||
model: "xlam",
|
||||
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 with single tool call",
|
||||
name: "qwen2.5-coder single tool call with prefix",
|
||||
model: "qwen2.5-coder",
|
||||
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen with no tool prefix",
|
||||
name: "qwen2.5-coder multiple tool calls with and without prefix",
|
||||
model: "qwen2.5-coder",
|
||||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call> <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5-coder multiple tool calls without prefix",
|
||||
model: "qwen2.5-coder",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen with no tool calls",
|
||||
name: "qwen2.5-coder plain text response no tool calls",
|
||||
model: "qwen2.5-coder",
|
||||
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||
},
|
||||
{
|
||||
name: "qwen with no tool prefix",
|
||||
name: "qwen2.5-coder tool calls with trailing text",
|
||||
model: "qwen2.5-coder",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "some tokens after call",
|
||||
},
|
||||
{
|
||||
name: "qwen with prefix",
|
||||
name: "qwen2.5 tool calls with prefix and trailing text",
|
||||
model: "qwen2.5-coder",
|
||||
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
// tests the leftover logic as well
|
||||
name: "qwen3 with single tool call and thinking",
|
||||
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
|
||||
model: "qwen3",
|
||||
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||
},
|
||||
{
|
||||
name: "qwen3 with single tool call and thinking spaces",
|
||||
name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)",
|
||||
model: "qwen3",
|
||||
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||
},
|
||||
{
|
||||
name: "qwen3 testing",
|
||||
name: "qwen3 empty think prefix without tool prefix and invalid tool call",
|
||||
model: "qwen3",
|
||||
output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 testing 2",
|
||||
name: "qwen3 empty think prefix with tool prefix and valid tool call",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: `<think></think>`,
|
||||
},
|
||||
{
|
||||
name: "llama3.2 with tool call - no prefix",
|
||||
name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)",
|
||||
model: "qwen3",
|
||||
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with malformed tool prefix",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "llama3.2 valid tool call without prefix",
|
||||
model: "llama3.2",
|
||||
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "llama3.2 with incomplete tool call - no prefix",
|
||||
name: "llama3.2 incomplete tool call without prefix",
|
||||
model: "llama3.2",
|
||||
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "llama3.2 with tool call - in middle",
|
||||
name: "llama3.2 tool call with leading text",
|
||||
model: "llama3.2",
|
||||
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||
},
|
||||
{
|
||||
name: "llama3.2 - fake tool prefix",
|
||||
name: "llama3.2 tool call with invalid tool prefix (no prefix in template)",
|
||||
model: "llama3.2",
|
||||
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
|
@ -288,7 +314,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||
m := &Model{Template: tmpl}
|
||||
tp := NewToolParser(m)
|
||||
got := []api.ToolCall{}
|
||||
var actualTokens strings.Builder
|
||||
var gotTokens strings.Builder
|
||||
|
||||
tokens := strings.Fields(tt.output)
|
||||
for _, tok := range tokens {
|
||||
|
@ -302,17 +328,18 @@ func TestParseToolCalls(t *testing.T) {
|
|||
got = append(got, toolCalls...)
|
||||
add = false
|
||||
case ToolCallSendTokens:
|
||||
actualTokens.WriteString(s)
|
||||
gotTokens.WriteString(s)
|
||||
add = false
|
||||
case ToolCallAccumulate:
|
||||
add = false
|
||||
case ToolCallSendPartial:
|
||||
actualTokens.WriteString(" " + leftover)
|
||||
t.Log("send partial", "leftover", leftover)
|
||||
gotTokens.WriteString(" " + leftover)
|
||||
add = false
|
||||
}
|
||||
}
|
||||
if add {
|
||||
actualTokens.WriteString(s)
|
||||
gotTokens.WriteString(s)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -322,7 +349,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||
}
|
||||
|
||||
// Compare tokens if we expect any
|
||||
stripped := strings.TrimSpace(actualTokens.String())
|
||||
stripped := strings.TrimSpace(gotTokens.String())
|
||||
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
|
||||
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
|
||||
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue