diff --git a/.golangci.yaml b/.golangci.yaml index 2e0ed3c7b..9d59fd6c0 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -8,8 +8,6 @@ linters: - containedctx - contextcheck - errcheck - - exportloopref - - gci - gocheckcompilerdirectives - gofmt - gofumpt @@ -30,8 +28,6 @@ linters: - wastedassign - whitespace linters-settings: - gci: - sections: [standard, default, localmodule] staticcheck: checks: - all diff --git a/llm/server.go b/llm/server.go index 4cebd8a4c..832863720 100644 --- a/llm/server.go +++ b/llm/server.go @@ -674,21 +674,6 @@ type CompletionResponse struct { } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { - if err := s.sem.Acquire(ctx, 1); err != nil { - if errors.Is(err, context.Canceled) { - slog.Info("aborting completion request due to client closing the connection") - } else { - slog.Error("Failed to acquire semaphore", "error", err) - } - return err - } - defer s.sem.Release(1) - - // put an upper limit on num_predict to avoid the model running on forever - if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx { - req.Options.NumPredict = 10 * s.options.NumCtx - } - request := map[string]any{ "prompt": req.Prompt, "stream": true, @@ -714,6 +699,39 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu "cache_prompt": true, } + if len(req.Format) > 0 { + switch { + case bytes.Equal(req.Format, []byte(`""`)): + // fallthrough + case bytes.Equal(req.Format, []byte(`"json"`)): + request["grammar"] = grammarJSON + case bytes.HasPrefix(req.Format, []byte("{")): + // User provided a JSON schema + g := llama.SchemaToGrammar(req.Format) + if g == nil { + return fmt.Errorf("invalid JSON schema in format") + } + request["grammar"] = string(g) + default: + return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema", req.Format) + } + } + + if err := s.sem.Acquire(ctx, 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting completion request due to client closing the connection") + } else { + slog.Error("Failed to acquire semaphore", "error", err) + } + return err + } + defer s.sem.Release(1) + + // put an upper limit on num_predict to avoid the model running on forever + if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx { + req.Options.NumPredict = 10 * s.options.NumCtx + } + // Make sure the server is ready status, err := s.getServerStatusRetry(ctx) if err != nil { @@ -722,16 +740,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("unexpected server status: %s", status.ToString()) } - if bytes.Equal(req.Format, []byte(`"json"`)) { - request["grammar"] = grammarJSON - } else if bytes.HasPrefix(req.Format, []byte("{")) { - g := llama.SchemaToGrammar(req.Format) - if g == nil { - return fmt.Errorf("invalid JSON schema in format") - } - request["grammar"] = string(g) - } - // Handling JSON marshaling with special characters unescaped. buffer := &bytes.Buffer{} enc := json.NewEncoder(buffer) diff --git a/llm/server_test.go b/llm/server_test.go new file mode 100644 index 000000000..e6f79a585 --- /dev/null +++ b/llm/server_test.go @@ -0,0 +1,63 @@ +package llm + +import ( + "context" + "errors" + "fmt" + "strings" + "testing" + + "github.com/ollama/ollama/api" + "golang.org/x/sync/semaphore" +) + +func TestLLMServerCompletionFormat(t *testing.T) { + // This test was written to fix an already deployed issue. It is a bit + // of a mess, and but it's good enough, until we can refactoring the + // Completion method to be more testable. + + ctx, cancel := context.WithCancel(context.Background()) + s := &llmServer{ + sem: semaphore.NewWeighted(1), // required to prevent nil panic + } + + checkInvalid := func(format string) { + t.Helper() + err := s.Completion(ctx, CompletionRequest{ + Options: new(api.Options), + Format: []byte(format), + }, nil) + + want := fmt.Sprintf("invalid format: %q; expected \"json\" or a valid JSON Schema", format) + if err == nil || !strings.Contains(err.Error(), want) { + t.Fatalf("err = %v; want %q", err, want) + } + } + + checkInvalid("X") // invalid format + checkInvalid(`"X"`) // invalid JSON Schema + + cancel() // prevent further processing if request makes it past the format check + + checkCanceled := func(err error) { + t.Helper() + if !errors.Is(err, context.Canceled) { + t.Fatalf("Completion: err = %v; expected context.Canceled", err) + } + } + + valids := []string{`"json"`, `{"type":"object"}`, ``, `""`} + for _, valid := range valids { + err := s.Completion(ctx, CompletionRequest{ + Options: new(api.Options), + Format: []byte(valid), + }, nil) + checkCanceled(err) + } + + err := s.Completion(ctx, CompletionRequest{ + Options: new(api.Options), + Format: nil, // missing format + }, nil) + checkCanceled(err) +}