mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
WIP thinking API support
- Allows specifying whether thinking mode should be on or not - Templates get passed a new option so, e.g., qwen3's template can put `/think` or `/no_think` in the system prompt depending on the value of the setting - Add parsing for thinking blocks in both streaming/non-streaming mode - Update the CLI to make use of these changes TODO: - [ ] Don't parse thinking blocks when the user doesn't explicitly set the option, to maintain backwards compatibility - [ ] Warning on CLI when using a non-thinking/older version of a model (with an old template) - [ ] Wire up capabilities fully - [x] Unify parsing for streaming/non-streaming - [ ] Update templates - [ ] Update python/js libraries - [ ] How to handle differences in models wrt defaults and whether or not the thinking ability can even be controlled. If not specified by the user, should there be a default or should the template be able to check if it was explicitly set?
This commit is contained in:
parent
a7835c6716
commit
77f4594e80
14 changed files with 513 additions and 12 deletions
14
api/types.go
14
api/types.go
|
@ -83,6 +83,10 @@ type GenerateRequest struct {
|
|||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Thinking controls whether thinking/reasoning models will think before
|
||||
// responding
|
||||
Thinking bool `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
|
@ -108,6 +112,10 @@ type ChatRequest struct {
|
|||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Thinking controls whether thinking/reasoning models will think before
|
||||
// responding
|
||||
Thinking bool `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
|
@ -130,6 +138,10 @@ type Message struct {
|
|||
Content string `json:"content"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
|
||||
// ThinkingBlock contains the text that was inside <think> tags in the
|
||||
// original model output when ChatRequest.Thinking was enabled.
|
||||
ThinkingBlock string `json:"thinkingBlock,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
|
@ -275,6 +287,8 @@ type Options struct {
|
|||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
|
||||
Thinking bool `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
|
|
58
cmd/cmd.go
58
cmd/cmd.go
|
@ -38,12 +38,33 @@ import (
|
|||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/runner"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// warnMissingThinking emits a warning if the model does not advertise thinking
|
||||
// support and opts.Thinking is set. Failures to query the capability are
|
||||
// ignored so this does not impact regular usage.
|
||||
func warnMissingThinking(ctx context.Context, client *api.Client, name string) {
|
||||
fmt.Printf("$$$$$ warnMissingThinking %q\n", name)
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityThinking {
|
||||
return
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
|
||||
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
||||
|
||||
func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||
|
@ -243,6 +264,9 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||
req := &api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
|
||||
// pass Thinking here so we fail before getting to the chat prompt if the model doesn't support it
|
||||
Thinking: opts.Thinking,
|
||||
}
|
||||
|
||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||
|
@ -277,6 +301,12 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
opts.Format = format
|
||||
|
||||
thinkingFlag, err := cmd.Flags().GetBool("thinking")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Thinking = thinkingFlag
|
||||
|
||||
keepAlive, err := cmd.Flags().GetString("keepalive")
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -361,6 +391,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
return err
|
||||
}
|
||||
warnMissingThinking(cmd.Context(), client, opts.Model)
|
||||
|
||||
for _, msg := range info.Messages {
|
||||
switch msg.Role {
|
||||
|
@ -876,6 +907,7 @@ type runOptions struct {
|
|||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Thinking bool
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
|
@ -958,6 +990,8 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var role string
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
p.StopAndClear()
|
||||
|
@ -965,7 +999,23 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||
latest = response
|
||||
|
||||
role = response.Message.Role
|
||||
if response.Message.ThinkingBlock != "" {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(readline.ColorGrey + readline.ColorBold + "<think>" + readline.ColorDefault + readline.ColorGrey)
|
||||
thinkTagOpened = true
|
||||
}
|
||||
displayResponse(response.Message.ThinkingBlock, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if !thinkTagClosed && thinkTagOpened && content != "" {
|
||||
fmt.Print(readline.ColorGrey + readline.ColorBold + "</think>" + readline.ColorDefault)
|
||||
thinkTagClosed = true
|
||||
}
|
||||
// purposefully not putting thinking blocks in the response, which would
|
||||
// only be needed if we later added tool calling to the cli (they get
|
||||
// filtered out anyway since current models don't expect them unless you're
|
||||
// about to finish some tool calls)
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
@ -982,6 +1032,11 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||
Messages: opts.Messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Thinking: opts.Thinking,
|
||||
}
|
||||
|
||||
if opts.Thinking {
|
||||
warnMissingThinking(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
|
@ -1075,6 +1130,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||
System: opts.System,
|
||||
Options: opts.Options,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
Thinking: opts.Thinking,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
|
@ -1290,6 +1346,8 @@ func NewCLI() *cobra.Command {
|
|||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
// TODO(drifkin): what should happen for an unsupported model? Warning? Fail hard?
|
||||
runCmd.Flags().Bool("thinking", false, "Turn on thinking mode for supported models")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
|
|
@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set thinking Enable thinking")
|
||||
fmt.Fprintln(os.Stderr, " /set nothinking Disable thinking")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
|
@ -260,6 +262,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "thinking":
|
||||
opts.Thinking = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
warnMissingThinking(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'thinking' mode.")
|
||||
case "nothinking":
|
||||
opts.Thinking = false
|
||||
fmt.Println("Set 'nothinking' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
|
|
64
cmd/warn_thinking_test.go
Normal file
64
cmd/warn_thinking_test.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Test that a warning is printed when thinking is requested but not supported.
|
||||
func TestWarnMissingThinking(t *testing.T) {
|
||||
cases := []struct {
|
||||
capabilities []model.Capability
|
||||
expectWarn bool
|
||||
}{
|
||||
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
|
||||
{capabilities: []model.Capability{}, expectWarn: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
|
||||
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
|
||||
}
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
resp := api.ShowResponse{Capabilities: tc.capabilities}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
warnMissingThinking(context.Background(), client, "m")
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
out, _ := io.ReadAll(r)
|
||||
|
||||
warned := strings.Contains(string(out), "warning:")
|
||||
if tc.expectWarn && !warned {
|
||||
t.Errorf("expected warning, got none")
|
||||
}
|
||||
if !tc.expectWarn && warned {
|
||||
t.Errorf("did not expect warning, got: %s", string(out))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -61,6 +61,8 @@ const (
|
|||
ColorGrey = Esc + "[38;5;245m"
|
||||
ColorDefault = Esc + "[0m"
|
||||
|
||||
ColorBold = Esc + "[1m"
|
||||
|
||||
StartBracketedPaste = Esc + "[?2004h"
|
||||
EndBracketedPaste = Esc + "[?2004l"
|
||||
)
|
||||
|
|
|
@ -37,6 +37,7 @@ var (
|
|||
errCapabilityInsert = errors.New("insert")
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
|
@ -106,6 +107,11 @@ func (m *Model) Capabilities() []model.Capability {
|
|||
capabilities = append(capabilities, model.CapabilityInsert)
|
||||
}
|
||||
|
||||
// Check for thinking capability
|
||||
if slices.Contains(m.Template.Vars(), "thinking") {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
|
@ -122,6 +128,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
|||
model.CapabilityInsert: errCapabilityInsert,
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
}
|
||||
|
||||
for _, cap := range want {
|
||||
|
|
|
@ -22,7 +22,7 @@ var errTooManyImages = errors.New("vision model only supports a single image per
|
|||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, thinking bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
isMllama := checkMllamaModelFamily(m)
|
||||
|
@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Thinking: thinking}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
@ -142,7 +142,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Thinking: thinking}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -318,7 +318,7 @@ func TestChatPrompt(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, false)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
@ -181,6 +180,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
if req.Suffix != "" {
|
||||
caps = append(caps, model.CapabilityInsert)
|
||||
}
|
||||
if req.Thinking {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
|
@ -1475,6 +1477,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
}
|
||||
if req.Thinking {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
|
@ -1515,7 +1520,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Thinking)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
@ -1529,6 +1534,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
defer close(ch)
|
||||
var sb strings.Builder
|
||||
var toolCallIndex int = 0
|
||||
var thinkingState thinkingParser = thinkingParser{
|
||||
openingTag: "<think>",
|
||||
closingTag: "</think>",
|
||||
}
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
|
@ -1548,6 +1557,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
},
|
||||
}
|
||||
|
||||
if req.Thinking {
|
||||
thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
res.Message.ThinkingBlock = thinkingContent
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
|
@ -1565,7 +1584,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
// Streaming tool calls:
|
||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||
// This ensures that content is cleared from the message on the last chunk sent
|
||||
sb.WriteString(r.Content)
|
||||
sb.WriteString(res.Message.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
|
@ -1613,9 +1632,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
resp.Message.Content = sb.String()
|
||||
if req.Thinking {
|
||||
resp.Message.ThinkingBlock, resp.Message.Content = extractThinking(resp.Message.Content)
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
if toolCalls, ok := m.parseToolCalls(resp.Message.Content); ok {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
resp.Message.Content = ""
|
||||
}
|
||||
|
@ -1643,7 +1665,16 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
|
||||
// returns (thinkingContent, content)
|
||||
func extractThinking(text string) (string, string) {
|
||||
thinking := thinkingParser{
|
||||
openingTag: "<think>",
|
||||
closingTag: "</think>",
|
||||
}
|
||||
|
||||
thinkingContent, content := thinking.addContent(text)
|
||||
return thinkingContent, content
|
||||
}
|
||||
|
||||
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
|
||||
|
@ -1656,7 +1687,9 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
|||
|
||||
for i, msg := range msgs {
|
||||
if msg.Role == "assistant" && i < finalUserIndex {
|
||||
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
|
||||
thinking, content := extractThinking(msg.Content)
|
||||
msg.Content = content
|
||||
msg.ThinkingBlock = thinking
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -143,6 +143,24 @@ func TestGenerateChat(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("missing thinking capability", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Thinking: true,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
|
|
127
server/thinking.go
Normal file
127
server/thinking.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type thinkingParseState int
|
||||
|
||||
const (
|
||||
thinkingParseState_LookingForOpening thinkingParseState = iota
|
||||
thinkingParseState_Thinking
|
||||
thinkingParseState_ThinkingDone
|
||||
)
|
||||
|
||||
func (s thinkingParseState) String() string {
|
||||
switch s {
|
||||
case thinkingParseState_LookingForOpening:
|
||||
return "LookingForOpening"
|
||||
case thinkingParseState_Thinking:
|
||||
return "Thinking"
|
||||
case thinkingParseState_ThinkingDone:
|
||||
return "ThinkingDone"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type thinkingParser struct {
|
||||
state thinkingParseState
|
||||
openingTag string
|
||||
closingTag string
|
||||
acc strings.Builder
|
||||
}
|
||||
|
||||
// returns the thinking content and the normal content that should be
|
||||
// immediately sent to the user. It will internally buffer if it needs to see
|
||||
// more content to disambiguate
|
||||
func (s *thinkingParser) addContent(content string) (string, string) {
|
||||
s.acc.WriteString(content)
|
||||
|
||||
var thinkingAcc, remainingAcc strings.Builder
|
||||
|
||||
var thinking, remaining string
|
||||
keepLooping := true
|
||||
// we loop because we might pass through multiple parsing states in a single
|
||||
// call to addContent, and we want to make sure callers don't have to wait for
|
||||
// data that's already unambiguous
|
||||
for keepLooping {
|
||||
thinking, remaining, keepLooping = eat(s)
|
||||
thinkingAcc.WriteString(thinking)
|
||||
remainingAcc.WriteString(remaining)
|
||||
}
|
||||
|
||||
return thinkingAcc.String(), remainingAcc.String()
|
||||
}
|
||||
|
||||
// the additional bool return is true iff we should continue eating
|
||||
func eat(s *thinkingParser) (string, string, bool) {
|
||||
switch s.state {
|
||||
case thinkingParseState_LookingForOpening:
|
||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, s.openingTag) {
|
||||
after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
// after might contain more than just thinking tokens, so we continue
|
||||
// parsing instead of returning it as thinking tokens here
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = thinkingParseState_Thinking
|
||||
return "", "", true
|
||||
} else if strings.HasPrefix(s.openingTag, trimmed) {
|
||||
// partial opening seen, so let's keep accumulating
|
||||
return "", "", false
|
||||
} else if trimmed == "" {
|
||||
// saw whitespace only, so let's keep accumulating
|
||||
return "", "", false
|
||||
} else {
|
||||
// didn't see an opening tag, but we have content, so thinking was skipped
|
||||
s.state = thinkingParseState_ThinkingDone
|
||||
// note that we use the original content, not the trimmed one because we
|
||||
// don't want to eat any whitespace in the real content if there were no
|
||||
// thinking tags
|
||||
return "", s.acc.String(), false
|
||||
}
|
||||
case thinkingParseState_Thinking:
|
||||
acc := s.acc.String()
|
||||
if strings.Contains(acc, s.closingTag) {
|
||||
split := strings.Split(acc, s.closingTag)
|
||||
thinking := split[0]
|
||||
remaining := strings.Join(split[1:], s.closingTag)
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
s.acc.Reset()
|
||||
s.state = thinkingParseState_ThinkingDone
|
||||
return thinking, remaining, false
|
||||
} else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 {
|
||||
thinking := acc[:len(acc)-overlapLen]
|
||||
remaining := acc[len(acc)-overlapLen:]
|
||||
s.acc.Reset()
|
||||
// keep track of the candidate closing tag. We have to buffer it until it
|
||||
// becomes disambiguated
|
||||
s.acc.WriteString(remaining)
|
||||
return thinking, "", false
|
||||
} else {
|
||||
// purely just thinking tokens, so we can return them
|
||||
s.acc.Reset()
|
||||
return acc, "", false
|
||||
}
|
||||
case thinkingParseState_ThinkingDone:
|
||||
acc := s.acc.String()
|
||||
s.acc.Reset()
|
||||
return "", acc, false
|
||||
default:
|
||||
panic("unknown state")
|
||||
}
|
||||
}
|
||||
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
161
server/thinking_test.go
Normal file
161
server/thinking_test.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantContent, wantThink string
|
||||
}{
|
||||
{
|
||||
in: "<think> internal </think> world",
|
||||
wantThink: "internal ",
|
||||
wantContent: "world",
|
||||
},
|
||||
{
|
||||
in: "<think>a</think><think>b</think>c",
|
||||
wantThink: "a",
|
||||
wantContent: "<think>b</think>c",
|
||||
},
|
||||
{
|
||||
in: "no think",
|
||||
wantThink: "",
|
||||
wantContent: "no think",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
gotThinking, gotContent := extractThinking(tt.in)
|
||||
if gotContent != tt.wantContent || gotThinking != tt.wantThink {
|
||||
t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingStreaming(t *testing.T) {
|
||||
|
||||
type step struct {
|
||||
input string
|
||||
wantThinking string
|
||||
wantContent string
|
||||
wantStateAfter thinkingParseState
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
skip bool
|
||||
steps []step
|
||||
}{
|
||||
{
|
||||
desc: "content without a thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " abc",
|
||||
wantThinking: "",
|
||||
wantContent: " abc",
|
||||
wantStateAfter: thinkingParseState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before a thinking tag nerfs the thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " abc <think>def</think> ghi",
|
||||
wantThinking: "",
|
||||
wantContent: " abc <think>def</think> ghi",
|
||||
wantStateAfter: thinkingParseState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "building up a thinking tag partially",
|
||||
// skip: true,
|
||||
steps: []step{
|
||||
{
|
||||
input: " <th",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_LookingForOpening,
|
||||
},
|
||||
{
|
||||
input: "in",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_LookingForOpening,
|
||||
},
|
||||
{
|
||||
input: "k>a",
|
||||
wantThinking: "a",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial closing tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>abc</th",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ink>def",
|
||||
wantThinking: "",
|
||||
wantContent: "def",
|
||||
wantStateAfter: thinkingParseState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial closing tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>abc</th",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ing>def",
|
||||
wantThinking: "</thing>def",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ghi</thi",
|
||||
wantThinking: "ghi",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "nk>jkl",
|
||||
wantThinking: "",
|
||||
wantContent: "jkl",
|
||||
wantStateAfter: thinkingParseState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
parser := thinkingParser{
|
||||
openingTag: "<think>",
|
||||
closingTag: "</think>",
|
||||
}
|
||||
if c.skip {
|
||||
continue
|
||||
}
|
||||
for i, step := range c.steps {
|
||||
thinking, content := parser.addContent(step.input)
|
||||
if content != step.wantContent || thinking != step.wantThinking {
|
||||
t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking)
|
||||
}
|
||||
if parser.state != step.wantStateAfter {
|
||||
t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state.String(), step.wantStateAfter.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -165,8 +165,9 @@ func (t *Template) Vars() []string {
|
|||
type Values struct {
|
||||
Messages []api.Message
|
||||
api.Tools
|
||||
Prompt string
|
||||
Suffix string
|
||||
Prompt string
|
||||
Suffix string
|
||||
Thinking bool
|
||||
|
||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||
forceLegacy bool
|
||||
|
@ -225,6 +226,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||
"Prompt": v.Prompt,
|
||||
"Suffix": v.Suffix,
|
||||
"Response": "",
|
||||
"Thinking": v.Thinking,
|
||||
})
|
||||
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
|
@ -232,6 +234,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||
"Messages": messages,
|
||||
"Tools": v.Tools,
|
||||
"Response": "",
|
||||
"Thinking": v.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -244,6 +247,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Thinking": v.Thinking,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -289,6 +293,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Thinking": v.Thinking,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ const (
|
|||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue