This commit is contained in:
Michael Yang 2024-06-20 11:00:08 -07:00
parent 269ed6e6a2
commit 2c3fe1fd97
5 changed files with 224 additions and 113 deletions

View file

@ -7,15 +7,10 @@ import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
)
type mock struct {
llm.LlamaServer
}
func (m mock) Tokenize(_ context.Context, s string) (tokens []int, err error) {
func tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
@ -48,7 +43,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate messages",
name: "truncate messages",
limit: 1,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@ -60,7 +55,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate messages with image",
name: "truncate messages with image",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@ -75,7 +70,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate messages with images",
name: "truncate messages with images",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@ -90,7 +85,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "messages with images",
name: "messages with images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@ -106,7 +101,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "message with image tag",
name: "message with image tag",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
@ -122,7 +117,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "messages with interleaved images",
name: "messages with interleaved images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@ -140,7 +135,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate message with interleaved images",
name: "truncate message with interleaved images",
limit: 1024,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@ -157,7 +152,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "message with system prompt",
name: "message with system prompt",
limit: 2048,
msgs: []api.Message{
{Role: "system", Content: "You are the Test Who Lived."},
@ -181,14 +176,9 @@ func TestChatPrompt(t *testing.T) {
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
r := runnerRef{
llama: mock{},
model: &Model{Template: tmpl, ProjectorPaths: []string{"vision"}},
Options: &api.Options{},
}
r.NumCtx = tt.limit
prompt, images, err := chatPrompt(context.TODO(), &r, tt.msgs)
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
if err != nil {
t.Fatal(err)
}