mirror of
https://github.com/ollama/ollama.git
synced 2025-05-17 06:55:09 +02:00
175 lines
4.2 KiB
Go
175 lines
4.2 KiB
Go
package benchmark
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
runnerURL = "http://localhost:8080"
|
|
warmupPrompts = 2 // Number of warm-up requests per test case
|
|
warmupTokens = 50 // Smaller token count for warm-up requests
|
|
)
|
|
|
|
var runnerMetrics []BenchmarkMetrics
|
|
|
|
// CompletionRequest represents the request body for the completion endpoint
|
|
type CompletionRequest struct {
|
|
Prompt string `json:"prompt"`
|
|
NumPredict int `json:"n_predict"`
|
|
Temperature float32 `json:"temperature"`
|
|
}
|
|
|
|
// CompletionResponse represents a single response chunk from the streaming API
|
|
type CompletionResponse struct {
|
|
Content string `json:"content"`
|
|
Stop bool `json:"stop"`
|
|
Timings struct {
|
|
PredictedN int `json:"predicted_n"`
|
|
PredictedMs int `json:"predicted_ms"`
|
|
PromptN int `json:"prompt_n"`
|
|
PromptMs int `json:"prompt_ms"`
|
|
} `json:"timings"`
|
|
}
|
|
|
|
// warmUp performs warm-up requests before the actual benchmark
|
|
func warmUp(b *testing.B, tt TestCase) {
|
|
b.Logf("Warming up for test case %s", tt.name)
|
|
warmupTest := TestCase{
|
|
name: tt.name + "_warmup",
|
|
prompt: tt.prompt,
|
|
maxTokens: warmupTokens,
|
|
}
|
|
|
|
for i := 0; i < warmupPrompts; i++ {
|
|
runCompletion(context.Background(), warmupTest, b)
|
|
time.Sleep(100 * time.Millisecond) // Brief pause between warm-up requests
|
|
}
|
|
b.Logf("Warm-up complete")
|
|
}
|
|
|
|
func BenchmarkRunnerInference(b *testing.B) {
|
|
b.Logf("Starting benchmark suite")
|
|
|
|
// Verify server availability
|
|
if _, err := http.Get(runnerURL + "/health"); err != nil {
|
|
b.Fatalf("Runner unavailable: %v", err)
|
|
}
|
|
b.Log("Runner available")
|
|
|
|
tests := []TestCase{
|
|
{
|
|
name: "short_prompt",
|
|
prompt: formatPrompt("Write a long story"),
|
|
maxTokens: 100,
|
|
},
|
|
{
|
|
name: "medium_prompt",
|
|
prompt: formatPrompt("Write a detailed economic analysis"),
|
|
maxTokens: 500,
|
|
},
|
|
{
|
|
name: "long_prompt",
|
|
prompt: formatPrompt("Write a comprehensive AI research paper"),
|
|
maxTokens: 1000,
|
|
},
|
|
}
|
|
|
|
// Register cleanup handler for results reporting
|
|
b.Cleanup(func() { reportMetrics(metrics) })
|
|
|
|
// Main benchmark loop
|
|
for _, tt := range tests {
|
|
b.Run(tt.name, func(b *testing.B) {
|
|
// Perform warm-up requests
|
|
warmUp(b, tt)
|
|
|
|
// Wait a bit after warm-up before starting the actual benchmark
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
m := make([]BenchmarkMetrics, b.N)
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
b.ResetTimer()
|
|
m[i] = runCompletion(context.Background(), tt, b)
|
|
}
|
|
metrics = append(metrics, m...)
|
|
})
|
|
}
|
|
}
|
|
|
|
func formatPrompt(text string) string {
|
|
return fmt.Sprintf("<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", text)
|
|
}
|
|
|
|
func runCompletion(ctx context.Context, tt TestCase, b *testing.B) BenchmarkMetrics {
|
|
start := time.Now()
|
|
var ttft time.Duration
|
|
var tokens int
|
|
lastToken := start
|
|
|
|
// Create request body
|
|
reqBody := CompletionRequest{
|
|
Prompt: tt.prompt,
|
|
NumPredict: tt.maxTokens,
|
|
Temperature: 0.1,
|
|
}
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
b.Fatalf("Failed to marshal request: %v", err)
|
|
}
|
|
|
|
// Create HTTP request
|
|
req, err := http.NewRequestWithContext(ctx, "POST", runnerURL+"/completion", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
b.Fatalf("Failed to create request: %v", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
// Execute request
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
b.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Process streaming response
|
|
decoder := json.NewDecoder(resp.Body)
|
|
for {
|
|
var chunk CompletionResponse
|
|
if err := decoder.Decode(&chunk); err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
b.Fatalf("Failed to decode response: %v", err)
|
|
}
|
|
|
|
if ttft == 0 && chunk.Content != "" {
|
|
ttft = time.Since(start)
|
|
}
|
|
|
|
if chunk.Content != "" {
|
|
tokens++
|
|
lastToken = time.Now()
|
|
}
|
|
|
|
if chunk.Stop {
|
|
break
|
|
}
|
|
}
|
|
|
|
totalTime := lastToken.Sub(start)
|
|
return BenchmarkMetrics{
|
|
testName: tt.name,
|
|
ttft: ttft,
|
|
totalTime: totalTime,
|
|
totalTokens: tokens,
|
|
tokensPerSecond: float64(tokens) / totalTime.Seconds(),
|
|
}
|
|
}
|