mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
Both interface{} and any (which is just an alias for interface{} introduced in Go 1.18) represent the empty interface that all types satisfy.
178 lines
4.6 KiB
Go
178 lines
4.6 KiB
Go
package benchmark
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
// Command line flags
|
|
var modelFlag string
|
|
|
|
func init() {
|
|
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
|
flag.Lookup("m").DefValue = "model"
|
|
}
|
|
|
|
// modelName returns the model name from flags, failing the test if not set
|
|
func modelName(b *testing.B) string {
|
|
if modelFlag == "" {
|
|
b.Fatal("Error: -m flag is required for benchmark tests")
|
|
}
|
|
return modelFlag
|
|
}
|
|
|
|
type TestCase struct {
|
|
name string
|
|
prompt string
|
|
maxTokens int
|
|
}
|
|
|
|
// runGenerateBenchmark contains the common generate and metrics logic
|
|
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
|
start := time.Now()
|
|
var ttft time.Duration
|
|
var metrics api.Metrics
|
|
|
|
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
|
if ttft == 0 && resp.Response != "" {
|
|
ttft = time.Since(start)
|
|
}
|
|
if resp.Done {
|
|
metrics = resp.Metrics
|
|
}
|
|
return nil
|
|
})
|
|
|
|
// Report custom metrics as part of the benchmark results
|
|
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
|
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
|
|
|
// Token throughput metrics
|
|
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
|
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
|
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
|
b.ReportMetric(genThroughput, "gen_tok/s")
|
|
|
|
// Token counts
|
|
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
|
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
|
|
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
|
func BenchmarkColdStart(b *testing.B) {
|
|
client := setup(b)
|
|
tests := []TestCase{
|
|
{"short_prompt", "Write a long story", 100},
|
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
|
}
|
|
m := modelName(b)
|
|
|
|
for _, tt := range tests {
|
|
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
|
ctx := context.Background()
|
|
|
|
// Set number of tokens as our throughput metric
|
|
b.SetBytes(int64(tt.maxTokens))
|
|
|
|
for b.Loop() {
|
|
b.StopTimer()
|
|
// Ensure model is unloaded before each iteration
|
|
unload(client, m, b)
|
|
b.StartTimer()
|
|
|
|
req := &api.GenerateRequest{
|
|
Model: m,
|
|
Prompt: tt.prompt,
|
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
|
}
|
|
|
|
runGenerateBenchmark(b, ctx, client, req)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
|
func BenchmarkWarmStart(b *testing.B) {
|
|
client := setup(b)
|
|
tests := []TestCase{
|
|
{"short_prompt", "Write a long story", 100},
|
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
|
}
|
|
m := modelName(b)
|
|
|
|
for _, tt := range tests {
|
|
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
|
ctx := context.Background()
|
|
|
|
// Pre-warm the model
|
|
warmup(client, m, tt.prompt, b)
|
|
|
|
// Set number of tokens as our throughput metric
|
|
b.SetBytes(int64(tt.maxTokens))
|
|
|
|
for b.Loop() {
|
|
req := &api.GenerateRequest{
|
|
Model: m,
|
|
Prompt: tt.prompt,
|
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
|
}
|
|
|
|
runGenerateBenchmark(b, ctx, client, req)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// setup verifies server and model availability
|
|
func setup(b *testing.B) *api.Client {
|
|
client, err := api.ClientFromEnvironment()
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
|
b.Fatalf("Model unavailable: %v", err)
|
|
}
|
|
|
|
return client
|
|
}
|
|
|
|
// warmup ensures the model is loaded and warmed up
|
|
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
|
for range 3 {
|
|
err := client.Generate(
|
|
context.Background(),
|
|
&api.GenerateRequest{
|
|
Model: model,
|
|
Prompt: prompt,
|
|
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
|
},
|
|
func(api.GenerateResponse) error { return nil },
|
|
)
|
|
if err != nil {
|
|
b.Logf("Error during model warm-up: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// unload forces model unloading using KeepAlive: 0 parameter
|
|
func unload(client *api.Client, model string, b *testing.B) {
|
|
req := &api.GenerateRequest{
|
|
Model: model,
|
|
KeepAlive: &api.Duration{Duration: 0},
|
|
}
|
|
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
|
b.Logf("Unload error: %v", err)
|
|
}
|
|
time.Sleep(1 * time.Second)
|
|
}
|