mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
fix incorrect chat truncation
The dynamically calculated `NumCtx` value wasn't making it all the way to the chat handler This fix made us notice that the minimum setting of `NumCtx` to 4 inside of `server/sched.go` was accidentally removed in #10364. The minimum doesn't make it out to the client code, which is important for embeddings, as demonstrated in `TestAllMiniLMEmbedTruncate`. This should be cleaned up more, but probably is caused by start and end tokens in the embedding, so small context sizes need some work there. See the comment in `server/routes.go` for more information on the temporary hack that's been added to propagate the dynamically calculated `NumCtx` (the -1 guard there is to keep embeddings working if you set `NumCtx` to some small value like `1`). Fixes: #10441
This commit is contained in:
parent
5cfc1c39f3
commit
d20cd8df80
5 changed files with 52 additions and 18 deletions
|
@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
|
||||||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
req := api.EmbeddingRequest{
|
req := api.EmbeddingRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: "why is the sky blue?",
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := embeddingTestHelper(ctx, t, req)
|
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
|
@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||||
func TestAllMiniLMEmbed(t *testing.T) {
|
func TestAllMiniLMEmbed(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
req := api.EmbedRequest{
|
req := api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, t, req)
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
|
@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
req := api.EmbedRequest{
|
req := api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, t, req)
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
|
@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||||
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
truncTrue, truncFalse := true, false
|
truncTrue, truncFalse := true, false
|
||||||
|
|
||||||
|
@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||||
res := make(map[string]*api.EmbedResponse)
|
res := make(map[string]*api.EmbedResponse)
|
||||||
|
|
||||||
for _, req := range reqs {
|
for _, req := range reqs {
|
||||||
response, err := embedTestHelper(ctx, t, req.Request)
|
response, err := embedTestHelper(ctx, client, t, req.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -190,7 +198,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||||
t.Fatal("expected default request to truncate correctly")
|
t.Fatal("expected default request to truncate correctly. Wanted: ", res["Target Truncation"].Embeddings[0][0], "Got: ", res["Default Truncate"].Embeddings[0][0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||||
|
@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// check that truncate set to false returns an error if context length is exceeded
|
// check that truncate set to false returns an error if context length is exceeded
|
||||||
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
Truncate: &truncFalse,
|
Truncate: &truncFalse,
|
||||||
|
@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
|
||||||
defer cleanup()
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||||
}
|
}
|
||||||
|
@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
|
||||||
defer cleanup()
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,6 +114,16 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(drifkin): `GetRunner` above changes opts, but we currently pass it by
|
||||||
|
// value. The following line is a hack to fix this for the now dynaically
|
||||||
|
// calculated NumCtx, but we should fix this properly (which could have other
|
||||||
|
// side effects, since perhaps we were relying on the values not being stomped
|
||||||
|
// on, particularly when NumCtx sometimes represents a numParallel-adjusted
|
||||||
|
// number and sometimes not)
|
||||||
|
if opts.NumCtx == -1 {
|
||||||
|
opts.NumCtx = runner.Options.NumCtx / runner.numParallel
|
||||||
|
}
|
||||||
|
|
||||||
return runner.llama, model, &opts, nil
|
return runner.llama, model, &opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,8 +80,12 @@ func TestGenerateChat(t *testing.T) {
|
||||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||||
// add small delay to simulate loading
|
// add small delay to simulate loading
|
||||||
time.Sleep(time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
opts.NumCtx = 4096
|
||||||
req.successCh <- &runnerRef{
|
req.successCh <- &runnerRef{
|
||||||
llama: &mock,
|
llama: &mock,
|
||||||
|
numParallel: 1,
|
||||||
|
Options: &opts,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -184,7 +188,8 @@ func TestGenerateChat(t *testing.T) {
|
||||||
|
|
||||||
t.Run("load model", func(t *testing.T) {
|
t.Run("load model", func(t *testing.T) {
|
||||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
Model: "test",
|
Model: "test",
|
||||||
|
Options: map[string]any{"num_ctx": 2048},
|
||||||
})
|
})
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
|
@ -634,7 +639,9 @@ func TestGenerate(t *testing.T) {
|
||||||
// add small delay to simulate loading
|
// add small delay to simulate loading
|
||||||
time.Sleep(time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
req.successCh <- &runnerRef{
|
req.successCh <- &runnerRef{
|
||||||
llama: &mock,
|
llama: &mock,
|
||||||
|
Options: &api.Options{},
|
||||||
|
numParallel: 1,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -750,7 +757,8 @@ func TestGenerate(t *testing.T) {
|
||||||
|
|
||||||
t.Run("load model", func(t *testing.T) {
|
t.Run("load model", func(t *testing.T) {
|
||||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
Model: "test",
|
Model: "test",
|
||||||
|
Options: map[string]any{"num_ctx": 2048},
|
||||||
})
|
})
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
|
|
|
@ -81,6 +81,9 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||||
|
|
||||||
// context must be canceled to decrement ref count and release the runner
|
// context must be canceled to decrement ref count and release the runner
|
||||||
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||||
|
if opts.NumCtx != -1 && opts.NumCtx < 4 {
|
||||||
|
opts.NumCtx = 4
|
||||||
|
}
|
||||||
req := &LlmRequest{
|
req := &LlmRequest{
|
||||||
ctx: c,
|
ctx: c,
|
||||||
model: model,
|
model: model,
|
||||||
|
@ -585,6 +588,16 @@ func (runner *runnerRef) unload() {
|
||||||
runner.gpus = nil
|
runner.gpus = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runnerOptionsEqual(a, b api.Runner) bool {
|
||||||
|
// if one of the options is -1, then it means it needs to be dynamically calculated
|
||||||
|
if a.NumCtx == -1 {
|
||||||
|
a.NumCtx = b.NumCtx
|
||||||
|
} else if b.NumCtx == -1 {
|
||||||
|
b.NumCtx = a.NumCtx
|
||||||
|
}
|
||||||
|
return reflect.DeepEqual(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||||
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
||||||
runner.refMu.Lock()
|
runner.refMu.Lock()
|
||||||
|
@ -614,7 +627,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
!runnerOptionsEqual(optsExisting, optsNew) ||
|
||||||
runner.llama.Ping(ctx) != nil {
|
runner.llama.Ping(ctx) != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,7 +148,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
b.req.opts.NumCtx = 4096
|
|
||||||
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue