mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
openai: return usage as final chunk for streams (#6784)
* openai: return usage as final chunk for streams --------- Co-authored-by: ParthSareen <parth.sareen@ollama.com>
This commit is contained in:
parent
c216850523
commit
e28f2d4900
3 changed files with 176 additions and 33 deletions
|
@ -233,6 +233,8 @@ curl http://localhost:11434/v1/embeddings \
|
||||||
- [x] `seed`
|
- [x] `seed`
|
||||||
- [x] `stop`
|
- [x] `stop`
|
||||||
- [x] `stream`
|
- [x] `stream`
|
||||||
|
- [x] `stream_options`
|
||||||
|
- [x] `include_usage`
|
||||||
- [x] `temperature`
|
- [x] `temperature`
|
||||||
- [x] `top_p`
|
- [x] `top_p`
|
||||||
- [x] `max_tokens`
|
- [x] `max_tokens`
|
||||||
|
@ -261,6 +263,8 @@ curl http://localhost:11434/v1/embeddings \
|
||||||
- [x] `seed`
|
- [x] `seed`
|
||||||
- [x] `stop`
|
- [x] `stop`
|
||||||
- [x] `stream`
|
- [x] `stream`
|
||||||
|
- [x] `stream_options`
|
||||||
|
- [x] `include_usage`
|
||||||
- [x] `temperature`
|
- [x] `temperature`
|
||||||
- [x] `top_p`
|
- [x] `top_p`
|
||||||
- [x] `max_tokens`
|
- [x] `max_tokens`
|
||||||
|
|
117
openai/openai.go
117
openai/openai.go
|
@ -75,10 +75,15 @@ type EmbedRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
|
StreamOptions *StreamOptions `json:"stream_options"`
|
||||||
MaxTokens *int `json:"max_tokens"`
|
MaxTokens *int `json:"max_tokens"`
|
||||||
Seed *int `json:"seed"`
|
Seed *int `json:"seed"`
|
||||||
Stop any `json:"stop"`
|
Stop any `json:"stop"`
|
||||||
|
@ -107,21 +112,23 @@ type ChatCompletionChunk struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
Choices []ChunkChoice `json:"choices"`
|
Choices []ChunkChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||||
MaxTokens *int `json:"max_tokens"`
|
MaxTokens *int `json:"max_tokens"`
|
||||||
PresencePenalty float32 `json:"presence_penalty"`
|
PresencePenalty float32 `json:"presence_penalty"`
|
||||||
Seed *int `json:"seed"`
|
Seed *int `json:"seed"`
|
||||||
Stop any `json:"stop"`
|
Stop any `json:"stop"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Temperature *float32 `json:"temperature"`
|
StreamOptions *StreamOptions `json:"stream_options"`
|
||||||
TopP float32 `json:"top_p"`
|
Temperature *float32 `json:"temperature"`
|
||||||
Suffix string `json:"suffix"`
|
TopP float32 `json:"top_p"`
|
||||||
|
Suffix string `json:"suffix"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Completion struct {
|
type Completion struct {
|
||||||
|
@ -141,6 +148,7 @@ type CompletionChunk struct {
|
||||||
Choices []CompleteChunkChoice `json:"choices"`
|
Choices []CompleteChunkChoice `json:"choices"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Usage *Usage `json:"usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
|
@ -197,6 +205,14 @@ func NewError(code int, message string) ErrorResponse {
|
||||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toUsage(r api.ChatResponse) Usage {
|
||||||
|
return Usage{
|
||||||
|
PromptTokens: r.PromptEvalCount,
|
||||||
|
CompletionTokens: r.EvalCount,
|
||||||
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toolCallId() string {
|
func toolCallId() string {
|
||||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
b := make([]byte, 8)
|
b := make([]byte, 8)
|
||||||
|
@ -246,11 +262,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: toUsage(r),
|
||||||
PromptTokens: r.PromptEvalCount,
|
|
||||||
CompletionTokens: r.EvalCount,
|
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,6 +287,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toUsageGenerate(r api.GenerateResponse) Usage {
|
||||||
|
return Usage{
|
||||||
|
PromptTokens: r.PromptEvalCount,
|
||||||
|
CompletionTokens: r.EvalCount,
|
||||||
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
return Completion{
|
return Completion{
|
||||||
Id: id,
|
Id: id,
|
||||||
|
@ -292,11 +312,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: toUsageGenerate(r),
|
||||||
PromptTokens: r.PromptEvalCount,
|
|
||||||
CompletionTokens: r.EvalCount,
|
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -566,14 +582,16 @@ type BaseWriter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatWriter struct {
|
type ChatWriter struct {
|
||||||
stream bool
|
stream bool
|
||||||
id string
|
streamOptions *StreamOptions
|
||||||
|
id string
|
||||||
BaseWriter
|
BaseWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompleteWriter struct {
|
type CompleteWriter struct {
|
||||||
stream bool
|
stream bool
|
||||||
id string
|
streamOptions *StreamOptions
|
||||||
|
id string
|
||||||
BaseWriter
|
BaseWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -616,7 +634,8 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
|
||||||
// chat chunk
|
// chat chunk
|
||||||
if w.stream {
|
if w.stream {
|
||||||
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
c := toChunk(w.id, chatResponse)
|
||||||
|
d, err := json.Marshal(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -628,6 +647,19 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatResponse.Done {
|
if chatResponse.Done {
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
u := toUsage(chatResponse)
|
||||||
|
c.Usage = &u
|
||||||
|
c.Choices = []ChunkChoice{}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -665,7 +697,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
|
||||||
// completion chunk
|
// completion chunk
|
||||||
if w.stream {
|
if w.stream {
|
||||||
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
c := toCompleteChunk(w.id, generateResponse)
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
c.Usage = &Usage{}
|
||||||
|
}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -677,6 +713,19 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if generateResponse.Done {
|
if generateResponse.Done {
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
u := toUsageGenerate(generateResponse)
|
||||||
|
c.Usage = &u
|
||||||
|
c.Choices = []CompleteChunkChoice{}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -839,9 +888,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
||||||
c.Request.Body = io.NopCloser(&b)
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
w := &CompleteWriter{
|
w := &CompleteWriter{
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
stream: req.Stream,
|
stream: req.Stream,
|
||||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
|
streamOptions: req.StreamOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
|
@ -921,9 +971,10 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||||
c.Request.Body = io.NopCloser(&b)
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
w := &ChatWriter{
|
w := &ChatWriter{
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
stream: req.Stream,
|
stream: req.Stream,
|
||||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||||
|
streamOptions: req.StreamOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
|
|
|
@ -112,6 +112,45 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
Stream: &True,
|
Stream: &True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"max_tokens": 999,
|
||||||
|
"seed": 123,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||||
|
"seed": 123.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
},
|
||||||
|
Format: json.RawMessage(`"json"`),
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "chat handler with image content",
|
name: "chat handler with image content",
|
||||||
body: `{
|
body: `{
|
||||||
|
@ -363,6 +402,55 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||||
Stream: &False,
|
Stream: &False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream with usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "completions handler error forwarding",
|
name: "completions handler error forwarding",
|
||||||
body: `{
|
body: `{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue