mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
preserve last system message from modelfile (#2289)
This commit is contained in:
parent
583950c828
commit
a896079705
2 changed files with 66 additions and 17 deletions
|
@ -256,15 +256,17 @@ func chatHistoryEqual(a, b ChatHistory) bool {
|
|||
|
||||
func TestChat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
msgs []api.Message
|
||||
want ChatHistory
|
||||
wantErr string
|
||||
name string
|
||||
model Model
|
||||
msgs []api.Message
|
||||
want ChatHistory
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "Single Message",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Single Message",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
|
@ -287,8 +289,10 @@ func TestChat(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Message History",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
|
@ -323,8 +327,10 @@ func TestChat(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Assistant Only",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
|
@ -340,6 +346,51 @@ func TestChat(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Last system message is preserved from modelfile",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
System: "You are Mojo Jojo.",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hi",
|
||||
},
|
||||
},
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are Mojo Jojo.",
|
||||
Prompt: "hi",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are Mojo Jojo.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Last system message is preserved from messages",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
System: "You are Mojo Jojo.",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are Professor Utonium.",
|
||||
},
|
||||
},
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are Professor Utonium.",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are Professor Utonium.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid Role",
|
||||
msgs: []api.Message{
|
||||
|
@ -353,11 +404,8 @@ func TestChat(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
m := Model{
|
||||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := m.ChatPrompts(tt.msgs)
|
||||
got, err := tt.model.ChatPrompts(tt.msgs)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue