mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
Request and model concurrency
This change adds support for multiple concurrent requests, as well as loading multiple models by spawning multiple runners. The default settings are currently set at 1 concurrent request per model and only 1 loaded model at a time, but these can be adjusted by setting OLLAMA_NUM_PARALLEL and OLLAMA_MAX_LOADED_MODELS.
This commit is contained in:
parent
ee448deaba
commit
34b9db5afc
30 changed files with 2572 additions and 1387 deletions
|
@ -4,7 +4,6 @@ package integration
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -24,5 +23,5 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
|||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
|
||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||
}
|
||||
|
|
225
integration/concurrency_test.go
Normal file
225
integration/concurrency_test.go
Normal file
|
@ -0,0 +1,225 @@
|
|||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMultiModelConcurrency(t *testing.T) {
|
||||
var (
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "tinydolphin",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
[]string{"sunlight"},
|
||||
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
GenerateTestHelper(ctx, t, req[i], resp[i])
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req, resp := GenerateRequests()
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 5; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the 4 concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||
func TestMultiModelStress(t *testing.T) {
|
||||
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if vram == "" {
|
||||
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||
}
|
||||
max, err := strconv.ParseUint(vram, 10, 64)
|
||||
require.NoError(t, err)
|
||||
const MB = uint64(1024 * 1024)
|
||||
type model struct {
|
||||
name string
|
||||
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
|
||||
}
|
||||
|
||||
smallModels := []model{
|
||||
{
|
||||
name: "orca-mini",
|
||||
size: 2992 * MB,
|
||||
},
|
||||
{
|
||||
name: "phi",
|
||||
size: 2616 * MB,
|
||||
},
|
||||
{
|
||||
name: "gemma:2b",
|
||||
size: 2364 * MB,
|
||||
},
|
||||
{
|
||||
name: "stable-code:3b",
|
||||
size: 2608 * MB,
|
||||
},
|
||||
{
|
||||
name: "starcoder2:3b",
|
||||
size: 2166 * MB,
|
||||
},
|
||||
}
|
||||
mediumModels := []model{
|
||||
{
|
||||
name: "llama2",
|
||||
size: 5118 * MB,
|
||||
},
|
||||
{
|
||||
name: "mistral",
|
||||
size: 4620 * MB,
|
||||
},
|
||||
{
|
||||
name: "orca-mini:7b",
|
||||
size: 5118 * MB,
|
||||
},
|
||||
{
|
||||
name: "dolphin-mistral",
|
||||
size: 4620 * MB,
|
||||
},
|
||||
{
|
||||
name: "gemma:7b",
|
||||
size: 5000 * MB,
|
||||
},
|
||||
// TODO - uncomment this once #3565 is merged and this is rebased on it
|
||||
// {
|
||||
// name: "codellama:7b",
|
||||
// size: 5118 * MB,
|
||||
// },
|
||||
}
|
||||
|
||||
// These seem to be too slow to be useful...
|
||||
// largeModels := []model{
|
||||
// {
|
||||
// name: "llama2:13b",
|
||||
// size: 7400 * MB,
|
||||
// },
|
||||
// {
|
||||
// name: "codellama:13b",
|
||||
// size: 7400 * MB,
|
||||
// },
|
||||
// {
|
||||
// name: "orca-mini:13b",
|
||||
// size: 7400 * MB,
|
||||
// },
|
||||
// {
|
||||
// name: "gemma:7b",
|
||||
// size: 5000 * MB,
|
||||
// },
|
||||
// {
|
||||
// name: "starcoder2:15b",
|
||||
// size: 9100 * MB,
|
||||
// },
|
||||
// }
|
||||
|
||||
var chosenModels []model
|
||||
switch {
|
||||
case max < 10000*MB:
|
||||
slog.Info("selecting small models")
|
||||
chosenModels = smallModels
|
||||
// case max < 30000*MB:
|
||||
default:
|
||||
slog.Info("selecting medium models")
|
||||
chosenModels = mediumModels
|
||||
// default:
|
||||
// slog.Info("selecting large models")
|
||||
// chosenModels = largModels
|
||||
}
|
||||
|
||||
req, resp := GenerateRequests()
|
||||
|
||||
for i := range req {
|
||||
if i > len(chosenModels) {
|
||||
break
|
||||
}
|
||||
req[i].Model = chosenModels[i].name
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Make sure all the models are pulled before we get started
|
||||
for _, r := range req {
|
||||
require.NoError(t, PullIfMissing(ctx, client, r.Model))
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
consumed := uint64(256 * MB) // Assume some baseline usage
|
||||
for i := 0; i < len(req); i++ {
|
||||
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
|
||||
if i > 1 && consumed > max {
|
||||
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
|
||||
break
|
||||
}
|
||||
consumed += chosenModels[i].size
|
||||
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
|
||||
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 3; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
|
@ -4,7 +4,6 @@ package integration
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
|
|||
"num_ctx": 128,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
|
||||
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ package integration
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
resp := "the ollamas"
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
|
||||
GenerateTestHelper(ctx, t, req, []string{resp})
|
||||
}
|
||||
|
||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||
|
|
|
@ -4,8 +4,6 @@ package integration
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -45,25 +43,5 @@ var (
|
|||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
|
||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||
}
|
||||
|
||||
// TODO
|
||||
// The server always loads a new runner and closes the old one, which forces serial execution
|
||||
// At present this test case fails with concurrency problems. Eventually we should try to
|
||||
// get true concurrency working with n_parallel support in the backend
|
||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TODO - create a parallel test with 2 different models once we support concurrency
|
||||
|
|
|
@ -5,13 +5,14 @@ package integration
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
@ -23,9 +24,13 @@ import (
|
|||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Init() {
|
||||
lifecycle.InitLogging()
|
||||
}
|
||||
|
||||
func FindPort() string {
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
|
@ -41,7 +46,7 @@ func FindPort() string {
|
|||
return strconv.Itoa(port)
|
||||
}
|
||||
|
||||
func GetTestEndpoint() (string, string) {
|
||||
func GetTestEndpoint() (*api.Client, string) {
|
||||
defaultPort := "11434"
|
||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
||||
|
||||
|
@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
|
|||
port = FindPort()
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s:%s", host, port)
|
||||
slog.Info("server connection", "url", url)
|
||||
return scheme, url
|
||||
slog.Info("server connection", "host", host, "port", port)
|
||||
|
||||
return api.NewClient(
|
||||
&url.URL{
|
||||
Scheme: scheme,
|
||||
Host: net.JoinHostPort(host, port),
|
||||
},
|
||||
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
|
||||
}
|
||||
|
||||
// TODO make fanicier, grab logs, etc.
|
||||
var serverMutex sync.Mutex
|
||||
var serverReady bool
|
||||
|
||||
func StartServer(ctx context.Context, ollamaHost string) error {
|
||||
func startServer(ctx context.Context, ollamaHost string) error {
|
||||
// Make sure the server has been built
|
||||
CLIName, err := filepath.Abs("../ollama")
|
||||
if err != nil {
|
||||
|
@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
|
||||
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
|
||||
slog.Info("checking status of model", "model", modelName)
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
requestJSON, err := json.Marshal(showReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
|
||||
if err != nil {
|
||||
showCtx, cancel := context.WithDeadlineCause(
|
||||
ctx,
|
||||
time.Now().Add(5*time.Second),
|
||||
fmt.Errorf("show for existing model %s took too long", modelName),
|
||||
)
|
||||
defer cancel()
|
||||
_, err := client.Show(showCtx, showReq)
|
||||
var statusError api.StatusError
|
||||
switch {
|
||||
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||
break
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
// Make the request with the HTTP client
|
||||
response, err := client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode == 200 {
|
||||
default:
|
||||
slog.Info("model already present", "model", modelName)
|
||||
return nil
|
||||
}
|
||||
slog.Info("model missing", "status", response.StatusCode)
|
||||
slog.Info("model missing", "model", modelName)
|
||||
|
||||
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
|
||||
stallTimer := time.NewTimer(stallDuration)
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
// fmt.Print(".")
|
||||
if !stallTimer.Reset(stallDuration) {
|
||||
return fmt.Errorf("stall was detected, aborting status reporting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
|
||||
requestJSON, err = json.Marshal(pullReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("pulling", "model", modelName)
|
||||
var pullError error
|
||||
|
||||
response, err = client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
done := make(chan int)
|
||||
go func() {
|
||||
pullError = client.Pull(ctx, pullReq, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
return fmt.Errorf("download stalled")
|
||||
case <-done:
|
||||
return pullError
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != 200 {
|
||||
return fmt.Errorf("failed to pull model") // TODO more details perhaps
|
||||
}
|
||||
slog.Info("model pulled", "model", modelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
var serverProcMutex sync.Mutex
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
|
||||
|
||||
// TODO maybe stuff in an init routine?
|
||||
lifecycle.InitLogging()
|
||||
|
||||
requestJSON, err := json.Marshal(genReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Error serializing request: %v", err)
|
||||
// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
|
||||
// Starts the server if needed
|
||||
func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
|
||||
client, testEndpoint := GetTestEndpoint()
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
serverProcMutex.Lock()
|
||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate log file: %s", err)
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
require.NoError(t, startServer(ctx, testEndpoint))
|
||||
}
|
||||
defer func() {
|
||||
|
||||
return client, testEndpoint, func() {
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
defer serverProcMutex.Unlock()
|
||||
if t.Failed() {
|
||||
|
@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
|
|||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err = os.Remove(lifecycle.ServerLogFile)
|
||||
err := os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
scheme, testEndpoint := GetTestEndpoint()
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
serverProcMutex.Lock()
|
||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate log file: %s", err)
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
assert.NoError(t, StartServer(ctx, testEndpoint))
|
||||
}
|
||||
|
||||
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
|
||||
if err != nil {
|
||||
t.Fatalf("Error pulling model: %v", err)
|
||||
}
|
||||
|
||||
// Make the request and get the response
|
||||
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating request: %v", err)
|
||||
}
|
||||
|
||||
// Set the content type for the request
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Make the request with the HTTP client
|
||||
response, err := client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
t.Fatalf("Error making request: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
body, err := io.ReadAll(response.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, response.StatusCode, 200, string(body))
|
||||
|
||||
// Verify the response is valid JSON
|
||||
var payload api.GenerateResponse
|
||||
err = json.Unmarshal(body, &payload)
|
||||
if err != nil {
|
||||
assert.NoError(t, err, body)
|
||||
}
|
||||
|
||||
// Verify the response contains the expected data
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(payload.Response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
|
||||
}
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
// fmt.Print(".")
|
||||
buf.Write([]byte(response.Response))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
genReq.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Generate(ctx, &genReq, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a set of requests
|
||||
// By default each request uses orca-mini as the model
|
||||
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
return []api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the composition of air?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
[][]string{
|
||||
[]string{"sunlight"},
|
||||
[]string{"soil", "organic", "earth", "black", "tan"},
|
||||
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||
[]string{"fourth", "july", "declaration", "independence"},
|
||||
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue