mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
Integration test improvements (#9654)
Add some new test coverage for various model architectures, and switch from orca-mini to the small llama model.
This commit is contained in:
parent
56dc316a57
commit
ed4e139314
9 changed files with 709 additions and 67 deletions
412
integration/api_test.go
Normal file
412
integration/api_test.go
Normal file
|
@ -0,0 +1,412 @@
|
|||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestAPIGenerate(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue? be brief",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering"}
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
stream bool
|
||||
}{
|
||||
{
|
||||
name: "stream",
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
name: "no_stream",
|
||||
stream: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
// Fields that must always be present
|
||||
if response.Model == "" {
|
||||
t.Errorf("response missing model: %#v", response)
|
||||
}
|
||||
if response.Done {
|
||||
// Required fields for final updates:
|
||||
if response.DoneReason == "" && *req.Stream {
|
||||
// TODO - is the lack of done reason on non-stream a bug?
|
||||
t.Errorf("final response missing done_reason: %#v", response)
|
||||
}
|
||||
if response.Metrics.TotalDuration == 0 {
|
||||
t.Errorf("final response missing total_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.LoadDuration == 0 {
|
||||
t.Errorf("final response missing load_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.PromptEvalDuration == 0 {
|
||||
t.Errorf("final response missing prompt_eval_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.EvalCount == 0 {
|
||||
t.Errorf("final response missing eval_count: %#v", response)
|
||||
}
|
||||
if response.Metrics.EvalDuration == 0 {
|
||||
t.Errorf("final response missing eval_duration: %#v", response)
|
||||
}
|
||||
if len(response.Context) == 0 {
|
||||
t.Errorf("final response missing context: %#v", response)
|
||||
}
|
||||
|
||||
// Note: caching can result in no prompt eval count, so this can't be verified reliably
|
||||
// if response.Metrics.PromptEvalCount == 0 {
|
||||
// t.Errorf("final response missing prompt_eval_count: %#v", response)
|
||||
// }
|
||||
|
||||
} // else incremental response, nothing to check right now...
|
||||
buf.Write([]byte(response.Response))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
req.Stream = &test.stream
|
||||
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
|
||||
genErr = client.Generate(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("failed with %s request prompt %s ", req.Model, req.Prompt)
|
||||
}
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Errorf("none of %v found in %s", anyResp, response)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Validate PS while we're at it...
|
||||
resp, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list models API error: %s", err)
|
||||
}
|
||||
if resp == nil || len(resp.Models) == 0 {
|
||||
t.Fatalf("list models API returned empty list while model should still be loaded")
|
||||
}
|
||||
// Find the model we just loaded and verify some attributes
|
||||
found := false
|
||||
for _, model := range resp.Models {
|
||||
if strings.Contains(model.Name, req.Model) {
|
||||
found = true
|
||||
if model.Model == "" {
|
||||
t.Errorf("model field omitted: %#v", model)
|
||||
}
|
||||
if model.Size == 0 {
|
||||
t.Errorf("size omitted: %#v", model)
|
||||
}
|
||||
if model.Digest == "" {
|
||||
t.Errorf("digest omitted: %#v", model)
|
||||
}
|
||||
verifyModelDetails(t, model.Details)
|
||||
var nilTime time.Time
|
||||
if model.ExpiresAt == nilTime {
|
||||
t.Errorf("expires_at omitted: %#v", model)
|
||||
}
|
||||
// SizeVRAM could be zero.
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("unable to locate running model: %#v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIChat(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "why is the sky blue? be brief",
|
||||
},
|
||||
},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering"}
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
stream bool
|
||||
}{
|
||||
{
|
||||
name: "stream",
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
name: "no_stream",
|
||||
stream: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// Fields that must always be present
|
||||
if response.Model == "" {
|
||||
t.Errorf("response missing model: %#v", response)
|
||||
}
|
||||
if response.Done {
|
||||
// Required fields for final updates:
|
||||
var nilTime time.Time
|
||||
if response.CreatedAt == nilTime {
|
||||
t.Errorf("final response missing total_duration: %#v", response)
|
||||
}
|
||||
if response.DoneReason == "" {
|
||||
t.Errorf("final response missing done_reason: %#v", response)
|
||||
}
|
||||
if response.Metrics.TotalDuration == 0 {
|
||||
t.Errorf("final response missing total_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.LoadDuration == 0 {
|
||||
t.Errorf("final response missing load_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.PromptEvalDuration == 0 {
|
||||
t.Errorf("final response missing prompt_eval_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.EvalCount == 0 {
|
||||
t.Errorf("final response missing eval_count: %#v", response)
|
||||
}
|
||||
if response.Metrics.EvalDuration == 0 {
|
||||
t.Errorf("final response missing eval_duration: %#v", response)
|
||||
}
|
||||
|
||||
if response.Metrics.PromptEvalCount == 0 {
|
||||
t.Errorf("final response missing prompt_eval_count: %#v", response)
|
||||
}
|
||||
} // else incremental response, nothing to check right now...
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
req.Stream = &test.stream
|
||||
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("chat never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("chat stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("failed with %s request prompt %v", req.Model, req.Messages)
|
||||
}
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Errorf("none of %v found in %s", anyResp, response)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for chat")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIListModels(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Make sure we have at least one model so an empty list can be considered a failure
|
||||
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to list models: %s", err)
|
||||
}
|
||||
if len(resp.Models) == 0 {
|
||||
t.Fatalf("list should not be empty")
|
||||
}
|
||||
model := resp.Models[0]
|
||||
if model.Name == "" {
|
||||
t.Errorf("first model name empty: %#v", model)
|
||||
}
|
||||
var nilTime time.Time
|
||||
if model.ModifiedAt == nilTime {
|
||||
t.Errorf("first model modified_at empty: %#v", model)
|
||||
}
|
||||
if model.Size == 0 {
|
||||
t.Errorf("first model size empty: %#v", model)
|
||||
}
|
||||
if model.Digest == "" {
|
||||
t.Errorf("first model digest empty: %#v", model)
|
||||
}
|
||||
verifyModelDetails(t, model.Details)
|
||||
}
|
||||
|
||||
func verifyModelDetails(t *testing.T, details api.ModelDetails) {
|
||||
if details.Format == "" {
|
||||
t.Errorf("first model details.format empty: %#v", details)
|
||||
}
|
||||
if details.Family == "" {
|
||||
t.Errorf("first model details.family empty: %#v", details)
|
||||
}
|
||||
if details.ParameterSize == "" {
|
||||
t.Errorf("first model details.parameter_size empty: %#v", details)
|
||||
}
|
||||
if details.QuantizationLevel == "" {
|
||||
t.Errorf("first model details.quantization_level empty: %#v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIShowModel(t *testing.T) {
|
||||
modelName := "llama3.2"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: modelName})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to show model: %s", err)
|
||||
}
|
||||
if resp.License == "" {
|
||||
t.Errorf("%s missing license: %#v", modelName, resp)
|
||||
}
|
||||
if resp.Modelfile == "" {
|
||||
t.Errorf("%s missing modelfile: %#v", modelName, resp)
|
||||
}
|
||||
if resp.Parameters == "" {
|
||||
t.Errorf("%s missing parameters: %#v", modelName, resp)
|
||||
}
|
||||
if resp.Template == "" {
|
||||
t.Errorf("%s missing template: %#v", modelName, resp)
|
||||
}
|
||||
// llama3 omits system
|
||||
verifyModelDetails(t, resp.Details)
|
||||
// llama3 ommits messages
|
||||
if len(resp.ModelInfo) == 0 {
|
||||
t.Errorf("%s missing model_info: %#v", modelName, resp)
|
||||
}
|
||||
// llama3 omits projectors
|
||||
var nilTime time.Time
|
||||
if resp.ModifiedAt == nilTime {
|
||||
t.Errorf("%s missing modified_at: %#v", modelName, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
resp, err := client.Embeddings(ctx, &req)
|
||||
if err != nil {
|
||||
t.Fatalf("embeddings call failed %s", err)
|
||||
}
|
||||
if len(resp.Embedding) == 0 {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
}
|
|
@ -14,12 +14,12 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
func TestBlueSky(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
|
@ -31,6 +31,7 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUnicode(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
|
@ -93,7 +94,7 @@ func TestUnicodeModelDir(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
|
|
|
@ -21,7 +21,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
|||
var (
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Model: "llama3.2:1b",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
|
@ -67,7 +67,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
func TestIntegrationConcurrentPredict(t *testing.T) {
|
||||
req, resp := GenerateRequests()
|
||||
reqLimit := len(req)
|
||||
iterLimit := 5
|
||||
|
@ -117,6 +117,9 @@ func TestMultiModelStress(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if maxVram < 2*format.GibiByte {
|
||||
t.Skip("VRAM less than 2G, skipping model stress tests")
|
||||
}
|
||||
|
||||
type model struct {
|
||||
name string
|
||||
|
@ -125,8 +128,8 @@ func TestMultiModelStress(t *testing.T) {
|
|||
|
||||
smallModels := []model{
|
||||
{
|
||||
name: "orca-mini",
|
||||
size: 2992 * format.MebiByte,
|
||||
name: "llama3.2:1b",
|
||||
size: 2876 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "phi",
|
||||
|
|
|
@ -12,58 +12,51 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntegrationLlava(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: "llava:7b",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
func TestVisionModels(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
type testCase struct {
|
||||
model string
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
model: "llava:7b",
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
{
|
||||
model: "llama3.2-vision",
|
||||
},
|
||||
{
|
||||
model: "gemma3",
|
||||
},
|
||||
}
|
||||
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
for _, v := range testCases {
|
||||
t.Run(v.model, func(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: v.model,
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
|
||||
func TestIntegrationMllama(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
// TODO fix up once we publish the final image
|
||||
Model: "x/llama3.2-vision",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
resp := "the ollamas"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// mllama models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
|
|
|
@ -17,7 +17,7 @@ var (
|
|||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
|
@ -25,7 +25,7 @@ var (
|
|||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
|
@ -35,12 +35,12 @@ var (
|
|||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight"},
|
||||
{"sunlight", "scattering", "interact"},
|
||||
{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
||||
func TestIntegrationSimple(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||
|
|
|
@ -30,7 +30,7 @@ func TestMaxQueue(t *testing.T) {
|
|||
t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
|
@ -61,7 +61,7 @@ func TestMaxQueue(t *testing.T) {
|
|||
}()
|
||||
|
||||
// Give the generate a chance to get started before we start hammering on embed requests
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
|
||||
busyCount := 0
|
||||
|
|
195
integration/model_arch_test.go
Normal file
195
integration/model_arch_test.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
//go:build integration && models
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var (
|
||||
started = time.Now()
|
||||
chatModels = []string{
|
||||
"granite3-moe:latest",
|
||||
"granite-code:latest",
|
||||
"nemotron-mini:latest",
|
||||
"command-r:latest",
|
||||
"gemma2:latest",
|
||||
"gemma:latest",
|
||||
"internlm2:latest",
|
||||
"phi3.5:latest",
|
||||
"phi3:latest",
|
||||
// "phi:latest", // flaky, sometimes generates no response on first query
|
||||
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||
"falcon:latest",
|
||||
"falcon2:latest",
|
||||
"minicpm-v:latest",
|
||||
"mistral:latest",
|
||||
"orca-mini:latest",
|
||||
"llama2:latest",
|
||||
"llama3.1:latest",
|
||||
"llama3.2:latest",
|
||||
"llama3.2-vision:latest",
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen:latest",
|
||||
"solar-pro:latest",
|
||||
}
|
||||
)
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
deadline, hasDeadline := t.Deadline()
|
||||
if !hasDeadline {
|
||||
return 8 * time.Minute, 10 * time.Minute
|
||||
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
|
||||
t.Skip("too little time")
|
||||
return time.Duration(0), time.Duration(0)
|
||||
}
|
||||
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
|
||||
}
|
||||
|
||||
func TestModelsGenerate(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// TODO use info API eventually
|
||||
var maxVram uint64
|
||||
var err error
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
if maxVram > 0 {
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list models failed %v", err)
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
|
||||
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO - fiddle with context size
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
|
||||
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelsEmbed(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// TODO use info API eventually
|
||||
var maxVram uint64
|
||||
var err error
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(filepath.Join("testdata", "embed.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open test data file: %s", err)
|
||||
}
|
||||
testCase := map[string][]float64{}
|
||||
err = json.Unmarshal(data, &testCase)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load test data: %s", err)
|
||||
}
|
||||
for model, expected := range testCase {
|
||||
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
if maxVram > 0 {
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list models failed %v", err)
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
|
||||
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||
}
|
||||
}
|
||||
}
|
||||
req := api.EmbeddingRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
resp, err := client.Embeddings(ctx, &req)
|
||||
if err != nil {
|
||||
t.Fatalf("embeddings call failed %s", err)
|
||||
}
|
||||
if len(resp.Embedding) == 0 {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
if len(expected) != len(resp.Embedding) {
|
||||
expStr := make([]string, len(resp.Embedding))
|
||||
for i, v := range resp.Embedding {
|
||||
expStr[i] = fmt.Sprintf("%0.6f", v)
|
||||
}
|
||||
// When adding new models, use this output to populate the testdata/embed.json
|
||||
fmt.Printf("expected\n%s\n", strings.Join(expStr, ", "))
|
||||
t.Fatalf("expected %d, got %d", len(expected), len(resp.Embedding))
|
||||
}
|
||||
sim := cosineSimilarity(resp.Embedding, expected)
|
||||
if sim < 0.99 {
|
||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], resp.Embedding[0:5], sim)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
21
integration/testdata/embed.json
vendored
Normal file
21
integration/testdata/embed.json
vendored
Normal file
File diff suppressed because one or more lines are too long
|
@ -24,9 +24,14 @@ import (
|
|||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
smol = "llama3.2:1b"
|
||||
)
|
||||
|
||||
func Init() {
|
||||
lifecycle.InitLogging()
|
||||
}
|
||||
|
@ -140,7 +145,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
|
|||
|
||||
showCtx, cancel := context.WithDeadlineCause(
|
||||
ctx,
|
||||
time.Now().Add(10*time.Second),
|
||||
time.Now().Add(20*time.Second),
|
||||
fmt.Errorf("show for existing model %s took too long", modelName),
|
||||
)
|
||||
defer cancel()
|
||||
|
@ -157,7 +162,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
|
|||
}
|
||||
slog.Info("model missing", "model", modelName)
|
||||
|
||||
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
|
||||
stallDuration := 60 * time.Second // This includes checksum verification, which can take a while on larger models, and slower systems
|
||||
stallTimer := time.NewTimer(stallDuration)
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
// fmt.Print(".")
|
||||
|
@ -283,11 +288,11 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|||
}
|
||||
|
||||
// Generate a set of requests
|
||||
// By default each request uses orca-mini as the model
|
||||
// By default each request uses llama3.2 as the model
|
||||
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
return []api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
|
@ -296,7 +301,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
|
@ -305,7 +310,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
|
@ -314,7 +319,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
|
@ -323,7 +328,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "what is the composition of air?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
|
@ -341,3 +346,15 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
require.NoError(t, err)
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < gb*format.GibiByte {
|
||||
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue