From ad3c7c9bda3b2db9a6887f65e3134c093333d3d5 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Wed, 30 Apr 2025 13:57:45 -0700 Subject: [PATCH] strip out thinking tags in message history for qwen3 & r1 (#10490) * strip out thinking tags in message history for qwen3 & r1 This is in advance of "proper" support where we'll make reasoning configurable and we'll parse out thinking/reasoning tags and provide them to the caller. These models expect there to be no thinking tags in the message history, so this should improve quality * parse model names instead of hacky prefix check --- server/routes.go | 22 ++++++++ server/routes_test.go | 126 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/server/routes.go b/server/routes.go index 31acd0d1a..16f22cf93 100644 --- a/server/routes.go +++ b/server/routes.go @@ -18,6 +18,7 @@ import ( "os" "os/signal" "path/filepath" + "regexp" "slices" "strings" "syscall" @@ -1512,6 +1513,7 @@ func (s *Server) ChatHandler(c *gin.Context) { if req.Messages[0].Role != "system" && m.System != "" { msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) } + msgs = filterThinkTags(msgs, m) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) if err != nil { @@ -1640,3 +1642,23 @@ func handleScheduleError(c *gin.Context, name string, err error) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } } + +var thinkTagRegexp = regexp.MustCompile(`(?s).*?(\n)*`) + +func filterThinkTags(msgs []api.Message, m *Model) []api.Message { + if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" { + finalUserIndex := -1 + for i, msg := range msgs { + if msg.Role == "user" { + finalUserIndex = i + } + } + + for i, msg := range msgs { + if msg.Role == "assistant" && i < finalUserIndex { + msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "") + } + } + } + return msgs +} diff --git a/server/routes_test.go b/server/routes_test.go index e13c4b599..2894b1555 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -15,6 +15,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "sort" "strings" "testing" @@ -746,3 +747,128 @@ func TestNormalize(t *testing.T) { }) } } + +func TestFilterThinkTags(t *testing.T) { + type testCase struct { + msgs []api.Message + want []api.Message + model *Model + } + testCases := []testCase{ + { + msgs: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "Thinking... about the answerabc"}, + {Role: "user", Content: "What is the answer?"}, + }, + want: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "abc"}, + {Role: "user", Content: "What is the answer?"}, + }, + model: &Model{ + Config: ConfigV2{ + ModelFamily: "qwen3", + }, + }, + }, + // with newlines inside the think tag aned newlines after + { + msgs: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "Thinking... \n\nabout \nthe answer\n\nabc\ndef"}, + {Role: "user", Content: "What is the answer?"}, + }, + want: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "abc\ndef"}, + {Role: "user", Content: "What is the answer?"}, + }, + model: &Model{ + Config: ConfigV2{ + ModelFamily: "qwen3", + }, + }, + }, + // should leave thinking tags if it's after the last user message + { + msgs: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "Thinking...after"}, + {Role: "user", Content: "What is the answer?"}, + {Role: "assistant", Content: "thinking againhjk"}, + {Role: "assistant", Content: "thinking yet againhjk"}, + }, + want: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "after"}, + {Role: "user", Content: "What is the answer?"}, + {Role: "assistant", Content: "thinking againhjk"}, + {Role: "assistant", Content: "thinking yet againhjk"}, + }, + model: &Model{ + Config: ConfigV2{ + ModelFamily: "qwen3", + }, + }, + }, + { + // shouldn't strip anything because the model family isn't one of the hardcoded ones + msgs: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "Thinking... about the answerabc"}, + {Role: "user", Content: "What is the answer?"}, + }, + want: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "Thinking... about the answerabc"}, + {Role: "user", Content: "What is the answer?"}, + }, + model: &Model{ + Config: ConfigV2{ + ModelFamily: "llama3", + }, + }, + }, + { + // deepseek-r1:-prefixed model + msgs: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "Thinking... about the answerabc"}, + {Role: "user", Content: "What is the answer?"}, + }, + want: []api.Message{ + {Role: "user", Content: "Hello, world!"}, + {Role: "assistant", Content: "abc"}, + {Role: "user", Content: "What is the answer?"}, + }, + model: &Model{ + Name: "registry.ollama.ai/library/deepseek-r1:latest", + ShortName: "deepseek-r1:7b", + Config: ConfigV2{}, + }, + }, + } + + for i, tc := range testCases { + filtered := filterThinkTags(tc.msgs, tc.model) + + if !reflect.DeepEqual(filtered, tc.want) { + t.Errorf("messages differ for case %d:", i) + for i := range tc.want { + if i >= len(filtered) { + t.Errorf(" missing message %d: %+v", i, tc.want[i]) + continue + } + if !reflect.DeepEqual(filtered[i], tc.want[i]) { + t.Errorf(" message %d:\n want: %+v\n got: %+v", i, tc.want[i], filtered[i]) + } + } + if len(filtered) > len(tc.want) { + for i := len(tc.want); i < len(filtered); i++ { + t.Errorf(" extra message %d: %+v", i, filtered[i]) + } + } + } + } +}