WIP thinking API support

- Allows specifying whether thinking mode should be on or not
- Templates get passed a new option so, e.g., qwen3's template can put
  `/think` or `/no_think` in the system prompt depending on the value of
  the setting
- Add parsing for thinking blocks in both streaming/non-streaming mode
- Update the CLI to make use of these changes

TODO:

- [ ] Don't parse thinking blocks when the user doesn't explicitly set
      the option, to maintain backwards compatibility
- [ ] Warning on CLI when using a non-thinking/older version of a model
      (with an old template)
- [ ] Wire up capabilities fully
- [x] Unify parsing for streaming/non-streaming
- [ ] Update templates
- [ ] Update python/js libraries
- [ ] How to handle differences in models wrt defaults and whether or
      not the thinking ability can even be controlled. If not specified
      by the user, should there be a default or should the template be
      able to check if it was explicitly set?
This commit is contained in:
Devon Rifkin 2025-05-07 16:15:46 -07:00
parent a7835c6716
commit 77f4594e80
14 changed files with 513 additions and 12 deletions

View file

@ -83,6 +83,10 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be // Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it. // set through this field, if the model supports it.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
// Thinking controls whether thinking/reasoning models will think before
// responding
Thinking bool `json:"thinking,omitempty"`
} }
// ChatRequest describes a request sent by [Client.Chat]. // ChatRequest describes a request sent by [Client.Chat].
@ -108,6 +112,10 @@ type ChatRequest struct {
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
// Thinking controls whether thinking/reasoning models will think before
// responding
Thinking bool `json:"thinking,omitempty"`
} }
type Tools []Tool type Tools []Tool
@ -130,6 +138,10 @@ type Message struct {
Content string `json:"content"` Content string `json:"content"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
// ThinkingBlock contains the text that was inside <think> tags in the
// original model output when ChatRequest.Thinking was enabled.
ThinkingBlock string `json:"thinkingBlock,omitempty"`
} }
func (m *Message) UnmarshalJSON(b []byte) error { func (m *Message) UnmarshalJSON(b []byte) error {
@ -275,6 +287,8 @@ type Options struct {
MirostatTau float32 `json:"mirostat_tau,omitempty"` MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"`
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
Thinking bool `json:"thinking,omitempty"`
} }
// Runner options which must be set when the model is loaded into memory // Runner options which must be set when the model is loaded into memory

View file

@ -38,12 +38,33 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/runner" "github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
// warnMissingThinking emits a warning if the model does not advertise thinking
// support and opts.Thinking is set. Failures to query the capability are
// ignored so this does not impact regular usage.
func warnMissingThinking(ctx context.Context, client *api.Client, name string) {
fmt.Printf("$$$$$ warnMissingThinking %q\n", name)
if name == "" {
return
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityThinking {
return
}
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
var errModelfileNotFound = errors.New("specified Modelfile wasn't found") var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) { func getModelfileName(cmd *cobra.Command) (string, error) {
@ -243,6 +264,9 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
req := &api.GenerateRequest{ req := &api.GenerateRequest{
Model: opts.Model, Model: opts.Model,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
// pass Thinking here so we fail before getting to the chat prompt if the model doesn't support it
Thinking: opts.Thinking,
} }
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
@ -277,6 +301,12 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
opts.Format = format opts.Format = format
thinkingFlag, err := cmd.Flags().GetBool("thinking")
if err != nil {
return err
}
opts.Thinking = thinkingFlag
keepAlive, err := cmd.Flags().GetString("keepalive") keepAlive, err := cmd.Flags().GetString("keepalive")
if err != nil { if err != nil {
return err return err
@ -361,6 +391,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
return err return err
} }
warnMissingThinking(cmd.Context(), client, opts.Model)
for _, msg := range info.Messages { for _, msg := range info.Messages {
switch msg.Role { switch msg.Role {
@ -876,6 +907,7 @@ type runOptions struct {
Options map[string]any Options map[string]any
MultiModal bool MultiModal bool
KeepAlive *api.Duration KeepAlive *api.Duration
Thinking bool
} }
type displayResponseState struct { type displayResponseState struct {
@ -958,6 +990,8 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var latest api.ChatResponse var latest api.ChatResponse
var fullResponse strings.Builder var fullResponse strings.Builder
var role string var role string
var thinkTagOpened bool = false
var thinkTagClosed bool = false
fn := func(response api.ChatResponse) error { fn := func(response api.ChatResponse) error {
p.StopAndClear() p.StopAndClear()
@ -965,7 +999,23 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
latest = response latest = response
role = response.Message.Role role = response.Message.Role
if response.Message.ThinkingBlock != "" {
if !thinkTagOpened {
fmt.Print(readline.ColorGrey + readline.ColorBold + "<think>" + readline.ColorDefault + readline.ColorGrey)
thinkTagOpened = true
}
displayResponse(response.Message.ThinkingBlock, opts.WordWrap, state)
}
content := response.Message.Content content := response.Message.Content
if !thinkTagClosed && thinkTagOpened && content != "" {
fmt.Print(readline.ColorGrey + readline.ColorBold + "</think>" + readline.ColorDefault)
thinkTagClosed = true
}
// purposefully not putting thinking blocks in the response, which would
// only be needed if we later added tool calling to the cli (they get
// filtered out anyway since current models don't expect them unless you're
// about to finish some tool calls)
fullResponse.WriteString(content) fullResponse.WriteString(content)
displayResponse(content, opts.WordWrap, state) displayResponse(content, opts.WordWrap, state)
@ -982,6 +1032,11 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
Messages: opts.Messages, Messages: opts.Messages,
Format: json.RawMessage(opts.Format), Format: json.RawMessage(opts.Format),
Options: opts.Options, Options: opts.Options,
Thinking: opts.Thinking,
}
if opts.Thinking {
warnMissingThinking(cmd.Context(), client, opts.Model)
} }
if opts.KeepAlive != nil { if opts.KeepAlive != nil {
@ -1075,6 +1130,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
System: opts.System, System: opts.System,
Options: opts.Options, Options: opts.Options,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
Thinking: opts.Thinking,
} }
if err := client.Generate(ctx, &request, fn); err != nil { if err := client.Generate(ctx, &request, fn); err != nil {
@ -1290,6 +1346,8 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)") runCmd.Flags().String("format", "", "Response format (e.g. json)")
// TODO(drifkin): what should happen for an unsupported model? Warning? Fail hard?
runCmd.Flags().Bool("thinking", false, "Turn on thinking mode for supported models")
stopCmd := &cobra.Command{ stopCmd := &cobra.Command{
Use: "stop MODEL", Use: "stop MODEL",

View file

@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting") fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats") fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats") fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set thinking Enable thinking")
fmt.Fprintln(os.Stderr, " /set nothinking Disable thinking")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
} }
@ -260,6 +262,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
fmt.Println("Set 'quiet' mode.") fmt.Println("Set 'quiet' mode.")
case "thinking":
opts.Thinking = true
if client, err := api.ClientFromEnvironment(); err == nil {
warnMissingThinking(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'thinking' mode.")
case "nothinking":
opts.Thinking = false
fmt.Println("Set 'nothinking' mode.")
case "format": case "format":
if len(args) < 3 || args[2] != "json" { if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")

64
cmd/warn_thinking_test.go Normal file
View file

@ -0,0 +1,64 @@
package cmd
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
// Test that a warning is printed when thinking is requested but not supported.
func TestWarnMissingThinking(t *testing.T) {
cases := []struct {
capabilities []model.Capability
expectWarn bool
}{
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
{capabilities: []model.Capability{}, expectWarn: true},
}
for _, tc := range cases {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
}
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
resp := api.ShowResponse{Capabilities: tc.capabilities}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
client, err := api.ClientFromEnvironment()
if err != nil {
t.Fatal(err)
}
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
warnMissingThinking(context.Background(), client, "m")
w.Close()
os.Stderr = oldStderr
out, _ := io.ReadAll(r)
warned := strings.Contains(string(out), "warning:")
if tc.expectWarn && !warned {
t.Errorf("expected warning, got none")
}
if !tc.expectWarn && warned {
t.Errorf("did not expect warning, got: %s", string(out))
}
}
}

View file

@ -61,6 +61,8 @@ const (
ColorGrey = Esc + "[38;5;245m" ColorGrey = Esc + "[38;5;245m"
ColorDefault = Esc + "[0m" ColorDefault = Esc + "[0m"
ColorBold = Esc + "[1m"
StartBracketedPaste = Esc + "[?2004h" StartBracketedPaste = Esc + "[?2004h"
EndBracketedPaste = Esc + "[?2004l" EndBracketedPaste = Esc + "[?2004l"
) )

View file

@ -37,6 +37,7 @@ var (
errCapabilityInsert = errors.New("insert") errCapabilityInsert = errors.New("insert")
errCapabilityVision = errors.New("vision") errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding") errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking")
errInsecureProtocol = errors.New("insecure protocol http") errInsecureProtocol = errors.New("insecure protocol http")
) )
@ -106,6 +107,11 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityInsert) capabilities = append(capabilities, model.CapabilityInsert)
} }
// Check for thinking capability
if slices.Contains(m.Template.Vars(), "thinking") {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities return capabilities
} }
@ -122,6 +128,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityInsert: errCapabilityInsert, model.CapabilityInsert: errCapabilityInsert,
model.CapabilityVision: errCapabilityVision, model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding, model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking,
} }
for _, cap := range want { for _, cap := range want {

View file

@ -22,7 +22,7 @@ var errTooManyImages = errors.New("vision model only supports a single image per
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages // latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, thinking bool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message var system []api.Message
isMllama := checkMllamaModelFamily(m) isMllama := checkMllamaModelFamily(m)
@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Thinking: thinking}); err != nil {
return "", nil, err return "", nil, err
} }
@ -142,7 +142,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Thinking: thinking}); err != nil {
return "", nil, err return "", nil, err
} }

