Update the /api/create endpoint to use JSON (#7935)

Replaces `POST /api/create` to use JSON instead of a Modelfile.

This is a breaking change.
This commit is contained in:
Patrick Devine 2024-12-31 18:02:30 -08:00 committed by GitHub
parent 459d822b51
commit 86a622cbdc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1523 additions and 1094 deletions

View file

@ -2,18 +2,24 @@ package parser
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"strings"
"testing"
"unicode/utf16"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/encoding"
"golang.org/x/text/encoding/unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
func TestParseFileFile(t *testing.T) {
@ -673,3 +679,150 @@ func TestParseMultiByte(t *testing.T) {
})
}
}
func TestCreateRequest(t *testing.T) {
cases := []struct {
input string
expected *api.CreateRequest
}{
{
`FROM test`,
&api.CreateRequest{From: "test"},
},
{
`FROM test
TEMPLATE some template
`,
&api.CreateRequest{
From: "test",
Template: "some template",
},
},
{
`FROM test
LICENSE single license
PARAMETER temperature 0.5
MESSAGE user Hello
`,
&api.CreateRequest{
From: "test",
License: []string{"single license"},
Parameters: map[string]any{"temperature": float32(0.5)},
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
},
},
{
`FROM test
PARAMETER temperature 0.5
PARAMETER top_k 1
SYSTEM You are a bot.
LICENSE license1
LICENSE license2
MESSAGE user Hello there!
MESSAGE assistant Hi! How are you?
`,
&api.CreateRequest{
From: "test",
License: []string{"license1", "license2"},
System: "You are a bot.",
Parameters: map[string]any{"temperature": float32(0.5), "top_k": int64(1)},
Messages: []api.Message{
{Role: "user", Content: "Hello there!"},
{Role: "assistant", Content: "Hi! How are you?"},
},
},
},
}
for _, c := range cases {
s, err := unicode.UTF8.NewEncoder().String(c.input)
if err != nil {
t.Fatal(err)
}
p, err := ParseFile(strings.NewReader(s))
if err != nil {
t.Error(err)
}
actual, err := p.CreateRequest()
if err != nil {
t.Error(err)
}
if diff := cmp.Diff(actual, c.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) {
t.Helper()
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
t.Fatal(err)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) (string, string) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := llm.WriteGGUF(f, kv, ti); err != nil {
t.Fatal(err)
}
// Calculate sha256 of file
if _, err := f.Seek(0, 0); err != nil {
t.Fatal(err)
}
digest, _ := getSHA256Digest(t, f)
return f.Name(), digest
}
func TestCreateRequestFiles(t *testing.T) {
name, digest := createBinFile(t, nil, nil)
cases := []struct {
input string
expected *api.CreateRequest
}{
{
fmt.Sprintf("FROM %s", name),
&api.CreateRequest{Files: map[string]string{name: digest}},
},
}
for _, c := range cases {
s, err := unicode.UTF8.NewEncoder().String(c.input)
if err != nil {
t.Fatal(err)
}
p, err := ParseFile(strings.NewReader(s))
if err != nil {
t.Error(err)
}
actual, err := p.CreateRequest()
if err != nil {
t.Error(err)
}
if diff := cmp.Diff(actual, c.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}