mirror of
https://github.com/ollama/ollama.git
synced 2025-05-17 06:55:09 +02:00
293 lines
8.4 KiB
Go
293 lines
8.4 KiB
Go
// Package benchmark provides tools for performance testing of Ollama inference server and supported models.
|
|
package benchmark
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"testing"
|
|
"text/tabwriter"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
// ServerURL is the default Ollama server URL for benchmarking
|
|
const serverURL = "http://127.0.0.1:11434"
|
|
|
|
// metrics collects all benchmark results for final reporting
|
|
var metrics []BenchmarkMetrics
|
|
|
|
// models contains the list of model names to benchmark
|
|
var models = []string{
|
|
"llama3.2:1b",
|
|
// "qwen2.5:7b",
|
|
// "llama3.3:70b",
|
|
}
|
|
|
|
// TestCase defines a benchmark test scenario with prompt characteristics
|
|
type TestCase struct {
|
|
name string // Human-readable test name
|
|
prompt string // Input prompt text
|
|
maxTokens int // Maximum tokens to generate
|
|
}
|
|
|
|
// BenchmarkMetrics contains performance measurements for a single test run
|
|
type BenchmarkMetrics struct {
|
|
model string // Model being tested
|
|
scenario string // cold_start or warm_start
|
|
testName string // Name of the test case
|
|
ttft time.Duration // Time To First Token (TTFT)
|
|
totalTime time.Duration // Total time for complete response
|
|
totalTokens int // Total generated tokens
|
|
tokensPerSecond float64 // Calculated throughput
|
|
}
|
|
|
|
// ScenarioType defines the initialization state for benchmarking
|
|
type ScenarioType int
|
|
|
|
const (
|
|
ColdStart ScenarioType = iota // Model is loaded from cold state
|
|
WarmStart // Model is already loaded in memory
|
|
)
|
|
|
|
// String implements fmt.Stringer for ScenarioType
|
|
func (s ScenarioType) String() string {
|
|
return [...]string{"cold_start", "warm_start"}[s]
|
|
}
|
|
|
|
// BenchmarkServerInference is the main entry point for benchmarking Ollama inference performance.
|
|
// It tests all configured models with different prompt lengths and start scenarios.
|
|
func BenchmarkServerInference(b *testing.B) {
|
|
b.Logf("Starting benchmark suite with %d models", len(models))
|
|
|
|
// Verify server availability
|
|
if _, err := http.Get(serverURL + "/api/version"); err != nil {
|
|
b.Fatalf("Server unavailable: %v", err)
|
|
}
|
|
b.Log("Server available")
|
|
|
|
tests := []TestCase{
|
|
{"short_prompt", "Write a long story", 100},
|
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
|
}
|
|
|
|
// Register cleanup handler for results reporting
|
|
b.Cleanup(func() { reportMetrics(metrics) })
|
|
|
|
// Main benchmark loop
|
|
for _, model := range models {
|
|
client := api.NewClient(mustParse(serverURL), http.DefaultClient)
|
|
// Verify model availability
|
|
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: model}); err != nil {
|
|
b.Fatalf("Model unavailable: %v", err)
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
testName := fmt.Sprintf("%s/%s/%s", model, ColdStart, tt.name)
|
|
b.Run(testName, func(b *testing.B) {
|
|
m := runBenchmark(b, tt, model, ColdStart, client)
|
|
metrics = append(metrics, m...)
|
|
})
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
testName := fmt.Sprintf("%s/%s/%s", model, WarmStart, tt.name)
|
|
b.Run(testName, func(b *testing.B) {
|
|
m := runBenchmark(b, tt, model, WarmStart, client)
|
|
metrics = append(metrics, m...)
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// runBenchmark executes multiple iterations of a specific test case and scenario.
|
|
// Returns collected metrics for all iterations.
|
|
func runBenchmark(b *testing.B, tt TestCase, model string, scenario ScenarioType, client *api.Client) []BenchmarkMetrics {
|
|
results := make([]BenchmarkMetrics, b.N)
|
|
|
|
// Run benchmark iterations
|
|
for i := 0; i < b.N; i++ {
|
|
switch scenario {
|
|
case WarmStart:
|
|
// Pre-warm the model by generating some tokens
|
|
for i := 0; i < 2; i++ {
|
|
client.Generate(
|
|
context.Background(),
|
|
&api.GenerateRequest{
|
|
Model: model,
|
|
Prompt: tt.prompt,
|
|
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
|
|
},
|
|
func(api.GenerateResponse) error { return nil },
|
|
)
|
|
}
|
|
case ColdStart:
|
|
unloadModel(client, model, b)
|
|
}
|
|
b.ResetTimer()
|
|
|
|
results[i] = runSingleIteration(context.Background(), client, tt, model, b)
|
|
results[i].scenario = scenario.String()
|
|
}
|
|
return results
|
|
}
|
|
|
|
// unloadModel forces model unloading using KeepAlive: -1 parameter.
|
|
// Includes short delay to ensure unloading completes before next test.
|
|
func unloadModel(client *api.Client, model string, b *testing.B) {
|
|
req := &api.GenerateRequest{
|
|
Model: model,
|
|
KeepAlive: &api.Duration{Duration: 0},
|
|
}
|
|
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
|
b.Logf("Unload error: %v", err)
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
|
|
// runSingleIteration measures performance metrics for a single inference request.
|
|
// Captures TTFT, total generation time, and calculates tokens/second.
|
|
func runSingleIteration(ctx context.Context, client *api.Client, tt TestCase, model string, b *testing.B) BenchmarkMetrics {
|
|
start := time.Now()
|
|
var ttft time.Duration
|
|
var tokens int
|
|
lastToken := start
|
|
|
|
req := &api.GenerateRequest{
|
|
Model: model,
|
|
Prompt: tt.prompt,
|
|
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
|
|
}
|
|
|
|
if b != nil {
|
|
b.Logf("Prompt length: %d chars", len(tt.prompt))
|
|
}
|
|
|
|
// Execute generation request with metrics collection
|
|
client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
|
if ttft == 0 {
|
|
ttft = time.Since(start)
|
|
}
|
|
if resp.Response != "" {
|
|
tokens++
|
|
lastToken = time.Now()
|
|
}
|
|
return nil
|
|
})
|
|
|
|
totalTime := lastToken.Sub(start)
|
|
return BenchmarkMetrics{
|
|
model: model,
|
|
testName: tt.name,
|
|
ttft: ttft,
|
|
totalTime: totalTime,
|
|
totalTokens: tokens,
|
|
tokensPerSecond: float64(tokens) / totalTime.Seconds(),
|
|
}
|
|
}
|
|
|
|
// reportMetrics processes collected metrics and prints formatted results.
|
|
// Generates both human-readable tables and CSV output with averaged statistics.
|
|
func reportMetrics(results []BenchmarkMetrics) {
|
|
if len(results) == 0 {
|
|
return
|
|
}
|
|
|
|
// Aggregate results by test case
|
|
type statsKey struct {
|
|
model string
|
|
scenario string
|
|
testName string
|
|
}
|
|
stats := make(map[statsKey]*struct {
|
|
ttftSum time.Duration
|
|
totalTimeSum time.Duration
|
|
tokensSum int
|
|
iterations int
|
|
})
|
|
|
|
for _, m := range results {
|
|
key := statsKey{m.model, m.scenario, m.testName}
|
|
if _, exists := stats[key]; !exists {
|
|
stats[key] = &struct {
|
|
ttftSum time.Duration
|
|
totalTimeSum time.Duration
|
|
tokensSum int
|
|
iterations int
|
|
}{}
|
|
}
|
|
|
|
stats[key].ttftSum += m.ttft
|
|
stats[key].totalTimeSum += m.totalTime
|
|
stats[key].tokensSum += m.totalTokens
|
|
stats[key].iterations++
|
|
}
|
|
|
|
// Calculate averages
|
|
var averaged []BenchmarkMetrics
|
|
for key, data := range stats {
|
|
count := data.iterations
|
|
averaged = append(averaged, BenchmarkMetrics{
|
|
model: key.model,
|
|
scenario: key.scenario,
|
|
testName: key.testName,
|
|
ttft: data.ttftSum / time.Duration(count),
|
|
totalTime: data.totalTimeSum / time.Duration(count),
|
|
totalTokens: data.tokensSum / count,
|
|
tokensPerSecond: float64(data.tokensSum) / data.totalTimeSum.Seconds(),
|
|
})
|
|
}
|
|
|
|
// Print formatted results
|
|
printTableResults(averaged)
|
|
printCSVResults(averaged)
|
|
}
|
|
|
|
// printTableResults displays averaged metrics in a formatted table
|
|
func printTableResults(averaged []BenchmarkMetrics) {
|
|
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
|
fmt.Fprintln(w, "\nAVERAGED BENCHMARK RESULTS")
|
|
fmt.Fprintln(w, "Model\tScenario\tTest Name\tTTFT (ms)\tTotal Time (ms)\tTokens\tTokens/sec")
|
|
for _, m := range averaged {
|
|
fmt.Fprintf(w, "%s\t%s\t%s\t%.2f\t%.2f\t%d\t%.2f\n",
|
|
m.model,
|
|
m.scenario,
|
|
m.testName,
|
|
float64(m.ttft.Milliseconds()),
|
|
float64(m.totalTime.Milliseconds()),
|
|
m.totalTokens,
|
|
m.tokensPerSecond,
|
|
)
|
|
}
|
|
w.Flush()
|
|
}
|
|
|
|
// printCSVResults outputs averaged metrics in CSV format
|
|
func printCSVResults(averaged []BenchmarkMetrics) {
|
|
fmt.Println("\nCSV OUTPUT")
|
|
fmt.Println("model,scenario,test_name,ttft_ms,total_ms,tokens,tokens_per_sec")
|
|
for _, m := range averaged {
|
|
fmt.Printf("%s,%s,%s,%.2f,%.2f,%d,%.2f\n",
|
|
m.model,
|
|
m.scenario,
|
|
m.testName,
|
|
float64(m.ttft.Milliseconds()),
|
|
float64(m.totalTime.Milliseconds()),
|
|
m.totalTokens,
|
|
m.tokensPerSecond,
|
|
)
|
|
}
|
|
}
|
|
|
|
// mustParse is a helper function to parse URLs with panic on error
|
|
func mustParse(rawURL string) *url.URL {
|
|
u, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return u
|
|
}
|