openai: return usage as final chunk for streams (#6784)

* openai: return usage as final chunk for streams

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
This commit is contained in:
Anuraag (Rag) Agrawal 2024-12-13 10:09:30 +09:00 committed by GitHub
parent c216850523
commit e28f2d4900
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 176 additions and 33 deletions

View file

@ -75,10 +75,15 @@ type EmbedRequest struct {
Model string `json:"model"`
}
type StreamOptions struct {
IncludeUsage bool `json:"include_usage"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
MaxTokens *int `json:"max_tokens"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
@ -107,21 +112,23 @@ type ChatCompletionChunk struct {
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []ChunkChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
}
type Completion struct {
@ -141,6 +148,7 @@ type CompletionChunk struct {
Choices []CompleteChunkChoice `json:"choices"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Usage *Usage `json:"usage,omitempty"`
}
type ToolCall struct {
@ -197,6 +205,14 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}}
}
func toUsage(r api.ChatResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
@ -246,11 +262,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
return nil
}(r.DoneReason),
}},
Usage: Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
Usage: toUsage(r),
}
}
@ -275,6 +287,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
}
}
func toUsageGenerate(r api.GenerateResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
func toCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
@ -292,11 +312,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
return nil
}(r.DoneReason),
}},
Usage: Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
Usage: toUsageGenerate(r),
}
}
@ -566,14 +582,16 @@ type BaseWriter struct {
}
type ChatWriter struct {
stream bool
id string
stream bool
streamOptions *StreamOptions
id string
BaseWriter
}
type CompleteWriter struct {
stream bool
id string
stream bool
streamOptions *StreamOptions
id string
BaseWriter
}
@ -616,7 +634,8 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
// chat chunk
if w.stream {
d, err := json.Marshal(toChunk(w.id, chatResponse))
c := toChunk(w.id, chatResponse)
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
@ -628,6 +647,19 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
}
if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsage(chatResponse)
c.Usage = &u
c.Choices = []ChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
@ -665,7 +697,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
// completion chunk
if w.stream {
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
c := toCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &Usage{}
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
@ -677,6 +713,19 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
}
if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsageGenerate(generateResponse)
c.Usage = &u
c.Choices = []CompleteChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
@ -839,9 +888,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
@ -921,9 +971,10 @@ func ChatMiddleware() gin.HandlerFunc {
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w