From 87f0a49fe6b0db7de0d6fa76e5d2a27963c10ca7 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 16 Dec 2024 21:57:49 -0800 Subject: [PATCH] llm: do not silently fail for supplied, but invalid formats (#8130) Changes in #8002 introduced fixes for bugs with mangling JSON Schemas. It also fixed a bug where the server would silently fail when clients requested invalid formats. It also, unfortunately, introduced a bug where the server would reject requests with an empty format, which should be allowed. The change in #8127 updated the code to allow the empty format, but also reintroduced the regression where the server would silently fail when the format was set, but invalid. This commit fixes both regressions. The server does not reject the empty format, but it does reject invalid formats. It also adds tests to help us catch regressions in the future. Also, the updated code provides a more detailed error message when a client sends a non-empty, but invalid format, echoing the invalid format in the response. This commits also takes the opportunity to remove superfluous linter checks. --- .golangci.yaml | 4 --- llm/server.go | 58 ++++++++++++++++++++++++------------------ llm/server_test.go | 63 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 29 deletions(-) create mode 100644 llm/server_test.go diff --git a/.golangci.yaml b/.golangci.yaml index 2e0ed3c7b..9d59fd6c0 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -8,8 +8,6 @@ linters: - containedctx - contextcheck - errcheck - - exportloopref - - gci - gocheckcompilerdirectives - gofmt - gofumpt @@ -30,8 +28,6 @@ linters: - wastedassign - whitespace linters-settings: - gci: - sections: [standard, default, localmodule] staticcheck: checks: - all diff --git a/llm/server.go b/llm/server.go index 4cebd8a4c..832863720 100644 --- a/llm/server.go +++ b/llm/server.go @@ -674,21 +674,6 @@ type CompletionResponse struct { } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { - if err := s.sem.Acquire(ctx, 1); err != nil { - if errors.Is(err, context.Canceled) { - slog.Info("aborting completion request due to client closing the connection") - } else { - slog.Error("Failed to acquire semaphore", "error", err) - } - return err - } - defer s.sem.Release(1) - - // put an upper limit on num_predict to avoid the model running on forever - if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx { - req.Options.NumPredict = 10 * s.options.NumCtx - } - request := map[string]any{ "prompt": req.Prompt, "stream": true, @@ -714,6 +699,39 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu "cache_prompt": true, } + if len(req.Format) > 0 { + switch { + case bytes.Equal(req.Format, []byte(`""`)): + // fallthrough + case bytes.Equal(req.Format, []byte(`"json"`)): + request["grammar"] = grammarJSON + case bytes.HasPrefix(req.Format, []byte("{")): + // User provided a JSON schema + g := llama.SchemaToGrammar(req.Format) + if g == nil { + return fmt.Errorf("invalid JSON schema in format") + } + request["grammar"] = string(g) + default: + return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema", req.Format) + } + } + + if err := s.sem.Acquire(ctx, 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting completion request due to client closing the connection") + } else { + slog.Error("Failed to acquire semaphore", "error", err) + } + return err + } + defer s.sem.Release(1) + + // put an upper limit on num_predict to avoid the model running on forever + if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx { + req.Options.NumPredict = 10 * s.options.NumCtx + } + // Make sure the server is ready status, err := s.getServerStatusRetry(ctx) if err != nil { @@ -722,16 +740,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("unexpected server status: %s", status.ToString()) } - if bytes.Equal(req.Format, []byte(`"json"`)) { - request["grammar"] = grammarJSON - } else if bytes.HasPrefix(req.Format, []byte("{")) { - g := llama.SchemaToGrammar(req.Format) - if g == nil { - return fmt.Errorf("invalid JSON schema in format") - } - request["grammar"] = string(g) - } - // Handling JSON marshaling with special characters unescaped. buffer := &bytes.Buffer{} enc := json.NewEncoder(buffer) diff --git a/llm/server_test.go b/llm/server_test.go new file mode 100644 index 000000000..e6f79a585 --- /dev/null +++ b/llm/server_test.go @@ -0,0 +1,63 @@ +package llm + +import ( + "context" + "errors" + "fmt" + "strings" + "testing" + + "github.com/ollama/ollama/api" + "golang.org/x/sync/semaphore" +) + +func TestLLMServerCompletionFormat(t *testing.T) { + // This test was written to fix an already deployed issue. It is a bit + // of a mess, and but it's good enough, until we can refactoring the + // Completion method to be more testable. + + ctx, cancel := context.WithCancel(context.Background()) + s := &llmServer{ + sem: semaphore.NewWeighted(1), // required to prevent nil panic + } + + checkInvalid := func(format string) { + t.Helper() + err := s.Completion(ctx, CompletionRequest{ + Options: new(api.Options), + Format: []byte(format), + }, nil) + + want := fmt.Sprintf("invalid format: %q; expected \"json\" or a valid JSON Schema", format) + if err == nil || !strings.Contains(err.Error(), want) { + t.Fatalf("err = %v; want %q", err, want) + } + } + + checkInvalid("X") // invalid format + checkInvalid(`"X"`) // invalid JSON Schema + + cancel() // prevent further processing if request makes it past the format check + + checkCanceled := func(err error) { + t.Helper() + if !errors.Is(err, context.Canceled) { + t.Fatalf("Completion: err = %v; expected context.Canceled", err) + } + } + + valids := []string{`"json"`, `{"type":"object"}`, ``, `""`} + for _, valid := range valids { + err := s.Completion(ctx, CompletionRequest{ + Options: new(api.Options), + Format: []byte(valid), + }, nil) + checkCanceled(err) + } + + err := s.Completion(ctx, CompletionRequest{ + Options: new(api.Options), + Format: nil, // missing format + }, nil) + checkCanceled(err) +}