From b4cd1118abd932a0e754e7f35b3dd2e62694ceaa Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 24 Apr 2025 18:23:23 -0700 Subject: [PATCH] checkpoint for vscode --- server/model.go | 279 +++++++++++++++-------------------------------- server/routes.go | 25 +++-- 2 files changed, 101 insertions(+), 203 deletions(-) diff --git a/server/model.go b/server/model.go index 5d9609b45..c237b0aab 100644 --- a/server/model.go +++ b/server/model.go @@ -154,109 +154,99 @@ func parseObjects(s string) []map[string]any { return objs } -// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. -// mxyng: this only really works if the input contains tool calls in some JSON format -func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { - // create a subtree from the node that ranges over .ToolCalls +// Get tool call token from model template +func (m *Model) TemplateToolToken() (string, string, bool) { + // Try to detect the tool call format from the model's template tmpl := m.Template.Subtree(func(n parse.Node) bool { if t, ok := n.(*parse.RangeNode); ok { return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") } - return false }) - if tmpl == nil { - slog.Debug("parseToolCalls: no ToolCalls template found") - return nil, false - } - - slog.Debug("parseToolCalls: executing template with test data", "input", s) - - var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ - "ToolCalls": { - { - Function: api.ToolCallFunction{ - Name: "@@name@@", - Arguments: api.ToolCallFunctionArguments{ - "@@argument@@": 1, + // fmt.Println("tool call template", tmpl) + if tmpl != nil { + // Execute template with test data to see the format + var b bytes.Buffer + if err := tmpl.Execute(&b, map[string][]api.ToolCall{ + "ToolCalls": { + { + Function: api.ToolCallFunction{ + Name: "function_name", + Arguments: api.ToolCallFunctionArguments{ + "argument1": "value1", + // "argument2": "value2", + }, }, }, }, - }, - }); err != nil { - slog.Debug("parseToolCalls: template execution failed", "error", err) - return nil, false - } - - slog.Debug("parseToolCalls: template executed successfully", "output", b.String()) - - templateObjects := parseObjects(b.String()) - if len(templateObjects) == 0 { - return nil, false - } - - slog.Debug("parseToolCalls: template objects", "objects", templateObjects) - - // find the keys that correspond to the name and arguments fields - var name, arguments string - for k, v := range templateObjects[0] { - switch v.(type) { - case string: - name = k - case map[string]any: - arguments = k - } - } - - if name == "" || arguments == "" { - return nil, false - } - - responseObjects := parseObjects(s) - if len(responseObjects) == 0 { - return nil, false - } - - // collect all nested objects - var collect func(any) []map[string]any - collect = func(obj any) (all []map[string]any) { - switch o := obj.(type) { - case map[string]any: - all = append(all, o) - for _, v := range o { - all = append(all, collect(v)...) - } - case []any: - for _, v := range o { - all = append(all, collect(v)...) + }); err == nil { + // Look for special tokens in the template output + output := strings.TrimSpace(b.String()) + slog.Debug("tool call template output", "output", output) + if strings.Contains(output, "<") { + // Extract the special token between < and > + start := strings.Index(output, "<") + end := strings.Index(output, ">") + if start >= 0 && end > start { + token := output[start : end+1] + return output, token, true + } + } else if strings.Contains(output, "[") { + // Check if it's a tool call token rather than JSON array + start := strings.Index(output, "[") + end := strings.Index(output, "]") + if start >= 0 && end > start { + token := output[start : end+1] + // Only consider it a token if it's not valid JSON + var jsonTest any + if err := json.Unmarshal([]byte(token), &jsonTest); err != nil { + return output, token, true + } + } } } - - return all } + return "", "", false +} - var objs []map[string]any - for _, p := range responseObjects { - objs = append(objs, collect(p)...) +func parsePythonFunctionCall(s string) ([]api.ToolCall, bool) { + re := regexp.MustCompile(`(\w+)\((.*?)\)`) + matches := re.FindAllStringSubmatchIndex(s, -1) + if len(matches) == 0 { + return nil, false } var toolCalls []api.ToolCall - for _, kv := range objs { - n, nok := kv[name].(string) - a, aok := kv[arguments].(map[string]any) - if nok && aok { + for _, match := range matches { + name := s[match[2]:match[3]] + args := s[match[4]:match[5]] + + arguments := make(api.ToolCallFunctionArguments) + if strings.Contains(args, "=") { // Keyword args + pairs := strings.SplitSeq(args, ",") + for pair := range pairs { + pair = strings.TrimSpace(pair) + kv := strings.Split(pair, "=") + if len(kv) == 2 { + key := strings.TrimSpace(kv[0]) + value := strings.TrimSpace(kv[1]) + arguments[key] = value + } + } toolCalls = append(toolCalls, api.ToolCall{ Function: api.ToolCallFunction{ - Name: n, - Arguments: a, + Name: name, + Arguments: arguments, }, }) } } - return toolCalls, len(toolCalls) > 0 + if len(toolCalls) > 0 { + return toolCalls, true + } + return nil, false } // ToolCallFormat represents different possible formats for tool calls @@ -377,100 +367,6 @@ func parseJSONToolCalls(obj map[string]any) ([]api.ToolCall, bool) { return nil, false } -func (m *Model) GetToolCallFormat() (string, string, bool) { - // Try to detect the tool call format from the model's template - tmpl := m.Template.Subtree(func(n parse.Node) bool { - if t, ok := n.(*parse.RangeNode); ok { - return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") - } - return false - }) - - fmt.Println("tool call template", tmpl) - if tmpl != nil { - // Execute template with test data to see the format - var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ - "ToolCalls": { - { - Function: api.ToolCallFunction{ - Name: "function_name", - Arguments: api.ToolCallFunctionArguments{ - "argument1": "value1", - // "argument2": "value2", - }, - }, - }, - }, - }); err == nil { - // Look for special tokens in the template output - output := strings.TrimSpace(b.String()) - slog.Debug("tool call template output", "output", output) - if strings.Contains(output, "<") { - // Extract the special token between < and > - start := strings.Index(output, "<") - end := strings.Index(output, ">") - if start >= 0 && end > start { - token := output[start : end+1] - return output, token, true - } - } else if strings.Contains(output, "[") { - // Check if it's a tool call token rather than JSON array - start := strings.Index(output, "[") - end := strings.Index(output, "]") - if start >= 0 && end > start { - token := output[start : end+1] - // Only consider it a token if it's not valid JSON - var jsonTest any - if err := json.Unmarshal([]byte(token), &jsonTest); err != nil { - return output, token, true - } - } - } - } - } - return "", "", false -} - -func parsePythonFunctionCall(s string) ([]api.ToolCall, bool) { - re := regexp.MustCompile(`(\w+)\((.*?)\)`) - matches := re.FindAllStringSubmatchIndex(s, -1) - if len(matches) == 0 { - return nil, false - } - - var toolCalls []api.ToolCall - for _, match := range matches { - name := s[match[2]:match[3]] - args := s[match[4]:match[5]] - - arguments := make(api.ToolCallFunctionArguments) - if strings.Contains(args, "=") { // Keyword args - pairs := strings.SplitSeq(args, ",") - for pair := range pairs { - pair = strings.TrimSpace(pair) - kv := strings.Split(pair, "=") - if len(kv) == 2 { - key := strings.TrimSpace(kv[0]) - value := strings.TrimSpace(kv[1]) - arguments[key] = value - } - } - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: name, - Arguments: arguments, - }, - }) - } - } - - if len(toolCalls) > 0 { - return toolCalls, true - } - return nil, false -} - // token, partial, success func deriveToolToken(s string, prefix string) (string, bool, bool) { // There shouldn't be spaces in a tool token @@ -488,27 +384,21 @@ func deriveToolToken(s string, prefix string) (string, bool, bool) { func parseJSON(s string) ([]api.ToolCall, bool) { objs := parseObjects(s) - var toolCalls []api.ToolCall + tcs := []api.ToolCall{} for _, obj := range objs { - if n, nok := obj["name"].(string); nok { - if a, aok := obj["arguments"].(map[string]any); aok { - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: n, - Arguments: a, - }, - }) - } + toolCalls, ok := parseJSONToolCalls(obj) + if ok { + tcs = append(tcs, toolCalls...) } } - if len(toolCalls) > 0 { - return toolCalls, true + if len(tcs) > 0 { + return tcs, true } return nil, false } // returns tool calls, partial, success -func (m *Model) ParseToolCallsNew(s string, toolToken *string) ([]api.ToolCall, bool, bool) { +func (m *Model) ParseToolCalls(s string, toolToken *string) ([]api.ToolCall, bool, bool) { // [ case can either be JSON, Python or a Tool Token s = strings.TrimSpace(s) fmt.Printf("ParseToolCallsNew input: %q\n", s) @@ -539,7 +429,7 @@ func (m *Model) ParseToolCallsNew(s string, toolToken *string) ([]api.ToolCall, } // Tool Token Case - this is okay if it's a real tool token and we couldn't get from template fmt.Println("Attempting to derive tool token") - if toolToken == nil || (toolToken != nil && *toolToken == "") { + if toolToken == nil || *toolToken == "" { toolTok, partial, ok := deriveToolToken(s, "[") if !ok { return nil, false, false @@ -552,21 +442,26 @@ func (m *Model) ParseToolCallsNew(s string, toolToken *string) ([]api.ToolCall, fmt.Printf("Found tool token: %q\n", *toolToken) s = strings.TrimSpace(s[len(*toolToken):]) fmt.Printf("Recursing with remaining string: %q\n", s) - if toolCalls, partial, ok := m.ParseToolCallsNew(s, toolToken); ok { + if toolCalls, partial, ok := m.ParseToolCalls(s, toolToken); ok { return toolCalls, partial, true } return nil, true, true } else if strings.HasPrefix(s, "{") || strings.HasPrefix(s, "```") { - fmt.Println("Found { prefix - attempting JSON parse") + // // TODO: temp fix + // if strings.HasPrefix(s, "```") && len(s) == 3 { + // return nil, false, false + // } + fmt.Println("Found { prefix - attempting JSON parse with ", s) if calls, ok := parseJSON(s); ok { fmt.Printf("Successfully parsed JSON object, found %d calls\n", len(calls)) return calls, false, true } + fmt.Println("Failed to parse JSON in JSON case") // TODO: possible case where it never finishes parsing - then what? return nil, true, true } else if strings.HasPrefix(s, "<") { fmt.Println("Found < prefix - attempting to derive tool token") - if toolToken == nil || (toolToken != nil && *toolToken == "") { + if toolToken == nil || *toolToken == "" { toolTok, partial, ok := deriveToolToken(s, "<") if !ok { return nil, false, false @@ -575,12 +470,12 @@ func (m *Model) ParseToolCallsNew(s string, toolToken *string) ([]api.ToolCall, return nil, true, true } *toolToken = toolTok - fmt.Printf("Found tool token: %q\n", toolToken) + fmt.Printf("Found tool token: %q\n", *toolToken) } fmt.Printf("Found tool token: %q\n", *toolToken) s = strings.TrimSpace(s[len(*toolToken):]) fmt.Printf("Recursing with remaining string: %q\n", s) - if toolCalls, partial, ok := m.ParseToolCallsNew(s, toolToken); ok { + if toolCalls, partial, ok := m.ParseToolCalls(s, toolToken); ok { return toolCalls, partial, true } return nil, true, true diff --git a/server/routes.go b/server/routes.go index 63214eb66..4afc3c613 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1529,8 +1529,8 @@ func (s *Server) ChatHandler(c *gin.Context) { var sentWithTools int = 0 // var prefix string // var templateToolToken string - _, templateToolToken, _ := m.GetToolCallFormat() - fmt.Println("special token", templateToolToken) + _, templateToolToken, _ := m.TemplateToolToken() + // fmt.Println("special token", templateToolToken) var minDuration time.Duration = math.MaxInt64 var maxDuration time.Duration @@ -1562,9 +1562,9 @@ func (s *Server) ChatHandler(c *gin.Context) { slog.Debug("total duration", "duration", totalDuration) slog.Debug("check count", "count", checkCount) // slog.Debug("average duration", "duration", totalDuration/time.Duration(checkCount)) - if sb.Len() > 0 { - res.Message.Content = sb.String() - } + // if sb.Len() > 0 { + // res.Message.Content = sb.String() + // } res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -1582,12 +1582,10 @@ func (s *Server) ChatHandler(c *gin.Context) { // If tools are recognized, use a flag to track the sending of a tool downstream // This ensures that content is cleared from the message on the last chunk sent sb.WriteString(r.Content) - // TODO: here we want to prefix check the tool ideally or derive the tool token from the model - // TODO: if we are deriving the tool token, then a heuristic must be applied to stream eventually - // TODO: if the prefix check fails, send the content downstream and reset the builder startTime := time.Now() + // TODO: work max tool tok logic if len(req.Tools) > 0 && sentWithTools < maxToolTokens { - toolCalls, partial, ok := m.ParseToolCallsNew(sb.String(), &templateToolToken) + toolCalls, partial, ok := m.ParseToolCalls(sb.String(), &templateToolToken) duration := time.Since(startTime) checkCount++ minDuration = min(minDuration, duration) @@ -1600,6 +1598,7 @@ func (s *Server) ChatHandler(c *gin.Context) { // If the tool call is partial, we need to wait for the next chunk return } + slog.Debug("toolCalls", "toolCalls", toolCalls, "partial", partial, "ok", ok) res.Message.ToolCalls = toolCalls for i := range toolCalls { toolCalls[i].Function.Index = toolCallIndex @@ -1611,6 +1610,9 @@ func (s *Server) ChatHandler(c *gin.Context) { res.Message.Content = "" sb.Reset() ch <- res + // TODO: revisit this + sentWithTools++ + slog.Debug("fired on tool call", "toolCalls", toolCalls, "toolCallIndex", toolCallIndex) return } } @@ -1634,15 +1636,16 @@ func (s *Server) ChatHandler(c *gin.Context) { const MAX_TOOL_TOKENS = 1 sentWithTools := 0 var tb strings.Builder - _, templateToolToken, _ := m.GetToolCallFormat() + _, templateToolToken, _ := m.TemplateToolToken() for rr := range ch { switch t := rr.(type) { case api.ChatResponse: sb.WriteString(t.Message.Content) resp = t + // TODO: work max tool tok logic if len(req.Tools) > 0 && sentWithTools < MAX_TOOL_TOKENS { tb.WriteString(t.Message.Content) - if tcs, partial, ok := m.ParseToolCallsNew(tb.String(), &templateToolToken); ok { + if tcs, partial, ok := m.ParseToolCalls(tb.String(), &templateToolToken); ok { if !partial { // resp.Message.ToolCalls = toolCalls toolCalls = append(toolCalls, tcs...)