integration: fix embedding tests error handling (#10478)

The cleanup routine from InitServerconnection should run in the defer of the test case to properly detect failures and report the server logs
This commit is contained in:
Daniel Hiltgen 2025-04-29 11:57:54 -07:00 committed by GitHub
parent a27462b708
commit 7bec2724a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
func TestAllMiniLMEmbeddings(t *testing.T) { func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{ req := api.EmbeddingRequest{
Model: "all-minilm", Model: "all-minilm",
Prompt: "why is the sky blue?", Prompt: "why is the sky blue?",
} }
res, err := embeddingTestHelper(ctx, t, req) res, err := embeddingTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
func TestAllMiniLMEmbed(t *testing.T) { func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{ req := api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
} }
res, err := embedTestHelper(ctx, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) {
func TestAllMiniLMBatchEmbed(t *testing.T) { func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{ req := api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"}, Input: []string{"why is the sky blue?", "why is the grass green?"},
} }
res, err := embedTestHelper(ctx, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
func TestAllMiniLMEmbedTruncate(t *testing.T) { func TestAllMiniLMEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
truncTrue, truncFalse := true, false truncTrue, truncFalse := true, false
@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
res := make(map[string]*api.EmbedResponse) res := make(map[string]*api.EmbedResponse)
for _, req := range reqs { for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request) response, err := embedTestHelper(ctx, client, t, req.Request)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
} }
// check that truncate set to false returns an error if context length is exceeded // check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{ _, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
Truncate: &truncFalse, Truncate: &truncFalse,
@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
} }
} }
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatalf("failed to pull model %s: %v", req.Model, err)
} }
@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq
return response, nil return response, nil
} }
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatalf("failed to pull model %s: %v", req.Model, err)
} }