diff --git a/openai/openai.go b/openai/openai.go index bc7bb1aff..214801fa6 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -20,6 +20,8 @@ import ( "github.com/ollama/ollama/types/model" ) +var finishReasonToolCalls = "tool_calls" + type Error struct { Message string `json:"message"` Type string `json:"type"` @@ -266,7 +268,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } } -func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { +func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk { toolCalls := toToolCalls(r.Message.ToolCalls) return ChatCompletionChunk{ Id: id, @@ -279,6 +281,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls}, FinishReason: func(reason string) *string { if len(reason) > 0 { + if toolCallSent { + return &finishReasonToolCalls + } return &reason } return nil @@ -585,6 +590,7 @@ type ChatWriter struct { stream bool streamOptions *StreamOptions id string + toolCallSent bool BaseWriter } @@ -634,11 +640,14 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { // chat chunk if w.stream { - c := toChunk(w.id, chatResponse) + c := toChunk(w.id, chatResponse, w.toolCallSent) d, err := json.Marshal(c) if err != nil { return 0, err } + if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 { + w.toolCallSent = true + } w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))