From 128c90d3ac2c063663c51edcde9ec2959384dca5 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 24 Apr 2025 16:57:54 -0700 Subject: [PATCH] checkpoint!!! --- server/model.go | 289 +++++++++++++++++++++-------------------------- server/routes.go | 32 +++--- 2 files changed, 142 insertions(+), 179 deletions(-) diff --git a/server/model.go b/server/model.go index d37a4a553..5d9609b45 100644 --- a/server/model.go +++ b/server/model.go @@ -377,15 +377,16 @@ func parseJSONToolCalls(obj map[string]any) ([]api.ToolCall, bool) { return nil, false } -func (m *Model) GetToolCallFormat(s string) (string, string, bool) { +func (m *Model) GetToolCallFormat() (string, string, bool) { // Try to detect the tool call format from the model's template tmpl := m.Template.Subtree(func(n parse.Node) bool { if t, ok := n.(*parse.RangeNode); ok { - return slices.Contains(template.Identifiers(t.Pipe), "Content") + return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") } return false }) + fmt.Println("tool call template", tmpl) if tmpl != nil { // Execute template with test data to see the format var b bytes.Buffer @@ -431,20 +432,18 @@ func (m *Model) GetToolCallFormat(s string) (string, string, bool) { return "", "", false } -func parsePythonFunctionCall(s string) (api.ToolCall, bool) { +func parsePythonFunctionCall(s string) ([]api.ToolCall, bool) { re := regexp.MustCompile(`(\w+)\((.*?)\)`) - if match := re.FindStringSubmatchIndex(s); match != nil { + matches := re.FindAllStringSubmatchIndex(s, -1) + if len(matches) == 0 { + return nil, false + } + + var toolCalls []api.ToolCall + for _, match := range matches { name := s[match[2]:match[3]] args := s[match[4]:match[5]] - // Check if there's a < after the closing bracket - if idx := strings.Index(s[match[5]:], "<"); idx >= 0 { - // Wait for closing > by returning false - if !strings.Contains(s[match[5]+idx:], ">") { - return api.ToolCall{}, false - } - } - arguments := make(api.ToolCallFunctionArguments) if strings.Contains(args, "=") { // Keyword args pairs := strings.SplitSeq(args, ",") @@ -457,179 +456,145 @@ func parsePythonFunctionCall(s string) (api.ToolCall, bool) { arguments[key] = value } } - return api.ToolCall{ + toolCalls = append(toolCalls, api.ToolCall{ Function: api.ToolCallFunction{ Name: name, Arguments: arguments, }, - }, true + }) } } - return api.ToolCall{}, false + + if len(toolCalls) > 0 { + return toolCalls, true + } + return nil, false } -func (m *Model) ParseToolCallsStream(s string, prefix *string, specialToken *string) ([]api.ToolCall, bool, bool) { - // The prefix check for for the tags shouldn't really be used and we should be consuming this from the model - // Knowing what the tool token enables quicker and more reliable parsing - // TODO: not sure how we're going to handle chatting before the tool call - // TODO: detection would be relying on the model to know what the tool token is - // fmt.Println("parsing tool calls", s) - - if prefix == nil { - prefix = new(string) - *prefix = "" +// token, partial, success +func deriveToolToken(s string, prefix string) (string, bool, bool) { + // There shouldn't be spaces in a tool token + if len(strings.Fields(s)) > 1 { + return "", false, false } - if specialToken == nil { - specialToken = new(string) - *specialToken = "" - } - // TODO: cache this - // _, token, ok := m.GetToolCallFormat(s) - // if ok && token != "" { - // fmt.Println("token", token) - // *specialToken = token - // } - // fmt.Println("prefix", *prefix) - // fmt.Println("special token", *specialToken) - var partial bool + if prefix == "[" && len(s) > 1 && s[len(s)-1] == ']' { + return s, false, true + } else if prefix == "<" && len(s) > 1 && s[len(s)-1] == '>' { + return s, false, true + } + return "", true, true +} + +func parseJSON(s string) ([]api.ToolCall, bool) { + objs := parseObjects(s) + var toolCalls []api.ToolCall + for _, obj := range objs { + if n, nok := obj["name"].(string); nok { + if a, aok := obj["arguments"].(map[string]any); aok { + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: n, + Arguments: a, + }, + }) + } + } + } + if len(toolCalls) > 0 { + return toolCalls, true + } + return nil, false +} + +// returns tool calls, partial, success +func (m *Model) ParseToolCallsNew(s string, toolToken *string) ([]api.ToolCall, bool, bool) { + // [ case can either be JSON, Python or a Tool Token s = strings.TrimSpace(s) + fmt.Printf("ParseToolCallsNew input: %q\n", s) if len(s) == 0 { return nil, false, false } - if specialToken != nil && len(*specialToken) > 0 { - s2 := *specialToken - if strings.HasPrefix(s, string(s2[0])) { - // fmt.Println("prefix 1 is", string(s2[0])) - partial = true - *prefix = string(s2[0]) - } - } - - if len(s) > 0 { - if s[0] == '[' { - s = strings.ReplaceAll(s, "\n", "") - // tool call list with no special token - if len(s) > 1 && s[1] == '{' { - // fmt.Println("prefix 2 in [{", string(s[0])) - partial = true - *specialToken = "[{" - *prefix = "[{" - } else if *specialToken == "" { - // possible tool call with special token but not in template - // split s over spaces to check for special token - if len(s) > 0 && s[len(s)-1] == ']' { - partial = true - *specialToken = s - *prefix = "[" - } + if strings.HasPrefix(s, "[") { + fmt.Println("Found [ prefix") + // JSON case + // we do not consider array JSONs as tool calls + if strings.HasPrefix(s, "[{") { + fmt.Println("Found [{ prefix - attempting JSON parse") + // TODO: mark as JSON partial + if calls, ok := parseJSON(s); ok { + fmt.Printf("Successfully parsed JSON, found %d calls\n", len(calls)) + return calls, false, true } - } else if s[0] == '{' { - // fmt.Println("prefix 2 in {", string(s[0])) - partial = true - *specialToken = "{" - *prefix = "{" - } else if s[0] == '<' { - // TODO: the only issue here is that we might miss a > if the token is weird - // The 1 && s[1] == '/' { - // fmt.Println("prefix3 in <", string(s[0])) - // returning a partial here is a hack to ensure that we don't send the content downstream - return nil, true, true - // TODO: jank hack to get special token right - // special token might not be set yet - } else if s[len(s)-1] == '>' { - partial = true - *specialToken = s - *prefix = "<" - } else if specialToken != nil && *specialToken == "" { - partial = true - *specialToken = "<" - *prefix = "<" - } - } - } - - // fmt.Println("special token", *specialToken) - // fmt.Println("prefix", *prefix) - - if !partial { - return nil, false, false - } - // Look for tags - // fmt.Println("looking for special token", *specialToken) - start := strings.Index(s, *specialToken) - if start == -1 { - if partial { - // fmt.Println("did not find opening tag, partial match", *specialToken) return nil, true, true } - return nil, false, false - } - end := len(s) - - // Extract content between tags - var content string - // fmt.Println("prefix before is", *prefix) - if *prefix == "[{" || *prefix == "{" { - content = s[start:end] - } else { - content = s[start+len(*specialToken) : end] - } - content = strings.TrimSpace(content) - // fmt.Println("content", content) - - var toolCalls []api.ToolCall - - // Try parsing as JSON first - could be single object or array - var jsonObj any - if err := json.Unmarshal([]byte(content), &jsonObj); err == nil { - // Try as single object - if obj, ok := jsonObj.(map[string]any); ok { - // fmt.Println("obj", obj) - if calls, ok := parseJSONToolCalls(obj); ok { - toolCalls = append(toolCalls, calls...) - } + // Python Case + // We just do a full python check here + fmt.Println("Attempting Python function parse") + tc, ok := parsePythonFunctionCall(s) + if ok { + fmt.Printf("Successfully parsed Python function: %+v\n", tc) + return tc, false, true } - // Try as array of objects - if arr, ok := jsonObj.([]any); ok { - for _, item := range arr { - if obj, ok := item.(map[string]any); ok { - if calls, ok := parseJSONToolCalls(obj); ok { - toolCalls = append(toolCalls, calls...) - } - } - } - } - } else { - // TODO: review this case - // Check for partial JSON before trying Python style - if strings.HasPrefix(content, "{") || strings.HasPrefix(content, "[{") { - // We have an opening brace/bracket but failed to parse - likely partial JSON - return nil, true, true - } - - // Try parsing as Python function call - if toolCall, ok := parsePythonFunctionCall(content); ok { - toolCalls = append(toolCalls, toolCall) - } - } - - // Only return success if we found valid tool calls and no errors - if len(toolCalls) > 0 { - // Check if any of the tool calls are malformed - for _, call := range toolCalls { - if call.Function.Name == "" || len(call.Function.Arguments) == 0 { + // Tool Token Case - this is okay if it's a real tool token and we couldn't get from template + fmt.Println("Attempting to derive tool token") + if toolToken == nil || (toolToken != nil && *toolToken == "") { + toolTok, partial, ok := deriveToolToken(s, "[") + if !ok { return nil, false, false } + if partial { + return nil, true, true + } + *toolToken = toolTok } - return toolCalls, false, true - } - - // fmt.Println("no tool calls found, partial match", partial) - if partial { + fmt.Printf("Found tool token: %q\n", *toolToken) + s = strings.TrimSpace(s[len(*toolToken):]) + fmt.Printf("Recursing with remaining string: %q\n", s) + if toolCalls, partial, ok := m.ParseToolCallsNew(s, toolToken); ok { + return toolCalls, partial, true + } + return nil, true, true + } else if strings.HasPrefix(s, "{") || strings.HasPrefix(s, "```") { + fmt.Println("Found { prefix - attempting JSON parse") + if calls, ok := parseJSON(s); ok { + fmt.Printf("Successfully parsed JSON object, found %d calls\n", len(calls)) + return calls, false, true + } + // TODO: possible case where it never finishes parsing - then what? + return nil, true, true + } else if strings.HasPrefix(s, "<") { + fmt.Println("Found < prefix - attempting to derive tool token") + if toolToken == nil || (toolToken != nil && *toolToken == "") { + toolTok, partial, ok := deriveToolToken(s, "<") + if !ok { + return nil, false, false + } + if partial { + return nil, true, true + } + *toolToken = toolTok + fmt.Printf("Found tool token: %q\n", toolToken) + } + fmt.Printf("Found tool token: %q\n", *toolToken) + s = strings.TrimSpace(s[len(*toolToken):]) + fmt.Printf("Recursing with remaining string: %q\n", s) + if toolCalls, partial, ok := m.ParseToolCallsNew(s, toolToken); ok { + return toolCalls, partial, true + } + return nil, true, true + } else if strings.Contains(s, "(") || len(strings.Fields(s)) == 1 { + fmt.Println("Attempting Python function parse") + tc, ok := parsePythonFunctionCall(s) + if ok { + fmt.Printf("Successfully parsed Python function: %+v\n", tc) + return tc, false, true + } + fmt.Printf("Failed to parse Python function: %q, returning partial", s) return nil, true, true } + fmt.Println("No successful parse paths found") + fmt.Printf("failed string: %q\n", s) return nil, false, false } diff --git a/server/routes.go b/server/routes.go index 554dc4daf..63214eb66 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1527,15 +1527,16 @@ func (s *Server) ChatHandler(c *gin.Context) { var sb strings.Builder var toolCallIndex int = 0 var sentWithTools int = 0 - var prefix string - // var specialToken string - _, specialToken, _ := m.GetToolCallFormat(sb.String()) + // var prefix string + // var templateToolToken string + _, templateToolToken, _ := m.GetToolCallFormat() + fmt.Println("special token", templateToolToken) var minDuration time.Duration = math.MaxInt64 var maxDuration time.Duration var totalDuration time.Duration var checkCount int - const MAX_TOOL_TOKENS = 6 + const maxToolTokens = 1 if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1561,9 +1562,9 @@ func (s *Server) ChatHandler(c *gin.Context) { slog.Debug("total duration", "duration", totalDuration) slog.Debug("check count", "count", checkCount) // slog.Debug("average duration", "duration", totalDuration/time.Duration(checkCount)) - // if sb.Len() > 0 { - // res.Message.Content = sb.String() - // } + if sb.Len() > 0 { + res.Message.Content = sb.String() + } res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -1585,8 +1586,8 @@ func (s *Server) ChatHandler(c *gin.Context) { // TODO: if we are deriving the tool token, then a heuristic must be applied to stream eventually // TODO: if the prefix check fails, send the content downstream and reset the builder startTime := time.Now() - if len(req.Tools) > 0 && sentWithTools < MAX_TOOL_TOKENS { - toolCalls, partial, ok := m.ParseToolCallsStream(sb.String(), &prefix, &specialToken) + if len(req.Tools) > 0 && sentWithTools < maxToolTokens { + toolCalls, partial, ok := m.ParseToolCallsNew(sb.String(), &templateToolToken) duration := time.Since(startTime) checkCount++ minDuration = min(minDuration, duration) @@ -1605,8 +1606,8 @@ func (s *Server) ChatHandler(c *gin.Context) { toolCallIndex++ } sentWithTools = 0 - prefix = "" - specialToken = "" + // prefix = "" + templateToolToken = "" res.Message.Content = "" sb.Reset() ch <- res @@ -1630,11 +1631,10 @@ func (s *Server) ChatHandler(c *gin.Context) { var resp api.ChatResponse var sb strings.Builder var toolCalls []api.ToolCall - var prefix string - var specialToken string - const MAX_TOOL_TOKENS = 6 + const MAX_TOOL_TOKENS = 1 sentWithTools := 0 var tb strings.Builder + _, templateToolToken, _ := m.GetToolCallFormat() for rr := range ch { switch t := rr.(type) { case api.ChatResponse: @@ -1642,14 +1642,12 @@ func (s *Server) ChatHandler(c *gin.Context) { resp = t if len(req.Tools) > 0 && sentWithTools < MAX_TOOL_TOKENS { tb.WriteString(t.Message.Content) - if tcs, partial, ok := m.ParseToolCallsStream(tb.String(), &prefix, &specialToken); ok { + if tcs, partial, ok := m.ParseToolCallsNew(tb.String(), &templateToolToken); ok { if !partial { // resp.Message.ToolCalls = toolCalls toolCalls = append(toolCalls, tcs...) resp.Message.Content = "" tb.Reset() - prefix = "" - specialToken = "" } } else { // equivalent to no partial - send the content downstream