mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
OpenAI: /v1/models and /v1/models/{model} compatibility (#5007)
* OpenAI v1 models * Refactor Writers * Add Test Co-Authored-By: Attila Kerekes * Credit Co-Author Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com> * Empty List Testing * Use Namespace for Ownedby * Update Test * Add back envconfig * v1/models docs * Use ModelName Parser * Test Names * Remove Docs * Clean Up * Test name Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com> * Add Middleware for Chat and List * Testing Cleanup * Test with Fatal * Add functionality to chat test * OpenAI: /v1/models/{model} compatibility (#5028) * Retrieve Model * OpenAI Delete Model * Retrieve Middleware * Remove Delete from Branch * Update Test * Middleware Test File * Function name * Cleanup * Test Update * Test Update --------- Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com> Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
This commit is contained in:
parent
422dcc3856
commit
996bb1b85e
6 changed files with 387 additions and 14 deletions
163
openai/openai.go
163
openai/openai.go
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
|
@ -85,6 +86,18 @@ type ChatCompletionChunk struct {
|
|||
Choices []ChunkChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
type ListCompletion struct {
|
||||
Object string `json:"object"`
|
||||
Data []Model `json:"data"`
|
||||
}
|
||||
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
|
@ -145,7 +158,33 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|||
}
|
||||
}
|
||||
|
||||
func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||
func toListCompletion(r api.ListResponse) ListCompletion {
|
||||
var data []Model
|
||||
for _, m := range r.Models {
|
||||
data = append(data, Model{
|
||||
Id: m.Name,
|
||||
Object: "model",
|
||||
Created: m.ModifiedAt.Unix(),
|
||||
OwnedBy: model.ParseName(m.Name).Namespace,
|
||||
})
|
||||
}
|
||||
|
||||
return ListCompletion{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func toModel(r api.ShowResponse, m string) Model {
|
||||
return Model{
|
||||
Id: m,
|
||||
Object: "model",
|
||||
Created: r.ModifiedAt.Unix(),
|
||||
OwnedBy: model.ParseName(m).Namespace,
|
||||
}
|
||||
}
|
||||
|
||||
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||
var messages []api.Message
|
||||
for _, msg := range r.Messages {
|
||||
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
||||
|
@ -208,13 +247,26 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
|||
}
|
||||
}
|
||||
|
||||
type writer struct {
|
||||
stream bool
|
||||
id string
|
||||
type BaseWriter struct {
|
||||
gin.ResponseWriter
|
||||
}
|
||||
|
||||
func (w *writer) writeError(code int, data []byte) (int, error) {
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type ListWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type RetrieveWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
if err != nil {
|
||||
|
@ -230,7 +282,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
|
|||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *writer) writeResponse(data []byte) (int, error) {
|
||||
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
|
@ -270,7 +322,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
|
|||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *writer) Write(data []byte) (int, error) {
|
||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(code, data)
|
||||
|
@ -279,7 +331,92 @@ func (w *writer) Write(data []byte) (int, error) {
|
|||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func Middleware() gin.HandlerFunc {
|
||||
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||
var listResponse api.ListResponse
|
||||
err := json.Unmarshal(data, &listResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(code, data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||
var showResponse api.ShowResponse
|
||||
err := json.Unmarshal(data, &showResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// retrieve completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(code, data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RetrieveMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
// response writer
|
||||
w := &RetrieveWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: c.Param("model"),
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req ChatCompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
|
@ -294,17 +431,17 @@ func Middleware() gin.HandlerFunc {
|
|||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
|
||||
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &writer{
|
||||
ResponseWriter: c.Writer,
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue