fix incorrect chat truncation

The dynamically calculated `NumCtx` value wasn't making it all the way
to the chat handler

This fix made us notice that the minimum setting of `NumCtx` to 4 inside
of `server/sched.go` was accidentally removed in #10364. The minimum
doesn't make it out to the client code, which is important for
embeddings, as demonstrated in `TestAllMiniLMEmbedTruncate`. This should
be cleaned up more, but probably is caused by start and end tokens in
the embedding, so small context sizes need some work there. See the
comment in `server/routes.go` for more information on the temporary hack
that's been added to propagate the dynamically calculated `NumCtx` (the
-1 guard there is to keep embeddings working if you set `NumCtx` to some
small value like `1`).

Fixes: #10441
This commit is contained in:
Devon Rifkin 2025-04-28 16:11:36 -07:00
parent 5cfc1c39f3
commit d20cd8df80
5 changed files with 52 additions and 18 deletions

View file

@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
}
res, err := embeddingTestHelper(ctx, t, req)
res, err := embeddingTestHelper(ctx, client, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, t, req)
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) {
func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := embedTestHelper(ctx, t, req)
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
func TestAllMiniLMEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
truncTrue, truncFalse := true, false
@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request)
response, err := embedTestHelper(ctx, client, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
@ -190,7 +198,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
t.Fatal("expected default request to truncate correctly. Wanted: ", res["Target Truncation"].Embeddings[0][0], "Got: ", res["Default Truncate"].Embeddings[0][0])
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
}
}
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq
return response, nil
}
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}

View file

@ -114,6 +114,16 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return nil, nil, nil, err
}
// TODO(drifkin): `GetRunner` above changes opts, but we currently pass it by
// value. The following line is a hack to fix this for the now dynaically
// calculated NumCtx, but we should fix this properly (which could have other
// side effects, since perhaps we were relying on the values not being stomped
// on, particularly when NumCtx sometimes represents a numParallel-adjusted
// number and sometimes not)
if opts.NumCtx == -1 {
opts.NumCtx = runner.Options.NumCtx / runner.numParallel
}
return runner.llama, model, &opts, nil
}

View file

@ -80,8 +80,12 @@ func TestGenerateChat(t *testing.T) {
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
opts := api.DefaultOptions()
opts.NumCtx = 4096
req.successCh <- &runnerRef{
llama: &mock,
llama: &mock,
numParallel: 1,
Options: &opts,
}
},
},
@ -184,7 +188,8 @@ func TestGenerateChat(t *testing.T) {
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Model: "test",
Options: map[string]any{"num_ctx": 2048},
})
if w.Code != http.StatusOK {
@ -634,7 +639,9 @@ func TestGenerate(t *testing.T) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
llama: &mock,
Options: &api.Options{},
numParallel: 1,
}
},
},
@ -750,7 +757,8 @@ func TestGenerate(t *testing.T) {
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Model: "test",
Options: map[string]any{"num_ctx": 2048},
})
if w.Code != http.StatusOK {

View file

@ -81,6 +81,9 @@ func InitScheduler(ctx context.Context) *Scheduler {
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx != -1 && opts.NumCtx < 4 {
opts.NumCtx = 4
}
req := &LlmRequest{
ctx: c,
model: model,
@ -585,6 +588,16 @@ func (runner *runnerRef) unload() {
runner.gpus = nil
}
func runnerOptionsEqual(a, b api.Runner) bool {
// if one of the options is -1, then it means it needs to be dynamically calculated
if a.NumCtx == -1 {
a.NumCtx = b.NumCtx
} else if b.NumCtx == -1 {
b.NumCtx = a.NumCtx
}
return reflect.DeepEqual(a, b)
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
runner.refMu.Lock()
@ -614,7 +627,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
!runnerOptionsEqual(optsExisting, optsNew) ||
runner.llama.Ping(ctx) != nil {
return true
}

View file

@ -148,7 +148,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
b.req.opts.NumCtx = 4096
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return b
}