diff --git a/api/types.go b/api/types.go index 7d8b6e532..59e5b6d7b 100644 --- a/api/types.go +++ b/api/types.go @@ -83,6 +83,10 @@ type GenerateRequest struct { // Options lists model-specific options. For example, temperature can be // set through this field, if the model supports it. Options map[string]any `json:"options"` + + // Thinking controls whether thinking/reasoning models will think before + // responding + Thinking bool `json:"thinking,omitempty"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -108,6 +112,10 @@ type ChatRequest struct { // Options lists model-specific options. Options map[string]any `json:"options"` + + // Thinking controls whether thinking/reasoning models will think before + // responding + Thinking bool `json:"thinking,omitempty"` } type Tools []Tool @@ -130,6 +138,10 @@ type Message struct { Content string `json:"content"` Images []ImageData `json:"images,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + // ThinkingBlock contains the text that was inside tags in the + // original model output when ChatRequest.Thinking was enabled. + ThinkingBlock string `json:"thinkingBlock,omitempty"` } func (m *Message) UnmarshalJSON(b []byte) error { @@ -275,6 +287,8 @@ type Options struct { MirostatTau float32 `json:"mirostat_tau,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"` Stop []string `json:"stop,omitempty"` + + Thinking bool `json:"thinking,omitempty"` } // Runner options which must be set when the model is loaded into memory diff --git a/cmd/cmd.go b/cmd/cmd.go index 79ff87ac8..792cdb96f 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -38,12 +38,33 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" + "github.com/ollama/ollama/readline" "github.com/ollama/ollama/runner" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) +// warnMissingThinking emits a warning if the model does not advertise thinking +// support and opts.Thinking is set. Failures to query the capability are +// ignored so this does not impact regular usage. +func warnMissingThinking(ctx context.Context, client *api.Client, name string) { + fmt.Printf("$$$$$ warnMissingThinking %q\n", name) + if name == "" { + return + } + resp, err := client.Show(ctx, &api.ShowRequest{Model: name}) + if err != nil { + return + } + for _, cap := range resp.Capabilities { + if cap == model.CapabilityThinking { + return + } + } + fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name) +} + var errModelfileNotFound = errors.New("specified Modelfile wasn't found") func getModelfileName(cmd *cobra.Command) (string, error) { @@ -243,6 +264,9 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { req := &api.GenerateRequest{ Model: opts.Model, KeepAlive: opts.KeepAlive, + + // pass Thinking here so we fail before getting to the chat prompt if the model doesn't support it + Thinking: opts.Thinking, } return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) @@ -277,6 +301,12 @@ func RunHandler(cmd *cobra.Command, args []string) error { } opts.Format = format + thinkingFlag, err := cmd.Flags().GetBool("thinking") + if err != nil { + return err + } + opts.Thinking = thinkingFlag + keepAlive, err := cmd.Flags().GetString("keepalive") if err != nil { return err @@ -361,6 +391,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { if err := loadOrUnloadModel(cmd, &opts); err != nil { return err } + warnMissingThinking(cmd.Context(), client, opts.Model) for _, msg := range info.Messages { switch msg.Role { @@ -876,6 +907,7 @@ type runOptions struct { Options map[string]any MultiModal bool KeepAlive *api.Duration + Thinking bool } type displayResponseState struct { @@ -958,6 +990,8 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { var latest api.ChatResponse var fullResponse strings.Builder var role string + var thinkTagOpened bool = false + var thinkTagClosed bool = false fn := func(response api.ChatResponse) error { p.StopAndClear() @@ -965,7 +999,23 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { latest = response role = response.Message.Role + if response.Message.ThinkingBlock != "" { + if !thinkTagOpened { + fmt.Print(readline.ColorGrey + readline.ColorBold + "" + readline.ColorDefault + readline.ColorGrey) + thinkTagOpened = true + } + displayResponse(response.Message.ThinkingBlock, opts.WordWrap, state) + } + content := response.Message.Content + if !thinkTagClosed && thinkTagOpened && content != "" { + fmt.Print(readline.ColorGrey + readline.ColorBold + "" + readline.ColorDefault) + thinkTagClosed = true + } + // purposefully not putting thinking blocks in the response, which would + // only be needed if we later added tool calling to the cli (they get + // filtered out anyway since current models don't expect them unless you're + // about to finish some tool calls) fullResponse.WriteString(content) displayResponse(content, opts.WordWrap, state) @@ -982,6 +1032,11 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { Messages: opts.Messages, Format: json.RawMessage(opts.Format), Options: opts.Options, + Thinking: opts.Thinking, + } + + if opts.Thinking { + warnMissingThinking(cmd.Context(), client, opts.Model) } if opts.KeepAlive != nil { @@ -1075,6 +1130,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { System: opts.System, Options: opts.Options, KeepAlive: opts.KeepAlive, + Thinking: opts.Thinking, } if err := client.Generate(ctx, &request, fn); err != nil { @@ -1290,6 +1346,8 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().String("format", "", "Response format (e.g. json)") + // TODO(drifkin): what should happen for an unsupported model? Warning? Fail hard? + runCmd.Flags().Bool("thinking", false, "Turn on thinking mode for supported models") stopCmd := &cobra.Command{ Use: "stop MODEL", diff --git a/cmd/interactive.go b/cmd/interactive.go index 82a3bfcbe..1edec51da 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " /set noformat Disable formatting") fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats") fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats") + fmt.Fprintln(os.Stderr, " /set thinking Enable thinking") + fmt.Fprintln(os.Stderr, " /set nothinking Disable thinking") fmt.Fprintln(os.Stderr, "") } @@ -260,6 +262,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { return err } fmt.Println("Set 'quiet' mode.") + case "thinking": + opts.Thinking = true + if client, err := api.ClientFromEnvironment(); err == nil { + warnMissingThinking(cmd.Context(), client, opts.Model) + } + fmt.Println("Set 'thinking' mode.") + case "nothinking": + opts.Thinking = false + fmt.Println("Set 'nothinking' mode.") case "format": if len(args) < 3 || args[2] != "json" { fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") diff --git a/cmd/warn_thinking_test.go b/cmd/warn_thinking_test.go new file mode 100644 index 000000000..8c084a922 --- /dev/null +++ b/cmd/warn_thinking_test.go @@ -0,0 +1,64 @@ +package cmd + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/types/model" +) + +// Test that a warning is printed when thinking is requested but not supported. +func TestWarnMissingThinking(t *testing.T) { + cases := []struct { + capabilities []model.Capability + expectWarn bool + }{ + {capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false}, + {capabilities: []model.Capability{}, expectWarn: true}, + } + + for _, tc := range cases { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/show" || r.Method != http.MethodPost { + t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method) + } + var req api.ShowRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + resp := api.ShowResponse{Capabilities: tc.capabilities} + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("encode response: %v", err) + } + })) + defer srv.Close() + + t.Setenv("OLLAMA_HOST", srv.URL) + client, err := api.ClientFromEnvironment() + if err != nil { + t.Fatal(err) + } + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + warnMissingThinking(context.Background(), client, "m") + w.Close() + os.Stderr = oldStderr + out, _ := io.ReadAll(r) + + warned := strings.Contains(string(out), "warning:") + if tc.expectWarn && !warned { + t.Errorf("expected warning, got none") + } + if !tc.expectWarn && warned { + t.Errorf("did not expect warning, got: %s", string(out)) + } + } +} diff --git a/readline/types.go b/readline/types.go index e136d9962..f4efa8d92 100644 --- a/readline/types.go +++ b/readline/types.go @@ -61,6 +61,8 @@ const ( ColorGrey = Esc + "[38;5;245m" ColorDefault = Esc + "[0m" + ColorBold = Esc + "[1m" + StartBracketedPaste = Esc + "[?2004h" EndBracketedPaste = Esc + "[?2004l" ) diff --git a/server/images.go b/server/images.go index be629f4cb..5d2e3cd06 100644 --- a/server/images.go +++ b/server/images.go @@ -37,6 +37,7 @@ var ( errCapabilityInsert = errors.New("insert") errCapabilityVision = errors.New("vision") errCapabilityEmbedding = errors.New("embedding") + errCapabilityThinking = errors.New("thinking") errInsecureProtocol = errors.New("insecure protocol http") ) @@ -106,6 +107,11 @@ func (m *Model) Capabilities() []model.Capability { capabilities = append(capabilities, model.CapabilityInsert) } + // Check for thinking capability + if slices.Contains(m.Template.Vars(), "thinking") { + capabilities = append(capabilities, model.CapabilityThinking) + } + return capabilities } @@ -122,6 +128,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error { model.CapabilityInsert: errCapabilityInsert, model.CapabilityVision: errCapabilityVision, model.CapabilityEmbedding: errCapabilityEmbedding, + model.CapabilityThinking: errCapabilityThinking, } for _, cap := range want { diff --git a/server/prompt.go b/server/prompt.go index 5b5b958f1..23541357c 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -22,7 +22,7 @@ var errTooManyImages = errors.New("vision model only supports a single image per // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages -func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { +func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, thinking bool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message isMllama := checkMllamaModelFamily(m) @@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Thinking: thinking}); err != nil { return "", nil, err } @@ -142,7 +142,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. // truncate any messages that do not fit into the context window var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Thinking: thinking}); err != nil { return "", nil, err } diff --git a/server/prompt_test.go b/server/prompt_test.go index 62aec86a9..498d6cfe8 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -318,7 +318,7 @@ func TestChatPrompt(t *testing.T) { t.Run(tt.name, func(t *testing.T) { model := tt.model opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} - prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) + prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, false) if tt.error == nil && err != nil { t.Fatal(err) } else if tt.error != nil && err != tt.error { diff --git a/server/routes.go b/server/routes.go index 16f22cf93..04e5b6905 100644 --- a/server/routes.go +++ b/server/routes.go @@ -18,7 +18,6 @@ import ( "os" "os/signal" "path/filepath" - "regexp" "slices" "strings" "syscall" @@ -181,6 +180,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { if req.Suffix != "" { caps = append(caps, model.CapabilityInsert) } + if req.Thinking { + caps = append(caps, model.CapabilityThinking) + } r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { @@ -1475,6 +1477,9 @@ func (s *Server) ChatHandler(c *gin.Context) { if len(req.Tools) > 0 { caps = append(caps, model.CapabilityTools) } + if req.Thinking { + caps = append(caps, model.CapabilityThinking) + } name := model.ParseName(req.Model) if !name.IsValid() { @@ -1515,7 +1520,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Thinking) if err != nil { slog.Error("chat prompt error", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -1529,6 +1534,10 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) var sb strings.Builder var toolCallIndex int = 0 + var thinkingState thinkingParser = thinkingParser{ + openingTag: "", + closingTag: "", + } if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1548,6 +1557,16 @@ func (s *Server) ChatHandler(c *gin.Context) { }, } + if req.Thinking { + thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content) + if thinkingContent == "" && remainingContent == "" && !r.Done { + // need to accumulate more to decide what to send + return + } + res.Message.Content = remainingContent + res.Message.ThinkingBlock = thinkingContent + } + if r.Done { res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) @@ -1565,7 +1584,7 @@ func (s *Server) ChatHandler(c *gin.Context) { // Streaming tool calls: // If tools are recognized, use a flag to track the sending of a tool downstream // This ensures that content is cleared from the message on the last chunk sent - sb.WriteString(r.Content) + sb.WriteString(res.Message.Content) if toolCalls, ok := m.parseToolCalls(sb.String()); ok { res.Message.ToolCalls = toolCalls for i := range toolCalls { @@ -1613,9 +1632,12 @@ func (s *Server) ChatHandler(c *gin.Context) { } resp.Message.Content = sb.String() + if req.Thinking { + resp.Message.ThinkingBlock, resp.Message.Content = extractThinking(resp.Message.Content) + } if len(req.Tools) > 0 { - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { + if toolCalls, ok := m.parseToolCalls(resp.Message.Content); ok { resp.Message.ToolCalls = toolCalls resp.Message.Content = "" } @@ -1643,7 +1665,16 @@ func handleScheduleError(c *gin.Context, name string, err error) { } } -var thinkTagRegexp = regexp.MustCompile(`(?s).*?(\n)*`) +// returns (thinkingContent, content) +func extractThinking(text string) (string, string) { + thinking := thinkingParser{ + openingTag: "", + closingTag: "", + } + + thinkingContent, content := thinking.addContent(text) + return thinkingContent, content +} func filterThinkTags(msgs []api.Message, m *Model) []api.Message { if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" { @@ -1656,7 +1687,9 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { for i, msg := range msgs { if msg.Role == "assistant" && i < finalUserIndex { - msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "") + thinking, content := extractThinking(msg.Content) + msg.Content = content + msg.ThinkingBlock = thinking } } } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 56121d41b..01f906d09 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -143,6 +143,24 @@ func TestGenerateChat(t *testing.T) { } }) + t.Run("missing thinking capability", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + Thinking: true, + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + t.Run("missing model", func(t *testing.T) { w := createRequest(t, s.ChatHandler, api.ChatRequest{}) if w.Code != http.StatusBadRequest { diff --git a/server/thinking.go b/server/thinking.go new file mode 100644 index 000000000..d35c8dd27 --- /dev/null +++ b/server/thinking.go @@ -0,0 +1,127 @@ +package server + +import ( + "strings" + "unicode" +) + +type thinkingParseState int + +const ( + thinkingParseState_LookingForOpening thinkingParseState = iota + thinkingParseState_Thinking + thinkingParseState_ThinkingDone +) + +func (s thinkingParseState) String() string { + switch s { + case thinkingParseState_LookingForOpening: + return "LookingForOpening" + case thinkingParseState_Thinking: + return "Thinking" + case thinkingParseState_ThinkingDone: + return "ThinkingDone" + default: + return "Unknown" + } +} + +type thinkingParser struct { + state thinkingParseState + openingTag string + closingTag string + acc strings.Builder +} + +// returns the thinking content and the normal content that should be +// immediately sent to the user. It will internally buffer if it needs to see +// more content to disambiguate +func (s *thinkingParser) addContent(content string) (string, string) { + s.acc.WriteString(content) + + var thinkingAcc, remainingAcc strings.Builder + + var thinking, remaining string + keepLooping := true + // we loop because we might pass through multiple parsing states in a single + // call to addContent, and we want to make sure callers don't have to wait for + // data that's already unambiguous + for keepLooping { + thinking, remaining, keepLooping = eat(s) + thinkingAcc.WriteString(thinking) + remainingAcc.WriteString(remaining) + } + + return thinkingAcc.String(), remainingAcc.String() +} + +// the additional bool return is true iff we should continue eating +func eat(s *thinkingParser) (string, string, bool) { + switch s.state { + case thinkingParseState_LookingForOpening: + trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace) + if strings.HasPrefix(trimmed, s.openingTag) { + after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag) + after = strings.TrimLeftFunc(after, unicode.IsSpace) + // after might contain more than just thinking tokens, so we continue + // parsing instead of returning it as thinking tokens here + s.acc.Reset() + s.acc.WriteString(after) + s.state = thinkingParseState_Thinking + return "", "", true + } else if strings.HasPrefix(s.openingTag, trimmed) { + // partial opening seen, so let's keep accumulating + return "", "", false + } else if trimmed == "" { + // saw whitespace only, so let's keep accumulating + return "", "", false + } else { + // didn't see an opening tag, but we have content, so thinking was skipped + s.state = thinkingParseState_ThinkingDone + // note that we use the original content, not the trimmed one because we + // don't want to eat any whitespace in the real content if there were no + // thinking tags + return "", s.acc.String(), false + } + case thinkingParseState_Thinking: + acc := s.acc.String() + if strings.Contains(acc, s.closingTag) { + split := strings.Split(acc, s.closingTag) + thinking := split[0] + remaining := strings.Join(split[1:], s.closingTag) + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + s.acc.Reset() + s.state = thinkingParseState_ThinkingDone + return thinking, remaining, false + } else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 { + thinking := acc[:len(acc)-overlapLen] + remaining := acc[len(acc)-overlapLen:] + s.acc.Reset() + // keep track of the candidate closing tag. We have to buffer it until it + // becomes disambiguated + s.acc.WriteString(remaining) + return thinking, "", false + } else { + // purely just thinking tokens, so we can return them + s.acc.Reset() + return acc, "", false + } + case thinkingParseState_ThinkingDone: + acc := s.acc.String() + s.acc.Reset() + return "", acc, false + default: + panic("unknown state") + } +} + +// longest overlap between suffix of s and prefix of delim +func overlap(s, delim string) int { + max := min(len(delim), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, delim[:i]) { + return i + } + } + return 0 +} diff --git a/server/thinking_test.go b/server/thinking_test.go new file mode 100644 index 000000000..4123f13f3 --- /dev/null +++ b/server/thinking_test.go @@ -0,0 +1,161 @@ +package server + +import ( + "testing" +) + +func TestExtractThinking(t *testing.T) { + tests := []struct { + in, wantContent, wantThink string + }{ + { + in: " internal world", + wantThink: "internal ", + wantContent: "world", + }, + { + in: "abc", + wantThink: "a", + wantContent: "bc", + }, + { + in: "no think", + wantThink: "", + wantContent: "no think", + }, + } + for i, tt := range tests { + gotThinking, gotContent := extractThinking(tt.in) + if gotContent != tt.wantContent || gotThinking != tt.wantThink { + t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent) + } + } +} + +func TestThinkingStreaming(t *testing.T) { + + type step struct { + input string + wantThinking string + wantContent string + wantStateAfter thinkingParseState + } + + cases := []struct { + desc string + skip bool + steps []step + }{ + { + desc: "content without a thinking tag", + steps: []step{ + { + input: " abc", + wantThinking: "", + wantContent: " abc", + wantStateAfter: thinkingParseState_ThinkingDone, + }, + }, + }, + { + desc: "content before a thinking tag nerfs the thinking tag", + steps: []step{ + { + input: " abc def ghi", + wantThinking: "", + wantContent: " abc def ghi", + wantStateAfter: thinkingParseState_ThinkingDone, + }, + }, + }, + { + desc: "building up a thinking tag partially", + // skip: true, + steps: []step{ + { + input: " a", + wantThinking: "a", + wantContent: "", + wantStateAfter: thinkingParseState_Thinking, + }, + }, + }, + { + desc: "partial closing tag", + steps: []step{ + { + input: "abcdef", + wantThinking: "", + wantContent: "def", + wantStateAfter: thinkingParseState_ThinkingDone, + }, + }, + }, + { + desc: "partial closing tag fakeout", + steps: []step{ + { + input: "abcdef", + wantThinking: "def", + wantContent: "", + wantStateAfter: thinkingParseState_Thinking, + }, + { + input: "ghijkl", + wantThinking: "", + wantContent: "jkl", + wantStateAfter: thinkingParseState_ThinkingDone, + }, + }, + }, + } + + for _, c := range cases { + parser := thinkingParser{ + openingTag: "", + closingTag: "", + } + if c.skip { + continue + } + for i, step := range c.steps { + thinking, content := parser.addContent(step.input) + if content != step.wantContent || thinking != step.wantThinking { + t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking) + } + if parser.state != step.wantStateAfter { + t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state.String(), step.wantStateAfter.String()) + } + } + } +} diff --git a/template/template.go b/template/template.go index 5c886cac4..138e8cb41 100644 --- a/template/template.go +++ b/template/template.go @@ -165,8 +165,9 @@ func (t *Template) Vars() []string { type Values struct { Messages []api.Message api.Tools - Prompt string - Suffix string + Prompt string + Suffix string + Thinking bool // forceLegacy is a flag used to test compatibility with legacy templates forceLegacy bool @@ -225,6 +226,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { "Prompt": v.Prompt, "Suffix": v.Suffix, "Response": "", + "Thinking": v.Thinking, }) } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ @@ -232,6 +234,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { "Messages": messages, "Tools": v.Tools, "Response": "", + "Thinking": v.Thinking, }) } @@ -244,6 +247,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { "System": system, "Prompt": prompt, "Response": response, + "Thinking": v.Thinking, }); err != nil { return err } @@ -289,6 +293,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { "System": system, "Prompt": prompt, "Response": response, + "Thinking": v.Thinking, }); err != nil { return err } diff --git a/types/model/capability.go b/types/model/capability.go index fb8689403..cde23cee7 100644 --- a/types/model/capability.go +++ b/types/model/capability.go @@ -8,6 +8,7 @@ const ( CapabilityInsert = Capability("insert") CapabilityVision = Capability("vision") CapabilityEmbedding = Capability("embedding") + CapabilityThinking = Capability("thinking") ) func (c Capability) String() string {