diff --git a/docs/openai.md b/docs/openai.md index c30276a73..b0f9b353c 100644 --- a/docs/openai.md +++ b/docs/openai.md @@ -233,6 +233,8 @@ curl http://localhost:11434/v1/embeddings \ - [x] `seed` - [x] `stop` - [x] `stream` +- [x] `stream_options` + - [x] `include_usage` - [x] `temperature` - [x] `top_p` - [x] `max_tokens` @@ -261,6 +263,8 @@ curl http://localhost:11434/v1/embeddings \ - [x] `seed` - [x] `stop` - [x] `stream` +- [x] `stream_options` + - [x] `include_usage` - [x] `temperature` - [x] `top_p` - [x] `max_tokens` diff --git a/openai/openai.go b/openai/openai.go index 6b28eee42..bc7bb1aff 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -75,10 +75,15 @@ type EmbedRequest struct { Model string `json:"model"` } +type StreamOptions struct { + IncludeUsage bool `json:"include_usage"` +} + type ChatCompletionRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` Stream bool `json:"stream"` + StreamOptions *StreamOptions `json:"stream_options"` MaxTokens *int `json:"max_tokens"` Seed *int `json:"seed"` Stop any `json:"stop"` @@ -107,21 +112,23 @@ type ChatCompletionChunk struct { Model string `json:"model"` SystemFingerprint string `json:"system_fingerprint"` Choices []ChunkChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` } // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int type CompletionRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - FrequencyPenalty float32 `json:"frequency_penalty"` - MaxTokens *int `json:"max_tokens"` - PresencePenalty float32 `json:"presence_penalty"` - Seed *int `json:"seed"` - Stop any `json:"stop"` - Stream bool `json:"stream"` - Temperature *float32 `json:"temperature"` - TopP float32 `json:"top_p"` - Suffix string `json:"suffix"` + Model string `json:"model"` + Prompt string `json:"prompt"` + FrequencyPenalty float32 `json:"frequency_penalty"` + MaxTokens *int `json:"max_tokens"` + PresencePenalty float32 `json:"presence_penalty"` + Seed *int `json:"seed"` + Stop any `json:"stop"` + Stream bool `json:"stream"` + StreamOptions *StreamOptions `json:"stream_options"` + Temperature *float32 `json:"temperature"` + TopP float32 `json:"top_p"` + Suffix string `json:"suffix"` } type Completion struct { @@ -141,6 +148,7 @@ type CompletionChunk struct { Choices []CompleteChunkChoice `json:"choices"` Model string `json:"model"` SystemFingerprint string `json:"system_fingerprint"` + Usage *Usage `json:"usage,omitempty"` } type ToolCall struct { @@ -197,6 +205,14 @@ func NewError(code int, message string) ErrorResponse { return ErrorResponse{Error{Type: etype, Message: message}} } +func toUsage(r api.ChatResponse) Usage { + return Usage{ + PromptTokens: r.PromptEvalCount, + CompletionTokens: r.EvalCount, + TotalTokens: r.PromptEvalCount + r.EvalCount, + } +} + func toolCallId() string { const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, 8) @@ -246,11 +262,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { return nil }(r.DoneReason), }}, - Usage: Usage{ - PromptTokens: r.PromptEvalCount, - CompletionTokens: r.EvalCount, - TotalTokens: r.PromptEvalCount + r.EvalCount, - }, + Usage: toUsage(r), } } @@ -275,6 +287,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { } } +func toUsageGenerate(r api.GenerateResponse) Usage { + return Usage{ + PromptTokens: r.PromptEvalCount, + CompletionTokens: r.EvalCount, + TotalTokens: r.PromptEvalCount + r.EvalCount, + } +} + func toCompletion(id string, r api.GenerateResponse) Completion { return Completion{ Id: id, @@ -292,11 +312,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion { return nil }(r.DoneReason), }}, - Usage: Usage{ - PromptTokens: r.PromptEvalCount, - CompletionTokens: r.EvalCount, - TotalTokens: r.PromptEvalCount + r.EvalCount, - }, + Usage: toUsageGenerate(r), } } @@ -566,14 +582,16 @@ type BaseWriter struct { } type ChatWriter struct { - stream bool - id string + stream bool + streamOptions *StreamOptions + id string BaseWriter } type CompleteWriter struct { - stream bool - id string + stream bool + streamOptions *StreamOptions + id string BaseWriter } @@ -616,7 +634,8 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { // chat chunk if w.stream { - d, err := json.Marshal(toChunk(w.id, chatResponse)) + c := toChunk(w.id, chatResponse) + d, err := json.Marshal(c) if err != nil { return 0, err } @@ -628,6 +647,19 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { } if chatResponse.Done { + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + u := toUsage(chatResponse) + c.Usage = &u + c.Choices = []ChunkChoice{} + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + } _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) if err != nil { return 0, err @@ -665,7 +697,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) { // completion chunk if w.stream { - d, err := json.Marshal(toCompleteChunk(w.id, generateResponse)) + c := toCompleteChunk(w.id, generateResponse) + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + c.Usage = &Usage{} + } + d, err := json.Marshal(c) if err != nil { return 0, err } @@ -677,6 +713,19 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) { } if generateResponse.Done { + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + u := toUsageGenerate(generateResponse) + c.Usage = &u + c.Choices = []CompleteChunkChoice{} + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + } _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) if err != nil { return 0, err @@ -839,9 +888,10 @@ func CompletionsMiddleware() gin.HandlerFunc { c.Request.Body = io.NopCloser(&b) w := &CompleteWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - stream: req.Stream, - id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), + streamOptions: req.StreamOptions, } c.Writer = w @@ -921,9 +971,10 @@ func ChatMiddleware() gin.HandlerFunc { c.Request.Body = io.NopCloser(&b) w := &ChatWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - stream: req.Stream, - id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), + streamOptions: req.StreamOptions, } c.Writer = w diff --git a/openai/openai_test.go b/openai/openai_test.go index 0c2a7d806..d8c821d39 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -112,6 +112,45 @@ func TestChatMiddleware(t *testing.T) { Stream: &True, }, }, + { + name: "chat handler with streaming usage", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "stream": true, + "stream_options": {"include_usage": true}, + "max_tokens": 999, + "seed": 123, + "stop": ["\n", "stop"], + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + "response_format": {"type": "json_object"} + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "num_predict": 999.0, // float because JSON doesn't distinguish between float and int + "seed": 123.0, + "stop": []any{"\n", "stop"}, + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + }, + Format: json.RawMessage(`"json"`), + Stream: &True, + }, + }, { name: "chat handler with image content", body: `{ @@ -363,6 +402,55 @@ func TestCompletionsMiddleware(t *testing.T) { Stream: &False, }, }, + { + name: "completions handler stream", + body: `{ + "model": "test-model", + "prompt": "Hello", + "stream": true, + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 0.8, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &True, + }, + }, + { + name: "completions handler stream with usage", + body: `{ + "model": "test-model", + "prompt": "Hello", + "stream": true, + "stream_options": {"include_usage": true}, + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 0.8, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &True, + }, + }, { name: "completions handler error forwarding", body: `{