View file

@ -318,7 +318,7 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
model := tt.model model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, false)
if tt.error == nil && err != nil { if tt.error == nil && err != nil {
t.Fatal(err) t.Fatal(err)
} else if tt.error != nil && err != tt.error { } else if tt.error != nil && err != tt.error {

View file

@ -18,7 +18,6 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"slices" "slices"
"strings" "strings"
"syscall" "syscall"
@ -181,6 +180,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Suffix != "" { if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert) caps = append(caps, model.CapabilityInsert)
} }
if req.Thinking {
caps = append(caps, model.CapabilityThinking)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
@ -1475,6 +1477,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools) caps = append(caps, model.CapabilityTools)
} }
if req.Thinking {
caps = append(caps, model.CapabilityThinking)
}
name := model.ParseName(req.Model) name := model.ParseName(req.Model)
if !name.IsValid() { if !name.IsValid() {
@ -1515,7 +1520,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
msgs = filterThinkTags(msgs, m) msgs = filterThinkTags(msgs, m)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Thinking)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err) slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -1529,6 +1534,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch) defer close(ch)
var sb strings.Builder var sb strings.Builder
var toolCallIndex int = 0 var toolCallIndex int = 0
var thinkingState thinkingParser = thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
@ -1548,6 +1557,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
}, },
} }
if req.Thinking {
thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.ThinkingBlock = thinkingContent
}
if r.Done { if r.Done {
res.DoneReason = r.DoneReason.String() res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
@ -1565,7 +1584,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
// Streaming tool calls: // Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream // If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent // This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content) sb.WriteString(res.Message.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls res.Message.ToolCalls = toolCalls
for i := range toolCalls { for i := range toolCalls {
@ -1613,9 +1632,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
resp.Message.Content = sb.String() resp.Message.Content = sb.String()
if req.Thinking {
resp.Message.ThinkingBlock, resp.Message.Content = extractThinking(resp.Message.Content)
}
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { if toolCalls, ok := m.parseToolCalls(resp.Message.Content); ok {
resp.Message.ToolCalls = toolCalls resp.Message.ToolCalls = toolCalls
resp.Message.Content = "" resp.Message.Content = ""
} }
@ -1643,7 +1665,16 @@ func handleScheduleError(c *gin.Context, name string, err error) {
} }
} }
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`) // returns (thinkingContent, content)
func extractThinking(text string) (string, string) {
thinking := thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
thinkingContent, content := thinking.addContent(text)
return thinkingContent, content
}
func filterThinkTags(msgs []api.Message, m *Model) []api.Message { func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" { if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
@ -1656,7 +1687,9 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
for i, msg := range msgs { for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex { if msg.Role == "assistant" && i < finalUserIndex {
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "") thinking, content := extractThinking(msg.Content)
msg.Content = content
msg.ThinkingBlock = thinking
} }
} }
} }

View file

@ -143,6 +143,24 @@ func TestGenerateChat(t *testing.T) {
} }
}) })
t.Run("missing thinking capability", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Thinking: true,
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) { t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{}) w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest { if w.Code != http.StatusBadRequest {

127
server/thinking.go Normal file
View file

@ -0,0 +1,127 @@
package server
import (
"strings"
"unicode"
)
type thinkingParseState int
const (
thinkingParseState_LookingForOpening thinkingParseState = iota
thinkingParseState_Thinking
thinkingParseState_ThinkingDone
)
func (s thinkingParseState) String() string {
switch s {
case thinkingParseState_LookingForOpening:
return "LookingForOpening"
case thinkingParseState_Thinking:
return "Thinking"
case thinkingParseState_ThinkingDone:
return "ThinkingDone"
default:
return "Unknown"
}
}
type thinkingParser struct {
state thinkingParseState
openingTag string
closingTag string
acc strings.Builder
}
// returns the thinking content and the normal content that should be
// immediately sent to the user. It will internally buffer if it needs to see
// more content to disambiguate
func (s *thinkingParser) addContent(content string) (string, string) {
s.acc.WriteString(content)
var thinkingAcc, remainingAcc strings.Builder
var thinking, remaining string
keepLooping := true
// we loop because we might pass through multiple parsing states in a single
// call to addContent, and we want to make sure callers don't have to wait for
// data that's already unambiguous
for keepLooping {
thinking, remaining, keepLooping = eat(s)
thinkingAcc.WriteString(thinking)
remainingAcc.WriteString(remaining)
}
return thinkingAcc.String(), remainingAcc.String()
}
// the additional bool return is true iff we should continue eating
func eat(s *thinkingParser) (string, string, bool) {
switch s.state {
case thinkingParseState_LookingForOpening:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, s.openingTag) {
after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
// after might contain more than just thinking tokens, so we continue
// parsing instead of returning it as thinking tokens here
s.acc.Reset()
s.acc.WriteString(after)
s.state = thinkingParseState_Thinking
return "", "", true
} else if strings.HasPrefix(s.openingTag, trimmed) {
// partial opening seen, so let's keep accumulating
return "", "", false
} else if trimmed == "" {
// saw whitespace only, so let's keep accumulating
return "", "", false
} else {
// didn't see an opening tag, but we have content, so thinking was skipped
s.state = thinkingParseState_ThinkingDone
// note that we use the original content, not the trimmed one because we
// don't want to eat any whitespace in the real content if there were no
// thinking tags
return "", s.acc.String(), false
}
case thinkingParseState_Thinking:
acc := s.acc.String()
if strings.Contains(acc, s.closingTag) {
split := strings.Split(acc, s.closingTag)
thinking := split[0]
remaining := strings.Join(split[1:], s.closingTag)
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
s.acc.Reset()
s.state = thinkingParseState_ThinkingDone
return thinking, remaining, false
} else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 {
thinking := acc[:len(acc)-overlapLen]
remaining := acc[len(acc)-overlapLen:]
s.acc.Reset()
// keep track of the candidate closing tag. We have to buffer it until it
// becomes disambiguated
s.acc.WriteString(remaining)
return thinking, "", false
} else {
// purely just thinking tokens, so we can return them
s.acc.Reset()
return acc, "", false
}
case thinkingParseState_ThinkingDone:
acc := s.acc.String()
s.acc.Reset()
return "", acc, false
default:
panic("unknown state")
}
}
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}

161
server/thinking_test.go Normal file
View file

@ -0,0 +1,161 @@
package server
import (
"testing"
)
func TestExtractThinking(t *testing.T) {
tests := []struct {
in, wantContent, wantThink string
}{
{
in: "<think> internal </think> world",
wantThink: "internal ",
wantContent: "world",
},
{
in: "<think>a</think><think>b</think>c",
wantThink: "a",
wantContent: "<think>b</think>c",
},
{
in: "no think",
wantThink: "",
wantContent: "no think",
},
}
for i, tt := range tests {
gotThinking, gotContent := extractThinking(tt.in)
if gotContent != tt.wantContent || gotThinking != tt.wantThink {
t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent)
}
}
}
func TestThinkingStreaming(t *testing.T) {
type step struct {
input string
wantThinking string
wantContent string
wantStateAfter thinkingParseState
}
cases := []struct {
desc string
skip bool
steps []step
}{
{
desc: "content without a thinking tag",
steps: []step{
{
input: " abc",
wantThinking: "",
wantContent: " abc",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
{
desc: "content before a thinking tag nerfs the thinking tag",
steps: []step{
{
input: " abc <think>def</think> ghi",
wantThinking: "",
wantContent: " abc <think>def</think> ghi",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
{
desc: "building up a thinking tag partially",
// skip: true,
steps: []step{
{
input: " <th",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingParseState_LookingForOpening,
},
{
input: "in",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingParseState_LookingForOpening,
},
{
input: "k>a",
wantThinking: "a",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
},
},
{
desc: "partial closing tag",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "ink>def",
wantThinking: "",
wantContent: "def",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
{
desc: "partial closing tag fakeout",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "ing>def",
wantThinking: "</thing>def",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "ghi</thi",
wantThinking: "ghi",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "nk>jkl",
wantThinking: "",
wantContent: "jkl",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
}
for _, c := range cases {
parser := thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
if c.skip {
continue
}
for i, step := range c.steps {
thinking, content := parser.addContent(step.input)
if content != step.wantContent || thinking != step.wantThinking {
t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking)
}
if parser.state != step.wantStateAfter {
t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state.String(), step.wantStateAfter.String())
}
}
}
}

View file

@ -165,8 +165,9 @@ func (t *Template) Vars() []string {
type Values struct { type Values struct {
Messages []api.Message Messages []api.Message
api.Tools api.Tools
Prompt string Prompt string
Suffix string Suffix string
Thinking bool
// forceLegacy is a flag used to test compatibility with legacy templates // forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool forceLegacy bool
@ -225,6 +226,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
"Prompt": v.Prompt, "Prompt": v.Prompt,
"Suffix": v.Suffix, "Suffix": v.Suffix,
"Response": "", "Response": "",
"Thinking": v.Thinking,
}) })
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
@ -232,6 +234,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
"Messages": messages, "Messages": messages,
"Tools": v.Tools, "Tools": v.Tools,
"Response": "", "Response": "",
"Thinking": v.Thinking,
}) })
} }
@ -244,6 +247,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
"Response": response, "Response": response,
"Thinking": v.Thinking,
}); err != nil { }); err != nil {
return err return err
} }
@ -289,6 +293,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
"Response": response, "Response": response,
"Thinking": v.Thinking,
}); err != nil { }); err != nil {
return err return err
} }

View file

@ -8,6 +8,7 @@ const (
CapabilityInsert = Capability("insert") CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision") CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding") CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
) )
func (c Capability) String() string { func (c Capability) String() string {