add suffix support to generate endpoint

this change is triggered by the presence of "suffix", particularly
useful for code completion tasks
This commit is contained in:
Michael Yang 2024-06-20 19:13:36 -07:00
parent 987dbab0b0
commit d290e87513
6 changed files with 155 additions and 27 deletions

View file

@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) {
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add 10ms delay to simulate loading
time.Sleep(10 * time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) {
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) {
}
})
t.Run("missing capabilities", func(t *testing.T) {
t.Run("missing capabilities chat", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "bert",
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) {
}
if actual.TotalDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
t.Errorf("expected total duration > 0, got 0")
}
}
@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) {
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) {
}
})
t.Run("missing capabilities", func(t *testing.T) {
t.Run("missing capabilities generate", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "bert",
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) {
}
})
t.Run("missing capabilities suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) {
}
if actual.TotalDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
t.Errorf("expected total duration > 0, got 0")
}
}
@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) {
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-suffix",
Modelfile: `FROM test
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}"""`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("prompt without suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("raw", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",