//go:build integration package integration import ( "bytes" "context" "fmt" "math/rand" "strings" "testing" "time" "github.com/ollama/ollama/api" ) func TestAPIGenerate(t *testing.T) { initialTimeout := 60 * time.Second streamTimeout := 30 * time.Second ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() // Set up the test data req := api.GenerateRequest{ Model: smol, Prompt: "why is the sky blue? be brief", Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } anyResp := []string{"rayleigh", "scattering"} client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("pull failed %s", err) } tests := []struct { name string stream bool }{ { name: "stream", stream: true, }, { name: "no_stream", stream: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { stallTimer := time.NewTimer(initialTimeout) var buf bytes.Buffer fn := func(response api.GenerateResponse) error { // Fields that must always be present if response.Model == "" { t.Errorf("response missing model: %#v", response) } if response.Done { // Required fields for final updates: if response.DoneReason == "" && *req.Stream { // TODO - is the lack of done reason on non-stream a bug? t.Errorf("final response missing done_reason: %#v", response) } if response.Metrics.TotalDuration == 0 { t.Errorf("final response missing total_duration: %#v", response) } if response.Metrics.LoadDuration == 0 { t.Errorf("final response missing load_duration: %#v", response) } if response.Metrics.PromptEvalDuration == 0 { t.Errorf("final response missing prompt_eval_duration: %#v", response) } if response.Metrics.EvalCount == 0 { t.Errorf("final response missing eval_count: %#v", response) } if response.Metrics.EvalDuration == 0 { t.Errorf("final response missing eval_duration: %#v", response) } if len(response.Context) == 0 { t.Errorf("final response missing context: %#v", response) } // Note: caching can result in no prompt eval count, so this can't be verified reliably // if response.Metrics.PromptEvalCount == 0 { // t.Errorf("final response missing prompt_eval_count: %#v", response) // } } // else incremental response, nothing to check right now... buf.Write([]byte(response.Response)) if !stallTimer.Reset(streamTimeout) { return fmt.Errorf("stall was detected while streaming response, aborting") } return nil } done := make(chan int) var genErr error go func() { req.Stream = &test.stream req.Options["seed"] = rand.Int() // bust cache for prompt eval results genErr = client.Generate(ctx, &req, fn) done <- 0 }() select { case <-stallTimer.C: if buf.Len() == 0 { t.Errorf("generate never started. Timed out after :%s", initialTimeout.String()) } else { t.Errorf("generate stalled. Response so far:%s", buf.String()) } case <-done: if genErr != nil { t.Fatalf("failed with %s request prompt %s ", req.Model, req.Prompt) } // Verify the response contains the expected data response := buf.String() atLeastOne := false for _, resp := range anyResp { if strings.Contains(strings.ToLower(response), resp) { atLeastOne = true break } } if !atLeastOne { t.Errorf("none of %v found in %s", anyResp, response) } case <-ctx.Done(): t.Error("outer test context done while waiting for generate") } }) } // Validate PS while we're at it... resp, err := client.ListRunning(ctx) if err != nil { t.Fatalf("list models API error: %s", err) } if resp == nil || len(resp.Models) == 0 { t.Fatalf("list models API returned empty list while model should still be loaded") } // Find the model we just loaded and verify some attributes found := false for _, model := range resp.Models { if strings.Contains(model.Name, req.Model) { found = true if model.Model == "" { t.Errorf("model field omitted: %#v", model) } if model.Size == 0 { t.Errorf("size omitted: %#v", model) } if model.Digest == "" { t.Errorf("digest omitted: %#v", model) } verifyModelDetails(t, model.Details) var nilTime time.Time if model.ExpiresAt == nilTime { t.Errorf("expires_at omitted: %#v", model) } // SizeVRAM could be zero. } } if !found { t.Errorf("unable to locate running model: %#v", resp) } } func TestAPIChat(t *testing.T) { initialTimeout := 60 * time.Second streamTimeout := 30 * time.Second ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() // Set up the test data req := api.ChatRequest{ Model: smol, Messages: []api.Message{ { Role: "user", Content: "why is the sky blue? be brief", }, }, Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } anyResp := []string{"rayleigh", "scattering"} client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("pull failed %s", err) } tests := []struct { name string stream bool }{ { name: "stream", stream: true, }, { name: "no_stream", stream: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { stallTimer := time.NewTimer(initialTimeout) var buf bytes.Buffer fn := func(response api.ChatResponse) error { // Fields that must always be present if response.Model == "" { t.Errorf("response missing model: %#v", response) } if response.Done { // Required fields for final updates: var nilTime time.Time if response.CreatedAt == nilTime { t.Errorf("final response missing total_duration: %#v", response) } if response.DoneReason == "" { t.Errorf("final response missing done_reason: %#v", response) } if response.Metrics.TotalDuration == 0 { t.Errorf("final response missing total_duration: %#v", response) } if response.Metrics.LoadDuration == 0 { t.Errorf("final response missing load_duration: %#v", response) } if response.Metrics.PromptEvalDuration == 0 { t.Errorf("final response missing prompt_eval_duration: %#v", response) } if response.Metrics.EvalCount == 0 { t.Errorf("final response missing eval_count: %#v", response) } if response.Metrics.EvalDuration == 0 { t.Errorf("final response missing eval_duration: %#v", response) } if response.Metrics.PromptEvalCount == 0 { t.Errorf("final response missing prompt_eval_count: %#v", response) } } // else incremental response, nothing to check right now... buf.Write([]byte(response.Message.Content)) if !stallTimer.Reset(streamTimeout) { return fmt.Errorf("stall was detected while streaming response, aborting") } return nil } done := make(chan int) var genErr error go func() { req.Stream = &test.stream req.Options["seed"] = rand.Int() // bust cache for prompt eval results genErr = client.Chat(ctx, &req, fn) done <- 0 }() select { case <-stallTimer.C: if buf.Len() == 0 { t.Errorf("chat never started. Timed out after :%s", initialTimeout.String()) } else { t.Errorf("chat stalled. Response so far:%s", buf.String()) } case <-done: if genErr != nil { t.Fatalf("failed with %s request prompt %v", req.Model, req.Messages) } // Verify the response contains the expected data response := buf.String() atLeastOne := false for _, resp := range anyResp { if strings.Contains(strings.ToLower(response), resp) { atLeastOne = true break } } if !atLeastOne { t.Errorf("none of %v found in %s", anyResp, response) } case <-ctx.Done(): t.Error("outer test context done while waiting for chat") } }) } } func TestAPIListModels(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() // Make sure we have at least one model so an empty list can be considered a failure if err := PullIfMissing(ctx, client, smol); err != nil { t.Fatalf("pull failed %s", err) } resp, err := client.List(ctx) if err != nil { t.Fatalf("unable to list models: %s", err) } if len(resp.Models) == 0 { t.Fatalf("list should not be empty") } model := resp.Models[0] if model.Name == "" { t.Errorf("first model name empty: %#v", model) } var nilTime time.Time if model.ModifiedAt == nilTime { t.Errorf("first model modified_at empty: %#v", model) } if model.Size == 0 { t.Errorf("first model size empty: %#v", model) } if model.Digest == "" { t.Errorf("first model digest empty: %#v", model) } verifyModelDetails(t, model.Details) } func verifyModelDetails(t *testing.T, details api.ModelDetails) { if details.Format == "" { t.Errorf("first model details.format empty: %#v", details) } if details.Family == "" { t.Errorf("first model details.family empty: %#v", details) } if details.ParameterSize == "" { t.Errorf("first model details.parameter_size empty: %#v", details) } if details.QuantizationLevel == "" { t.Errorf("first model details.quantization_level empty: %#v", details) } } func TestAPIShowModel(t *testing.T) { modelName := "llama3.2" ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() if err := PullIfMissing(ctx, client, modelName); err != nil { t.Fatalf("pull failed %s", err) } resp, err := client.Show(ctx, &api.ShowRequest{Name: modelName}) if err != nil { t.Fatalf("unable to show model: %s", err) } if resp.License == "" { t.Errorf("%s missing license: %#v", modelName, resp) } if resp.Modelfile == "" { t.Errorf("%s missing modelfile: %#v", modelName, resp) } if resp.Parameters == "" { t.Errorf("%s missing parameters: %#v", modelName, resp) } if resp.Template == "" { t.Errorf("%s missing template: %#v", modelName, resp) } // llama3 omits system verifyModelDetails(t, resp.Details) // llama3 ommits messages if len(resp.ModelInfo) == 0 { t.Errorf("%s missing model_info: %#v", modelName, resp) } // llama3 omits projectors var nilTime time.Time if resp.ModifiedAt == nilTime { t.Errorf("%s missing modified_at: %#v", modelName, resp) } } func TestAPIEmbeddings(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() req := api.EmbeddingRequest{ Model: "orca-mini", Prompt: "why is the sky blue?", Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("pull failed %s", err) } resp, err := client.Embeddings(ctx, &req) if err != nil { t.Fatalf("embeddings call failed %s", err) } if len(resp.Embedding) == 0 { t.Errorf("zero length embedding response") } }