mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
strip out thinking tags in message history for qwen3 & r1 (#10490)
* strip out thinking tags in message history for qwen3 & r1 This is in advance of "proper" support where we'll make reasoning configurable and we'll parse out thinking/reasoning tags and provide them to the caller. These models expect there to be no thinking tags in the message history, so this should improve quality * parse model names instead of hacky prefix check
This commit is contained in:
parent
415c8fcc3d
commit
ad3c7c9bda
2 changed files with 148 additions and 0 deletions
|
@ -18,6 +18,7 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
@ -1512,6 +1513,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
if req.Messages[0].Role != "system" && m.System != "" {
|
||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
if err != nil {
|
||||
|
@ -1640,3 +1642,23 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
|||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
|
||||
|
||||
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
|
||||
finalUserIndex := -1
|
||||
for i, msg := range msgs {
|
||||
if msg.Role == "user" {
|
||||
finalUserIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
for i, msg := range msgs {
|
||||
if msg.Role == "assistant" && i < finalUserIndex {
|
||||
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -746,3 +747,128 @@ func TestNormalize(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterThinkTags(t *testing.T) {
|
||||
type testCase struct {
|
||||
msgs []api.Message
|
||||
want []api.Message
|
||||
model *Model
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
model: &Model{
|
||||
Config: ConfigV2{
|
||||
ModelFamily: "qwen3",
|
||||
},
|
||||
},
|
||||
},
|
||||
// with newlines inside the think tag aned newlines after
|
||||
{
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... \n\nabout \nthe answer</think>\n\nabc\ndef"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "abc\ndef"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
model: &Model{
|
||||
Config: ConfigV2{
|
||||
ModelFamily: "qwen3",
|
||||
},
|
||||
},
|
||||
},
|
||||
// should leave thinking tags if it's after the last user message
|
||||
{
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking...</think>after"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
|
||||
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "after"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
|
||||
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
|
||||
},
|
||||
model: &Model{
|
||||
Config: ConfigV2{
|
||||
ModelFamily: "qwen3",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// shouldn't strip anything because the model family isn't one of the hardcoded ones
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
model: &Model{
|
||||
Config: ConfigV2{
|
||||
ModelFamily: "llama3",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// deepseek-r1:-prefixed model
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
model: &Model{
|
||||
Name: "registry.ollama.ai/library/deepseek-r1:latest",
|
||||
ShortName: "deepseek-r1:7b",
|
||||
Config: ConfigV2{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
filtered := filterThinkTags(tc.msgs, tc.model)
|
||||
|
||||
if !reflect.DeepEqual(filtered, tc.want) {
|
||||
t.Errorf("messages differ for case %d:", i)
|
||||
for i := range tc.want {
|
||||
if i >= len(filtered) {
|
||||
t.Errorf(" missing message %d: %+v", i, tc.want[i])
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(filtered[i], tc.want[i]) {
|
||||
t.Errorf(" message %d:\n want: %+v\n got: %+v", i, tc.want[i], filtered[i])
|
||||
}
|
||||
}
|
||||
if len(filtered) > len(tc.want) {
|
||||
for i := len(tc.want); i < len(filtered); i++ {
|
||||
t.Errorf(" extra message %d: %+v", i, filtered[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue