From d20cd8df8012a4600c3caf3e64b3d17c40f58903 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Mon, 28 Apr 2025 16:11:36 -0700 Subject: [PATCH] fix incorrect chat truncation The dynamically calculated `NumCtx` value wasn't making it all the way to the chat handler This fix made us notice that the minimum setting of `NumCtx` to 4 inside of `server/sched.go` was accidentally removed in #10364. The minimum doesn't make it out to the client code, which is important for embeddings, as demonstrated in `TestAllMiniLMEmbedTruncate`. This should be cleaned up more, but probably is caused by start and end tokens in the embedding, so small context sizes need some work there. See the comment in `server/routes.go` for more information on the temporary hack that's been added to propagate the dynamically calculated `NumCtx` (the -1 guard there is to keep embeddings working if you set `NumCtx` to some small value like `1`). Fixes: #10441 --- integration/embed_test.go | 28 ++++++++++++++++------------ server/routes.go | 10 ++++++++++ server/routes_generate_test.go | 16 ++++++++++++---- server/sched.go | 15 ++++++++++++++- server/sched_test.go | 1 - 5 files changed, 52 insertions(+), 18 deletions(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index 8a95816a5..26e793205 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V { func TestAllMiniLMEmbeddings(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() req := api.EmbeddingRequest{ Model: "all-minilm", Prompt: "why is the sky blue?", } - res, err := embeddingTestHelper(ctx, t, req) + res, err := embeddingTestHelper(ctx, client, t, req) if err != nil { t.Fatalf("error: %v", err) @@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) { func TestAllMiniLMEmbed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() req := api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", } - res, err := embedTestHelper(ctx, t, req) + res, err := embedTestHelper(ctx, client, t, req) if err != nil { t.Fatalf("error: %v", err) @@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) { func TestAllMiniLMBatchEmbed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() req := api.EmbedRequest{ Model: "all-minilm", Input: []string{"why is the sky blue?", "why is the grass green?"}, } - res, err := embedTestHelper(ctx, t, req) + res, err := embedTestHelper(ctx, client, t, req) if err != nil { t.Fatalf("error: %v", err) @@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { func TestAllMiniLMEmbedTruncate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() truncTrue, truncFalse := true, false @@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { res := make(map[string]*api.EmbedResponse) for _, req := range reqs { - response, err := embedTestHelper(ctx, t, req.Request) + response, err := embedTestHelper(ctx, client, t, req.Request) if err != nil { t.Fatalf("error: %v", err) } @@ -190,7 +198,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { } if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { - t.Fatal("expected default request to truncate correctly") + t.Fatal("expected default request to truncate correctly. Wanted: ", res["Target Truncation"].Embeddings[0][0], "Got: ", res["Default Truncate"].Embeddings[0][0]) } if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { @@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { } // check that truncate set to false returns an error if context length is exceeded - _, err := embedTestHelper(ctx, t, api.EmbedRequest{ + _, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", Truncate: &truncFalse, @@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { } } -func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { - client, _, cleanup := InitServerConnection(ctx, t) - defer cleanup() +func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("failed to pull model %s: %v", req.Model, err) } @@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq return response, nil } -func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { - client, _, cleanup := InitServerConnection(ctx, t) - defer cleanup() +func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("failed to pull model %s: %v", req.Model, err) } diff --git a/server/routes.go b/server/routes.go index 31acd0d1a..3d1dfcd68 100644 --- a/server/routes.go +++ b/server/routes.go @@ -114,6 +114,16 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C return nil, nil, nil, err } + // TODO(drifkin): `GetRunner` above changes opts, but we currently pass it by + // value. The following line is a hack to fix this for the now dynaically + // calculated NumCtx, but we should fix this properly (which could have other + // side effects, since perhaps we were relying on the values not being stomped + // on, particularly when NumCtx sometimes represents a numParallel-adjusted + // number and sometimes not) + if opts.NumCtx == -1 { + opts.NumCtx = runner.Options.NumCtx / runner.numParallel + } + return runner.llama, model, &opts, nil } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index dd77b574a..af75797a4 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -80,8 +80,12 @@ func TestGenerateChat(t *testing.T) { loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) { // add small delay to simulate loading time.Sleep(time.Millisecond) + opts := api.DefaultOptions() + opts.NumCtx = 4096 req.successCh <- &runnerRef{ - llama: &mock, + llama: &mock, + numParallel: 1, + Options: &opts, } }, }, @@ -184,7 +188,8 @@ func TestGenerateChat(t *testing.T) { t.Run("load model", func(t *testing.T) { w := createRequest(t, s.ChatHandler, api.ChatRequest{ - Model: "test", + Model: "test", + Options: map[string]any{"num_ctx": 2048}, }) if w.Code != http.StatusOK { @@ -634,7 +639,9 @@ func TestGenerate(t *testing.T) { // add small delay to simulate loading time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ - llama: &mock, + llama: &mock, + Options: &api.Options{}, + numParallel: 1, } }, }, @@ -750,7 +757,8 @@ func TestGenerate(t *testing.T) { t.Run("load model", func(t *testing.T) { w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ - Model: "test", + Model: "test", + Options: map[string]any{"num_ctx": 2048}, }) if w.Code != http.StatusOK { diff --git a/server/sched.go b/server/sched.go index d5b19fbfd..28ebf309d 100644 --- a/server/sched.go +++ b/server/sched.go @@ -81,6 +81,9 @@ func InitScheduler(ctx context.Context) *Scheduler { // context must be canceled to decrement ref count and release the runner func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) { + if opts.NumCtx != -1 && opts.NumCtx < 4 { + opts.NumCtx = 4 + } req := &LlmRequest{ ctx: c, model: model, @@ -585,6 +588,16 @@ func (runner *runnerRef) unload() { runner.gpus = nil } +func runnerOptionsEqual(a, b api.Runner) bool { + // if one of the options is -1, then it means it needs to be dynamically calculated + if a.NumCtx == -1 { + a.NumCtx = b.NumCtx + } else if b.NumCtx == -1 { + b.NumCtx = a.NumCtx + } + return reflect.DeepEqual(a, b) +} + func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool { slog.Debug("evaluating already loaded", "model", req.model.ModelPath) runner.refMu.Lock() @@ -614,7 +627,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool defer cancel() if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed? !reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed? - !reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed? + !runnerOptionsEqual(optsExisting, optsNew) || runner.llama.Ping(ctx) != nil { return true } diff --git a/server/sched_test.go b/server/sched_test.go index 1b620329c..274e18cec 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -148,7 +148,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), } - b.req.opts.NumCtx = 4096 b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} return b }