ollama/openai/openai.go
Blake Mizerany 9039c821a2
llama: preserve field order in user-defined JSON schemas (#8002)
Previously we decoded and re-encoded JSON schemas during validation,
which served no purpose since json.RawMessage already validates JSON
syntax. Worse, the re-encoding lost field ordering from the original
schema, which affects inference quality during step-by-step reasoning.

While fixing this ordering issue by using json.RawMessage directly,
testing revealed that schema_to_grammar (from llama.cpp) also fails to
preserve field order during grammar generation. This appears to be the
root cause of inference degradation.

This change prevents us from mangling the user's original schema order,
but we still need to address the ordering issue in schema_to_grammar.
That will be a separate change.

Updates #7978
2024-12-11 14:07:30 -08:00

933 lines
22 KiB
Go

// openai package provides middleware for partial compatibility with the OpenAI REST API
package openai
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
Param interface{} `json:"param"`
Code *string `json:"code"`
}
type ErrorResponse struct {
Error Error `json:"error"`
}
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
type Choice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason *string `json:"finish_reason"`
}
type ChunkChoice struct {
Index int `json:"index"`
Delta Message `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type CompleteChunkChoice struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason *string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
}
type JsonSchema struct {
Schema json.RawMessage `json:"schema"`
}
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
MaxTokens *int `json:"max_tokens"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Temperature *float64 `json:"temperature"`
FrequencyPenalty *float64 `json:"frequency_penalty"`
PresencePenalty *float64 `json:"presence_penalty"`
TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
}
type ChatCompletion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
}
type ChatCompletionChunk struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []ChunkChoice `json:"choices"`
}
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
}
type Completion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []CompleteChunkChoice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
}
type CompletionChunk struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []CompleteChunkChoice `json:"choices"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
}
type ToolCall struct {
ID string `json:"id"`
Index int `json:"index"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type Model struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type ListCompletion struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type EmbeddingList struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
Usage EmbeddingUsage `json:"usage,omitempty"`
}
type EmbeddingUsage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusNotFound:
etype = "not_found_error"
default:
etype = "api_error"
}
return ErrorResponse{Error{Type: etype, Message: message}}
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return "call_" + strings.ToLower(string(b))
}
func toToolCalls(tc []api.ToolCall) []ToolCall {
toolCalls := make([]ToolCall, len(tc))
for i, tc := range tc {
toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
toolCalls[i].Index = tc.Function.Index
args, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("could not marshall function arguments to json", "error", err)
continue
}
toolCalls[i].Function.Arguments = string(args)
}
return toolCalls
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletion{
Id: id,
Object: "chat.completion",
Created: r.CreatedAt.Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []Choice{{
Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string {
if len(toolCalls) > 0 {
reason = "tool_calls"
}
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
Usage: Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
}
}
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{
Id: id,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{
Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
}
}
func toCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
Object: "text_completion",
Created: r.CreatedAt.Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
Usage: Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
}
}
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
return CompletionChunk{
Id: id,
Object: "text_completion",
Created: time.Now().Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
}
}
func toListCompletion(r api.ListResponse) ListCompletion {
var data []Model
for _, m := range r.Models {
data = append(data, Model{
Id: m.Name,
Object: "model",
Created: m.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m.Name).Namespace,
})
}
return ListCompletion{
Object: "list",
Data: data,
}
}
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
data = append(data, Embedding{
Object: "embedding",
Embedding: e,
Index: i,
})
}
return EmbeddingList{
Object: "list",
Data: data,
Model: model,
Usage: EmbeddingUsage{
PromptTokens: r.PromptEvalCount,
TotalTokens: r.PromptEvalCount,
},
}
}
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
Object: "model",
Created: r.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m).Namespace,
}
}
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message
for _, msg := range r.Messages {
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any:
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
return nil, errors.New("invalid message format")
}
switch data["type"] {
case "text":
text, ok := data["text"].(string)
if !ok {
return nil, errors.New("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok {
return nil, errors.New("invalid message format")
}
} else {
if url, ok = data["image_url"].(string); !ok {
return nil, errors.New("invalid message format")
}
}
types := []string{"jpeg", "jpg", "png"}
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, errors.New("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, errors.New("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default:
return nil, errors.New("invalid message format")
}
}
default:
if msg.ToolCalls == nil {
return nil, fmt.Errorf("invalid message content type: %T", content)
}
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
toolCalls[i].Function.Name = tc.Function.Name
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
if err != nil {
return nil, errors.New("invalid tool call arguments")
}
}
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
}
}
options := make(map[string]interface{})
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
}
}
options["stop"] = stops
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
}
if r.Temperature != nil {
options["temperature"] = *r.Temperature
} else {
options["temperature"] = 1.0
}
if r.Seed != nil {
options["seed"] = *r.Seed
}
if r.FrequencyPenalty != nil {
options["frequency_penalty"] = *r.FrequencyPenalty
}
if r.PresencePenalty != nil {
options["presence_penalty"] = *r.PresencePenalty
}
if r.TopP != nil {
options["top_p"] = *r.TopP
} else {
options["top_p"] = 1.0
}
var format json.RawMessage
if r.ResponseFormat != nil {
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
// Support the old "json_object" type for OpenAI compatibility
case "json_object":
format = json.RawMessage(`"json"`)
case "json_schema":
if r.ResponseFormat.JsonSchema != nil {
format = r.ResponseFormat.JsonSchema.Schema
}
}
}
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
}, nil
}
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
options := make(map[string]any)
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
}
options["stop"] = stops
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
}
if r.Temperature != nil {
options["temperature"] = *r.Temperature
} else {
options["temperature"] = 1.0
}
if r.Seed != nil {
options["seed"] = *r.Seed
}
options["frequency_penalty"] = r.FrequencyPenalty
options["presence_penalty"] = r.PresencePenalty
if r.TopP != 0.0 {
options["top_p"] = r.TopP
} else {
options["top_p"] = 1.0
}
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
Suffix: r.Suffix,
}, nil
}
type BaseWriter struct {
gin.ResponseWriter
}
type ChatWriter struct {
stream bool
id string
BaseWriter
}
type CompleteWriter struct {
stream bool
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
// chat chunk
if w.stream {
d, err := json.Marshal(toChunk(w.id, chatResponse))
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if chatResponse.Done {
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
// response writer
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := fromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
var b bytes.Buffer
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
}
c.Writer = w
c.Next()
}
}