mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
* increase default context length to 4096 We lower the default numParallel from 4 to 2 and use these "savings" to double the default context length from 2048 to 4096. We're memory neutral in cases when we previously would've used numParallel == 4, but we add the following mitigation to handle some cases where we would have previously fallen back to 1x2048 due to low VRAM: we decide between 2048 and 4096 using a runtime check, choosing 2048 if we're on a one GPU system with total VRAM of <= 4 GB. We purposefully don't check the available VRAM because we don't want the context window size to change unexpectedly based on the available VRAM. We plan on making the default even larger, but this is a relatively low-risk change we can make to quickly double it. * fix tests add an explicit context length so they don't get truncated. The code that converts -1 from being a signal for doing a runtime check isn't running as part of these tests. * tweak small gpu message * clarify context length default also make it actually show up in `ollama serve --help`
979 lines
27 KiB
Go
979 lines
27 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/go-cmp/cmp"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/discover"
|
|
"github.com/ollama/ollama/fs/ggml"
|
|
"github.com/ollama/ollama/llm"
|
|
)
|
|
|
|
type mockRunner struct {
|
|
llm.LlamaServer
|
|
|
|
// CompletionRequest is only valid until the next call to Completion
|
|
llm.CompletionRequest
|
|
llm.CompletionResponse
|
|
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
|
|
}
|
|
|
|
func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
|
m.CompletionRequest = r
|
|
if m.CompletionFn != nil {
|
|
return m.CompletionFn(ctx, r, fn)
|
|
}
|
|
fn(m.CompletionResponse)
|
|
return nil
|
|
}
|
|
|
|
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
|
for range strings.Fields(s) {
|
|
tokens = append(tokens, len(tokens))
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
|
return func(_ discover.GpuInfoList, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
|
return mock, nil
|
|
}
|
|
}
|
|
|
|
func TestGenerateChat(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
mock := mockRunner{
|
|
CompletionResponse: llm.CompletionResponse{
|
|
Done: true,
|
|
DoneReason: llm.DoneReasonStop,
|
|
PromptEvalCount: 1,
|
|
PromptEvalDuration: 1,
|
|
EvalCount: 1,
|
|
EvalDuration: 1,
|
|
},
|
|
}
|
|
|
|
s := Server{
|
|
sched: &Scheduler{
|
|
pendingReqCh: make(chan *LlmRequest, 1),
|
|
finishedReqCh: make(chan *LlmRequest, 1),
|
|
expiredCh: make(chan *runnerRef, 1),
|
|
unloadedCh: make(chan any, 1),
|
|
loaded: make(map[string]*runnerRef),
|
|
newServerFn: newMockServer(&mock),
|
|
getGpuFn: discover.GetGPUInfo,
|
|
getCpuFn: discover.GetCPUInfo,
|
|
reschedDelay: 250 * time.Millisecond,
|
|
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
|
// add small delay to simulate loading
|
|
time.Sleep(time.Millisecond)
|
|
req.successCh <- &runnerRef{
|
|
llama: &mock,
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
go s.sched.Run(context.TODO())
|
|
|
|
_, digest := createBinFile(t, ggml.KV{
|
|
"general.architecture": "llama",
|
|
"llama.block_count": uint32(1),
|
|
"llama.context_length": uint32(8192),
|
|
"llama.embedding_length": uint32(4096),
|
|
"llama.attention.head_count": uint32(32),
|
|
"llama.attention.head_count_kv": uint32(8),
|
|
"tokenizer.ggml.tokens": []string{""},
|
|
"tokenizer.ggml.scores": []float32{0},
|
|
"tokenizer.ggml.token_type": []int32{0},
|
|
}, []ggml.Tensor{
|
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
})
|
|
|
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test",
|
|
Files: map[string]string{"file.gguf": digest},
|
|
Template: `
|
|
{{- if .Tools }}
|
|
{{ .Tools }}
|
|
{{ end }}
|
|
{{- range .Messages }}
|
|
{{- .Role }}: {{ .Content }}
|
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
|
{{- end }}
|
|
{{ end }}`,
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("missing body", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, nil)
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing model", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing capabilities chat", func(t *testing.T) {
|
|
_, digest := createBinFile(t, ggml.KV{
|
|
"general.architecture": "bert",
|
|
"bert.pooling_type": uint32(0),
|
|
}, []ggml.Tensor{})
|
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "bert",
|
|
Files: map[string]string{"bert.gguf": digest},
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "bert",
|
|
})
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("load model", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
var actual api.ChatResponse
|
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if actual.Model != "test" {
|
|
t.Errorf("expected model test, got %s", actual.Model)
|
|
}
|
|
|
|
if !actual.Done {
|
|
t.Errorf("expected done true, got false")
|
|
}
|
|
|
|
if actual.DoneReason != "load" {
|
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
|
}
|
|
})
|
|
|
|
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
|
|
t.Helper()
|
|
|
|
var actual api.ChatResponse
|
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if actual.Model != model {
|
|
t.Errorf("expected model test, got %s", actual.Model)
|
|
}
|
|
|
|
if !actual.Done {
|
|
t.Errorf("expected done false, got true")
|
|
}
|
|
|
|
if actual.DoneReason != "stop" {
|
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
|
}
|
|
|
|
if diff := cmp.Diff(actual.Message, api.Message{
|
|
Role: "assistant",
|
|
Content: content,
|
|
}); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
if actual.PromptEvalCount == 0 {
|
|
t.Errorf("expected prompt eval count > 0, got 0")
|
|
}
|
|
|
|
if actual.PromptEvalDuration == 0 {
|
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
|
}
|
|
|
|
if actual.EvalCount == 0 {
|
|
t.Errorf("expected eval count > 0, got 0")
|
|
}
|
|
|
|
if actual.EvalDuration == 0 {
|
|
t.Errorf("expected eval duration > 0, got 0")
|
|
}
|
|
|
|
if actual.LoadDuration == 0 {
|
|
t.Errorf("expected load duration > 0, got 0")
|
|
}
|
|
|
|
if actual.TotalDuration == 0 {
|
|
t.Errorf("expected total duration > 0, got 0")
|
|
}
|
|
}
|
|
|
|
mock.CompletionResponse.Content = "Hi!"
|
|
t.Run("messages", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello!"},
|
|
},
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkChatResponse(t, w.Body, "test", "Hi!")
|
|
})
|
|
|
|
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test-system",
|
|
From: "test",
|
|
System: "You are a helpful assistant.",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("messages with model system", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello!"},
|
|
},
|
|
Stream: &stream,
|
|
Options: map[string]any{
|
|
"num_ctx": 1024,
|
|
},
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
|
})
|
|
|
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
|
t.Run("messages with system", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "system", Content: "You can perform magic tricks."},
|
|
{Role: "user", Content: "Hello!"},
|
|
},
|
|
Stream: &stream,
|
|
Options: map[string]any{
|
|
"num_ctx": 1024,
|
|
},
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
|
})
|
|
|
|
t.Run("messages with interleaved system", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello!"},
|
|
{Role: "assistant", Content: "I can help you with that."},
|
|
{Role: "system", Content: "You can perform magic tricks."},
|
|
{Role: "user", Content: "Help me write tests."},
|
|
},
|
|
Stream: &stream,
|
|
Options: map[string]any{
|
|
"num_ctx": 1024,
|
|
},
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
|
})
|
|
|
|
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("failed to create test-system model: %d", w.Code)
|
|
}
|
|
|
|
tools := []api.Tool{
|
|
{
|
|
Type: "function",
|
|
Function: api.ToolFunction{
|
|
Name: "get_weather",
|
|
Description: "Get the current weather",
|
|
Parameters: struct {
|
|
Type string `json:"type"`
|
|
Defs any `json:"$defs,omitempty"`
|
|
Items any `json:"items,omitempty"`
|
|
Required []string `json:"required"`
|
|
Properties map[string]struct {
|
|
Type api.PropertyType `json:"type"`
|
|
Items any `json:"items,omitempty"`
|
|
Description string `json:"description"`
|
|
Enum []any `json:"enum,omitempty"`
|
|
} `json:"properties"`
|
|
}{
|
|
Type: "object",
|
|
Required: []string{"location"},
|
|
Properties: map[string]struct {
|
|
Type api.PropertyType `json:"type"`
|
|
Items any `json:"items,omitempty"`
|
|
Description string `json:"description"`
|
|
Enum []any `json:"enum,omitempty"`
|
|
}{
|
|
"location": {
|
|
Type: api.PropertyType{"string"},
|
|
Description: "The city and state",
|
|
},
|
|
"unit": {
|
|
Type: api.PropertyType{"string"},
|
|
Enum: []any{"celsius", "fahrenheit"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
mock.CompletionResponse = llm.CompletionResponse{
|
|
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
|
|
Done: true,
|
|
DoneReason: llm.DoneReasonStop,
|
|
PromptEvalCount: 1,
|
|
PromptEvalDuration: 1,
|
|
EvalCount: 1,
|
|
EvalDuration: 1,
|
|
}
|
|
|
|
streamRequest := true
|
|
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "What's the weather in Seattle?"},
|
|
},
|
|
Tools: tools,
|
|
Stream: &streamRequest,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
var errResp struct {
|
|
Error string `json:"error"`
|
|
}
|
|
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
|
|
t.Logf("Failed to decode error response: %v", err)
|
|
} else {
|
|
t.Logf("Error response: %s", errResp.Error)
|
|
}
|
|
}
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
var resp api.ChatResponse
|
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if resp.Message.ToolCalls == nil {
|
|
t.Error("expected tool calls, got nil")
|
|
}
|
|
|
|
expectedToolCall := api.ToolCall{
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_weather",
|
|
Arguments: api.ToolCallFunctionArguments{
|
|
"location": "Seattle, WA",
|
|
"unit": "celsius",
|
|
},
|
|
},
|
|
}
|
|
|
|
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
|
|
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("messages with tools (streaming)", func(t *testing.T) {
|
|
tools := []api.Tool{
|
|
{
|
|
Type: "function",
|
|
Function: api.ToolFunction{
|
|
Name: "get_weather",
|
|
Description: "Get the current weather",
|
|
Parameters: struct {
|
|
Type string `json:"type"`
|
|
Defs any `json:"$defs,omitempty"`
|
|
Items any `json:"items,omitempty"`
|
|
Required []string `json:"required"`
|
|
Properties map[string]struct {
|
|
Type api.PropertyType `json:"type"`
|
|
Items any `json:"items,omitempty"`
|
|
Description string `json:"description"`
|
|
Enum []any `json:"enum,omitempty"`
|
|
} `json:"properties"`
|
|
}{
|
|
Type: "object",
|
|
Required: []string{"location"},
|
|
Properties: map[string]struct {
|
|
Type api.PropertyType `json:"type"`
|
|
Items any `json:"items,omitempty"`
|
|
Description string `json:"description"`
|
|
Enum []any `json:"enum,omitempty"`
|
|
}{
|
|
"location": {
|
|
Type: api.PropertyType{"string"},
|
|
Description: "The city and state",
|
|
},
|
|
"unit": {
|
|
Type: api.PropertyType{"string"},
|
|
Enum: []any{"celsius", "fahrenheit"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
// Simulate streaming response with multiple chunks
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
|
|
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
|
defer wg.Done()
|
|
|
|
// Send chunks with small delays to simulate streaming
|
|
responses := []llm.CompletionResponse{
|
|
{
|
|
Content: `{"name":"get_`,
|
|
Done: false,
|
|
PromptEvalCount: 1,
|
|
PromptEvalDuration: 1,
|
|
},
|
|
{
|
|
Content: `weather","arguments":{"location":"Seattle`,
|
|
Done: false,
|
|
PromptEvalCount: 2,
|
|
PromptEvalDuration: 1,
|
|
},
|
|
{
|
|
Content: `, WA","unit":"celsius"}}`,
|
|
Done: true,
|
|
DoneReason: llm.DoneReasonStop,
|
|
PromptEvalCount: 3,
|
|
PromptEvalDuration: 1,
|
|
},
|
|
}
|
|
|
|
for _, resp := range responses {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
fn(resp)
|
|
time.Sleep(10 * time.Millisecond) // Small delay between chunks
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "What's the weather in Seattle?"},
|
|
},
|
|
Tools: tools,
|
|
Stream: &stream,
|
|
})
|
|
|
|
wg.Wait()
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
// Read and validate the streamed responses
|
|
decoder := json.NewDecoder(w.Body)
|
|
var finalToolCall api.ToolCall
|
|
|
|
for {
|
|
var resp api.ChatResponse
|
|
if err := decoder.Decode(&resp); err == io.EOF {
|
|
break
|
|
} else if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if resp.Done {
|
|
if len(resp.Message.ToolCalls) != 1 {
|
|
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
|
}
|
|
finalToolCall = resp.Message.ToolCalls[0]
|
|
}
|
|
}
|
|
|
|
expectedToolCall := api.ToolCall{
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_weather",
|
|
Arguments: api.ToolCallFunctionArguments{
|
|
"location": "Seattle, WA",
|
|
"unit": "celsius",
|
|
},
|
|
},
|
|
}
|
|
|
|
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
|
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGenerate(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
mock := mockRunner{
|
|
CompletionResponse: llm.CompletionResponse{
|
|
Done: true,
|
|
DoneReason: llm.DoneReasonStop,
|
|
PromptEvalCount: 1,
|
|
PromptEvalDuration: 1,
|
|
EvalCount: 1,
|
|
EvalDuration: 1,
|
|
},
|
|
}
|
|
|
|
s := Server{
|
|
sched: &Scheduler{
|
|
pendingReqCh: make(chan *LlmRequest, 1),
|
|
finishedReqCh: make(chan *LlmRequest, 1),
|
|
expiredCh: make(chan *runnerRef, 1),
|
|
unloadedCh: make(chan any, 1),
|
|
loaded: make(map[string]*runnerRef),
|
|
newServerFn: newMockServer(&mock),
|
|
getGpuFn: discover.GetGPUInfo,
|
|
getCpuFn: discover.GetCPUInfo,
|
|
reschedDelay: 250 * time.Millisecond,
|
|
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
|
// add small delay to simulate loading
|
|
time.Sleep(time.Millisecond)
|
|
req.successCh <- &runnerRef{
|
|
llama: &mock,
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
go s.sched.Run(context.TODO())
|
|
|
|
_, digest := createBinFile(t, ggml.KV{
|
|
"general.architecture": "llama",
|
|
"llama.block_count": uint32(1),
|
|
"llama.context_length": uint32(8192),
|
|
"llama.embedding_length": uint32(4096),
|
|
"llama.attention.head_count": uint32(32),
|
|
"llama.attention.head_count_kv": uint32(8),
|
|
"tokenizer.ggml.tokens": []string{""},
|
|
"tokenizer.ggml.scores": []float32{0},
|
|
"tokenizer.ggml.token_type": []int32{0},
|
|
}, []ggml.Tensor{
|
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
})
|
|
|
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test",
|
|
Files: map[string]string{"file.gguf": digest},
|
|
Template: `
|
|
{{- if .System }}System: {{ .System }} {{ end }}
|
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}
|
|
`,
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("missing body", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, nil)
|
|
if w.Code != http.StatusNotFound {
|
|
t.Errorf("expected status 404, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing model", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
|
if w.Code != http.StatusNotFound {
|
|
t.Errorf("expected status 404, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing capabilities generate", func(t *testing.T) {
|
|
_, digest := createBinFile(t, ggml.KV{
|
|
"general.architecture": "bert",
|
|
"bert.pooling_type": uint32(0),
|
|
}, []ggml.Tensor{})
|
|
|
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "bert",
|
|
Files: map[string]string{"file.gguf": digest},
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "bert",
|
|
})
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing capabilities suffix", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test",
|
|
Prompt: "def add(",
|
|
Suffix: " return c",
|
|
})
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("load model", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
var actual api.GenerateResponse
|
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if actual.Model != "test" {
|
|
t.Errorf("expected model test, got %s", actual.Model)
|
|
}
|
|
|
|
if !actual.Done {
|
|
t.Errorf("expected done true, got false")
|
|
}
|
|
|
|
if actual.DoneReason != "load" {
|
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
|
}
|
|
})
|
|
|
|
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
|
|
t.Helper()
|
|
|
|
var actual api.GenerateResponse
|
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if actual.Model != model {
|
|
t.Errorf("expected model test, got %s", actual.Model)
|
|
}
|
|
|
|
if !actual.Done {
|
|
t.Errorf("expected done false, got true")
|
|
}
|
|
|
|
if actual.DoneReason != "stop" {
|
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
|
}
|
|
|
|
if actual.Response != content {
|
|
t.Errorf("expected response %s, got %s", content, actual.Response)
|
|
}
|
|
|
|
if actual.Context == nil {
|
|
t.Errorf("expected context not nil")
|
|
}
|
|
|
|
if actual.PromptEvalCount == 0 {
|
|
t.Errorf("expected prompt eval count > 0, got 0")
|
|
}
|
|
|
|
if actual.PromptEvalDuration == 0 {
|
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
|
}
|
|
|
|
if actual.EvalCount == 0 {
|
|
t.Errorf("expected eval count > 0, got 0")
|
|
}
|
|
|
|
if actual.EvalDuration == 0 {
|
|
t.Errorf("expected eval duration > 0, got 0")
|
|
}
|
|
|
|
if actual.LoadDuration == 0 {
|
|
t.Errorf("expected load duration > 0, got 0")
|
|
}
|
|
|
|
if actual.TotalDuration == 0 {
|
|
t.Errorf("expected total duration > 0, got 0")
|
|
}
|
|
}
|
|
|
|
mock.CompletionResponse.Content = "Hi!"
|
|
t.Run("prompt", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test",
|
|
Prompt: "Hello!",
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkGenerateResponse(t, w.Body, "test", "Hi!")
|
|
})
|
|
|
|
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test-system",
|
|
From: "test",
|
|
System: "You are a helpful assistant.",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("prompt with model system", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-system",
|
|
Prompt: "Hello!",
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
|
})
|
|
|
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
|
t.Run("prompt with system", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-system",
|
|
Prompt: "Hello!",
|
|
System: "You can perform magic tricks.",
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
|
})
|
|
|
|
t.Run("prompt with template", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-system",
|
|
Prompt: "Help me write tests.",
|
|
System: "You can perform magic tricks.",
|
|
Template: `{{- if .System }}{{ .System }} {{ end }}
|
|
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
|
|
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
|
})
|
|
|
|
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test-suffix",
|
|
Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
|
{{- else }}{{ .Prompt }}
|
|
{{- end }}`,
|
|
From: "test",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("prompt with suffix", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-suffix",
|
|
Prompt: "def add(",
|
|
Suffix: " return c",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("prompt without suffix", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-suffix",
|
|
Prompt: "def add(",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("raw", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-system",
|
|
Prompt: "Help me write tests.",
|
|
Raw: true,
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
}
|