mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
fix incorrect chat truncation
The dynamically calculated `NumCtx` value wasn't making it all the way to the chat handler This fix made us notice that the minimum setting of `NumCtx` to 4 inside of `server/sched.go` was accidentally removed in #10364. The minimum doesn't make it out to the client code, which is important for embeddings, as demonstrated in `TestAllMiniLMEmbedTruncate`. This should be cleaned up more, but probably is caused by start and end tokens in the embedding, so small context sizes need some work there. See the comment in `server/routes.go` for more information on the temporary hack that's been added to propagate the dynamically calculated `NumCtx` (the -1 guard there is to keep embeddings working if you set `NumCtx` to some small value like `1`). Fixes: #10441
This commit is contained in:
parent
5cfc1c39f3
commit
d20cd8df80
5 changed files with 52 additions and 18 deletions
|
@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
|
|||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, t, req)
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
|
@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
|||
func TestAllMiniLMEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
|
@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
|
@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
truncTrue, truncFalse := true, false
|
||||
|
||||
|
@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := embedTestHelper(ctx, t, req.Request)
|
||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
@ -190,7 +198,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
}
|
||||
|
||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request to truncate correctly")
|
||||
t.Fatal("expected default request to truncate correctly. Wanted: ", res["Target Truncation"].Embeddings[0][0], "Got: ", res["Default Truncate"].Embeddings[0][0])
|
||||
}
|
||||
|
||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||
|
@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
|
@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq
|
|||
return response, nil
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
|
|
@ -114,6 +114,16 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
|||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// TODO(drifkin): `GetRunner` above changes opts, but we currently pass it by
|
||||
// value. The following line is a hack to fix this for the now dynaically
|
||||
// calculated NumCtx, but we should fix this properly (which could have other
|
||||
// side effects, since perhaps we were relying on the values not being stomped
|
||||
// on, particularly when NumCtx sometimes represents a numParallel-adjusted
|
||||
// number and sometimes not)
|
||||
if opts.NumCtx == -1 {
|
||||
opts.NumCtx = runner.Options.NumCtx / runner.numParallel
|
||||
}
|
||||
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -80,8 +80,12 @@ func TestGenerateChat(t *testing.T) {
|
|||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4096
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
llama: &mock,
|
||||
numParallel: 1,
|
||||
Options: &opts,
|
||||
}
|
||||
},
|
||||
},
|
||||
|
@ -184,7 +188,8 @@ func TestGenerateChat(t *testing.T) {
|
|||
|
||||
t.Run("load model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Model: "test",
|
||||
Options: map[string]any{"num_ctx": 2048},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
|
@ -634,7 +639,9 @@ func TestGenerate(t *testing.T) {
|
|||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
llama: &mock,
|
||||
Options: &api.Options{},
|
||||
numParallel: 1,
|
||||
}
|
||||
},
|
||||
},
|
||||
|
@ -750,7 +757,8 @@ func TestGenerate(t *testing.T) {
|
|||
|
||||
t.Run("load model", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Model: "test",
|
||||
Options: map[string]any{"num_ctx": 2048},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
|
|
|
@ -81,6 +81,9 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
|||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx != -1 && opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
req := &LlmRequest{
|
||||
ctx: c,
|
||||
model: model,
|
||||
|
@ -585,6 +588,16 @@ func (runner *runnerRef) unload() {
|
|||
runner.gpus = nil
|
||||
}
|
||||
|
||||
func runnerOptionsEqual(a, b api.Runner) bool {
|
||||
// if one of the options is -1, then it means it needs to be dynamically calculated
|
||||
if a.NumCtx == -1 {
|
||||
a.NumCtx = b.NumCtx
|
||||
} else if b.NumCtx == -1 {
|
||||
b.NumCtx = a.NumCtx
|
||||
}
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
||||
runner.refMu.Lock()
|
||||
|
@ -614,7 +627,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||
defer cancel()
|
||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
||||
!runnerOptionsEqual(optsExisting, optsNew) ||
|
||||
runner.llama.Ping(ctx) != nil {
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -148,7 +148,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
|||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
b.req.opts.NumCtx = 4096
|
||||
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||
return b
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue