OpenAI: v1/completions compatibility (#5209)

* OpenAI v1 models

* Refactor Writers

* Add Test

Co-Authored-By: Attila Kerekes

* Credit Co-Author

Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com>

* Empty List Testing

* Use Namespace for Ownedby

* Update Test

* Add back envconfig

* v1/models docs

* Use ModelName Parser

* Test Names

* Remove Docs

* Clean Up

* Test name

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Add Middleware for Chat and List

* Completions Endpoint

* Testing Cleanup

* Test with Fatal

* Add functionality to chat test

* Rename function

* float types

* type cleanup

* cleaning

* more cleaning

* Extra test cases

* merge conflicts

* merge conflicts

* merge conflicts

* merge conflicts

* cleaning

* cleaning

---------

Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
This commit is contained in:
royjhan 2024-07-02 16:01:45 -07:00 committed by GitHub
parent dddb58a38b
commit d626b99b54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 353 additions and 3 deletions

View file

@ -3,9 +3,11 @@ package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@ -69,6 +71,8 @@ func TestMiddleware(t *testing.T) {
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var chatResp ChatCompletion
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
t.Fatal(err)
@ -83,6 +87,130 @@ func TestMiddleware(t *testing.T) {
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.GenerateResponse{
Response: "Hello!",
})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Hello!" {
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with params",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
var generateReq api.GenerateRequest
if err := c.ShouldBindJSON(&generateReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
temperature := generateReq.Options["temperature"].(float64)
var assistantMessage string
switch temperature {
case 1.6:
assistantMessage = "Received temperature of 1.6"
default:
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
}
c.JSON(http.StatusOK, api.GenerateResponse{
Response: assistantMessage,
})
},
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with error",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
t.Fatalf("error was not forwarded")
}
},
},
{
Name: "list handler",
Method: http.MethodGet,
@ -99,6 +227,8 @@ func TestMiddleware(t *testing.T) {
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err)
@ -162,8 +292,6 @@ func TestMiddleware(t *testing.T) {
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
tc.Expected(t, resp)
})
}