checkpoint

This commit is contained in:
ParthSareen 2025-05-09 17:05:16 -07:00
parent b6ca295f24
commit 46c95b25dd
3 changed files with 165 additions and 94 deletions

View file

@ -1527,7 +1527,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any)
go func() {
defer close(ch)
// var sb strings.Builder
var toolParser *ToolParser
if len(req.Tools) > 0 {
toolParser = NewToolParser(m)
@ -1560,6 +1559,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 && !toolParser.Done {
toolCalls, leftover := toolParser.ParseToolCalls(r.Content)
// * This can be abstracted again to a .handleState(tp.state)
// * However, we'd need a flag to indicate whether to send the response or not
// * happy to take whatever is more idiomatic
switch toolParser.ParserState {
case ToolCallAccumulate:
// tokens are accumulated in the tool parser
@ -1568,7 +1570,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
// tokens are sent back in the response
case ToolCallSendPartial:
// tokens not needed for parsing are sent back in the response
res.Message.Content = leftover
if len(leftover) > 0 {
res.Message.Content = leftover
}
// ! state is needed as we need to not match on the other states
case ToolCallFound:
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
@ -1576,6 +1581,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
fmt.Println("sending response", res.Message.Content)
// * this is where we'd need the flag if we have a .handleState(tp.state)
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}

View file

@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"strings"
gotmpl "text/template"
@ -24,7 +23,9 @@ const (
GreedyToolNoPrefix
ForceTools
ToolSuffix
ContainsPartialPrefix
ContainsPrefix
PartialPrefix
NotPartialPrefix
Done
)
@ -64,9 +65,11 @@ func (s State) String() string {
return "ForceTools"
case ToolSuffix:
return "ToolSuffix"
case PartialPrefix:
return "PossiblePrefix"
case Done:
return "Done"
case ContainsPartialPrefix:
case ContainsPrefix:
return "PartialPrefix"
default:
return fmt.Sprintf("Unknown State (%d)", s)
@ -88,6 +91,8 @@ type ToolParser struct {
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s)
var b bytes.Buffer
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
@ -101,6 +106,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
},
},
}); err != nil {
fmt.Printf("failed to execute template: error=%v\n", err)
return nil, false, false
}
@ -108,6 +114,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
var temp any
err := jsonv2.Unmarshal(b.Bytes(), &temp)
if err != nil {
fmt.Printf("failed to unmarshal template: error=%v\n", err)
return nil, false, false
}
@ -125,6 +132,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
}
default:
// TODO: err or fallback
fmt.Printf("collect encountered unknown type: type=%T\n", obj)
return nil
}
@ -142,6 +150,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
templateObjects = collect(t)
}
if len(templateObjects) == 0 {
fmt.Println("no template objects found")
return nil, false, false
}
@ -151,12 +160,15 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
switch v.(type) {
case string:
name = k
fmt.Printf("found name field: key=%s\n", k)
case map[string]any:
arguments = k
fmt.Printf("found arguments field: key=%s\n", k)
}
}
if name == "" || arguments == "" {
fmt.Printf("missing required fields: name_found=%v arguments_found=%v\n", name != "", arguments != "")
return nil, false, false
}
@ -165,18 +177,17 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
dec := jsontext.NewDecoder(strings.NewReader(s))
if got, err := dec.ReadValue(); err == nil {
s = got.String()
fmt.Printf("decoded JSON value: value=%s\n", s)
}
var responseObjects any
err = jsonv2.Unmarshal([]byte(s), &responseObjects)
if err != nil {
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
fmt.Println("Detected partial or incomplete JSON.")
fmt.Println("state", p.state)
fmt.Println("incomplete JSON detected")
return nil, true, false
} else {
fmt.Printf("Other error: %v\n", err)
fmt.Println("exiting from JSON parsing", p.state)
fmt.Printf("failed to unmarshal response: error=%v\n", err)
return nil, false, false
}
}
@ -187,14 +198,14 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
return nil, false, false
}
slog.Debug("collected objects", "count", len(objs))
fmt.Printf("collected objects: count=%d\n", len(objs))
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
slog.Debug("found valid tool call", "name", n)
fmt.Printf("found valid tool call: name=%s\n", n)
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
@ -204,84 +215,89 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
}
}
slog.Debug("parsed tool calls", "count", len(toolCalls))
fmt.Printf("parsed tool calls: count=%d\n", len(toolCalls))
return toolCalls, false, true
}
func (p *ToolParser) updateOutputState(ok bool, partial bool, tcs []api.ToolCall) {
// TODO: clean up the boundary of internal and external state transitions
func (p *ToolParser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) {
fmt.Printf("updating output state: ok=%v partial=%v tool_calls=%d current_state=%s\n", ok, partial, len(tcs), p.state)
// state transition logic
switch {
case !ok && !partial && p.state == ForceTools:
fmt.Println("Case: !ok && !partial && ForceTools - staying in force tools, resetting buffer")
// force partial tool if we have a prefix
// no op and stay in force tools
p.sb.Reset()
case !ok && !partial:
fmt.Println("Case: !ok && !partial")
fmt.Println("state", p.state)
if p.state == GreedyToolNoPrefix {
fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done")
p.state = Done
// p.ParserState = DoneFR
p.ParserState = ToolCallSendTokens
// ? the output parser state is the same even though internal can we not leak the external state?
p.Done = true
}
if p.state == GreedyToolWithPrefix {
fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens")
p.state = SendTokens
p.ParserState = ToolCallSendTokens
}
p.sb.Reset()
if p.state == PartialPrefix {
p.state = NotPartialPrefix
}
case !ok && partial:
fmt.Println("Case: !ok && partial - accumulating partial content")
// ! acucumulate
// acucumulate
case len(tcs) > 0:
fmt.Println("Case: tool calls found")
// do not parse again in the greedy JSON case as soon as we have a tool call
if p.state == GreedyToolWithPrefix {
p.state = SendTokens
p.ParserState = ToolCallFound
p.state = Done
p.Done = true
} else if p.state == GreedyToolNoPrefix {
fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens")
p.state = Done
p.Done = true
}
p.sb.Reset()
}
p.updateExternalState(tcs)
fmt.Printf("state updated: new_state=%s parser_state=%s\n", p.state, p.ParserState)
}
func (p *ToolParser) updateExternalState(tcs []api.ToolCall) {
if (p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix || p.state == ToolSuffix) || (p.state == ForceTools && len(tcs) == 0) {
p.ParserState = ToolCallAccumulate
} else if p.state == ContainsPartialPrefix {
p.ParserState = ToolCallSendPartial
} else if len(tcs) > 0 {
fmt.Printf("updating external state: current_state=%s tool_calls=%d\n", p.state, len(tcs))
switch {
case len(tcs) > 0:
// do not parse again in the greedy JSON case as soon as we have a tool call
if p.state == GreedyToolWithPrefix {
p.state = SendTokens
} else if p.state == GreedyToolNoPrefix {
p.state = Done
p.Done = true
}
p.ParserState = ToolCallFound
} else if p.state == SendTokens {
case p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix ||
p.state == ToolSuffix || p.state == PartialPrefix ||
(p.state == ForceTools && len(tcs) == 0):
p.ParserState = ToolCallAccumulate
case p.state == ContainsPrefix:
p.ParserState = ToolCallSendPartial
case p.state == SendTokens || p.state == Done:
p.ParserState = ToolCallSendTokens
case p.state == NotPartialPrefix:
p.ParserState = ToolCallSendPartial
default:
p.ParserState = ToolCallSendTokens
p.sb.Reset()
p.state = SendTokens
}
}
// string, and if it has a prefix
func (p *ToolParser) checkPrefix(s string) (string, bool) {
fmt.Printf("checking prefix: input=%s prefix=%s\n", s, p.toolPrefix)
if p.toolPrefix == "" {
return s, true
}
original := s
// s = strings.TrimSpace(s)
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
if hasPrefix {
fmt.Println("has prefix", s)
p.state = ForceTools
// partial tool possibly
} else if strings.HasPrefix(p.toolPrefix, s) {
slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s)
// TODO: could possibly err maybe this should be greedy instead?
p.state = ForceTools
// this would basically be a no op on rest of the input
fmt.Printf("found exact prefix match: remaining=%s\n", s)
// partial tool possibly - accumulate
} else if suffixOverlap(s, p.toolPrefix) > 0 {
p.state = PartialPrefix
fmt.Printf("found partial prefix: remaining=%s\n", s)
return "", false
// the case where "token<tool_call>" - send "token" back
// accounts for spaces in prefix or suffix to avoid breaking cache
@ -289,11 +305,13 @@ func (p *ToolParser) checkPrefix(s string) (string, bool) {
idx := strings.Index(original, p.toolPrefix)
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
p.state = ContainsPrefix
p.sb.Reset()
// todo: see if there is a simpler way for this
idx2 := strings.Index(s, p.toolPrefix)
// buffer now only has the prefix
p.sb.WriteString(s[idx2:])
fmt.Printf("found prefix in middle: prefix_start=%d content_before=%s\n", idx, original[:idx])
return original[:idx], false
}
}
@ -305,51 +323,71 @@ func (p *ToolParser) checkPrefix(s string) (string, bool) {
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
// Returns tool calls, whether parsing is incomplete, and any errors.
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string) {
fmt.Println("checking tool calls", s)
fmt.Println("external state", p.ParserState)
fmt.Println("internal state", p.state)
fmt.Printf("parsing tool calls: input=%s current_state=%s\n", s, p.state)
p.sb.WriteString(s)
s = p.sb.String()
s = strings.TrimSpace(s)
fmt.Println("sb", s)
p.updateExternalState(nil)
if len(s) == 0 {
p.updateExternalState(nil)
return nil, ""
}
s, cont := p.checkPrefix(s)
if !cont {
p.updateExternalState(nil)
if p.state == ContainsPartialPrefix {
if p.state == ContainsPrefix {
fmt.Printf("returning partial prefix: remaining=%s\n", s)
return nil, s
}
// * we'd be returning here for just accumulating with possible prefix
// * ext state is accumulation
return nil, ""
}
// * lets say the check fails here and now we're still in external state accumulation here
// stay in SendTokens unless we have a prefix
if p.state == SendTokens {
fmt.Println("SendTokens - resetting buffer")
p.updateExternalState(nil)
p.sb.Reset()
return nil, ""
fmt.Printf("returning send tokens: remaining=%s\n", s)
return nil, s
}
// * we'd parse here as json to see if it's a tool call
tcs, partial, ok := p.parseJSONToolCalls(s)
p.updateOutputState(ok, partial, tcs)
fmt.Println("output state", p.ParserState, p.state)
// * it would not be a tool call here
p.updateStateAfterJSONParse(ok, partial, tcs)
if !ok {
fmt.Println("returning empty tool calls")
// * and so we should send the data here
// * we also need to move out of that internal state after sending the tokens
if p.state == NotPartialPrefix {
p.state = SendTokens
// the string would have acc until here
return nil, p.sb.String()
}
return nil, ""
}
for _, tc := range tcs {
tc.Function.Index = p.toolIndex
p.toolIndex++
}
fmt.Printf("finished parsing tool calls: tool_calls_found=%d\n", len(tcs))
return tcs, ""
}
func suffixOverlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func NewToolParser(model *Model) *ToolParser {
// TODO: use new template parsing to get all tokens for the prefix
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
@ -365,7 +403,7 @@ func NewToolParser(model *Model) *ToolParser {
} else {
state = GreedyToolWithPrefix
}
fmt.Println("setup state", state)
fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", templateToolPrefix, state)
return &ToolParser{
tmpl: tmpl,
sb: &strings.Builder{},

View file

@ -55,21 +55,21 @@ func TestParseToolCalls(t *testing.T) {
expectedTokens string
}{
{
name: "mistral invalid json",
name: "mistral malformed json with tool calls prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
},
{
name: "mistral multiple tool calls - no prefix",
name: "mistral multiple tool calls without prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "mistral tool calls with text in between - no prefix",
name: "mistral tool calls with text between no prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
@ -77,15 +77,14 @@ func TestParseToolCalls(t *testing.T) {
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
},
{
name: "mistral valid json - with prefix",
name: "mistral valid json with tool calls prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
// In this case we'd be ignoring the text in between and just returning the tool calls
name: "mistral valid json with text in between - with prefix",
name: "mistral multiple tool calls with text between and prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
@ -93,14 +92,14 @@ func TestParseToolCalls(t *testing.T) {
expectedTokens: "",
},
{
name: "mistral incomplete json",
name: "mistral incomplete json with tool calls prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
},
{
name: "mistral without tool token",
name: "mistral invalid tool call with explanatory text no prefix",
model: "mistral",
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
@ -109,14 +108,14 @@ func TestParseToolCalls(t *testing.T) {
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
},
{
name: "mistral without tool token - tool first",
name: "mistral tool calls without prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "command-r-plus with json block",
name: "command r plus tool calls with json block format",
model: "command-r-plus",
output: "Action: ```json" + `
[
@ -140,14 +139,14 @@ func TestParseToolCalls(t *testing.T) {
expectedTokens: "",
},
{
name: "firefunction with functools",
name: "firefunction tool calls with functools prefix",
model: "firefunction",
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "llama3 with tool call tags",
name: "llama3 groq single tool call with xml tags",
model: "llama3-groq-tool-use",
output: `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
@ -156,99 +155,126 @@ func TestParseToolCalls(t *testing.T) {
expectedTokens: "",
},
{
name: "xlam with tool_calls wrapper",
name: "xlam tool calls with wrapper object",
model: "xlam",
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "qwen2.5 with single tool call",
name: "qwen2.5-coder single tool call with prefix",
model: "qwen2.5-coder",
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
},
{
name: "qwen with no tool prefix",
name: "qwen2.5-coder multiple tool calls with and without prefix",
model: "qwen2.5-coder",
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call> <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1, t1, t2},
expectedTokens: "",
},
{
name: "qwen2.5-coder multiple tool calls without prefix",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "qwen with no tool calls",
name: "qwen2.5-coder plain text response no tool calls",
model: "qwen2.5-coder",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
},
{
name: "qwen with no tool prefix",
name: "qwen2.5-coder tool calls with trailing text",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "some tokens after call",
},
{
name: "qwen with prefix",
name: "qwen2.5 tool calls with prefix and trailing text",
model: "qwen2.5-coder",
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
// tests the leftover logic as well
name: "qwen3 with single tool call and thinking",
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
},
{
name: "qwen3 with single tool call and thinking spaces",
name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
},
{
name: "qwen3 testing",
name: "qwen3 empty think prefix without tool prefix and invalid tool call",
model: "qwen3",
output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "qwen3 testing 2",
name: "qwen3 empty think prefix with tool prefix and valid tool call",
model: "qwen3",
output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: `<think></think>`,
},
{
name: "llama3.2 with tool call - no prefix",
name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)",
model: "qwen3",
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
model: "qwen3",
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "qwen3 invalid tool call with malformed tool prefix",
model: "qwen3",
output: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "llama3.2 valid tool call without prefix",
model: "llama3.2",
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
},
{
name: "llama3.2 with incomplete tool call - no prefix",
name: "llama3.2 incomplete tool call without prefix",
model: "llama3.2",
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
},
{
name: "llama3.2 with tool call - in middle",
name: "llama3.2 tool call with leading text",
model: "llama3.2",
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
},
{
name: "llama3.2 - fake tool prefix",
name: "llama3.2 tool call with invalid tool prefix (no prefix in template)",
model: "llama3.2",
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
@ -288,7 +314,7 @@ func TestParseToolCalls(t *testing.T) {
m := &Model{Template: tmpl}
tp := NewToolParser(m)
got := []api.ToolCall{}
var actualTokens strings.Builder
var gotTokens strings.Builder
tokens := strings.Fields(tt.output)
for _, tok := range tokens {
@ -302,17 +328,18 @@ func TestParseToolCalls(t *testing.T) {
got = append(got, toolCalls...)
add = false
case ToolCallSendTokens:
actualTokens.WriteString(s)
gotTokens.WriteString(s)
add = false
case ToolCallAccumulate:
add = false
case ToolCallSendPartial:
actualTokens.WriteString(" " + leftover)
t.Log("send partial", "leftover", leftover)
gotTokens.WriteString(" " + leftover)
add = false
}
}
if add {
actualTokens.WriteString(s)
gotTokens.WriteString(s)
}
}
@ -322,7 +349,7 @@ func TestParseToolCalls(t *testing.T) {
}
// Compare tokens if we expect any
stripped := strings.TrimSpace(actualTokens.String())
stripped := strings.TrimSpace(gotTokens.String())
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
t.Errorf("tokens mismatch (-got +want):\n%s", diff)