mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 02:16:36 +02:00
api: enable tool streaming (#7836)
This commit is contained in:
parent
e3936d4fb3
commit
ce7455a8e1
4 changed files with 289 additions and 13 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -25,10 +26,14 @@ type mockRunner struct {
|
|||
// CompletionRequest is only valid until the next call to Completion
|
||||
llm.CompletionRequest
|
||||
llm.CompletionResponse
|
||||
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
|
||||
}
|
||||
|
||||
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
m.CompletionRequest = r
|
||||
if m.CompletionFn != nil {
|
||||
return m.CompletionFn(ctx, r, fn)
|
||||
}
|
||||
fn(m.CompletionResponse)
|
||||
return nil
|
||||
}
|
||||
|
@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) {
|
|||
Model: "test",
|
||||
Modelfile: fmt.Sprintf(`FROM %s
|
||||
TEMPLATE """
|
||||
{{- if .System }}System: {{ .System }} {{ end }}
|
||||
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||
{{- if .Tools }}
|
||||
{{ .Tools }}
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- .Role }}: {{ .Content }}
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}
|
||||
{{ end }}"""
|
||||
`, createBinFile(t, llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
|
@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
|
@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
|
@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
|
@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) {
|
|||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||
})
|
||||
|
||||
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create test-system model: %d", w.Code)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mock.CompletionResponse = llm.CompletionResponse{
|
||||
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
|
||||
Done: true,
|
||||
DoneReason: "done",
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
}
|
||||
|
||||
streamRequest := true
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: &streamRequest,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
var errResp struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
|
||||
t.Logf("Failed to decode error response: %v", err)
|
||||
} else {
|
||||
t.Logf("Error response: %s", errResp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.ToolCalls == nil {
|
||||
t.Error("expected tool calls, got nil")
|
||||
}
|
||||
|
||||
expectedToolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Seattle, WA",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
|
||||
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("messages with tools (streaming)", func(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate streaming response with multiple chunks
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
// Send chunks with small delays to simulate streaming
|
||||
responses := []llm.CompletionResponse{
|
||||
{
|
||||
Content: `{"name":"get_`,
|
||||
Done: false,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
},
|
||||
{
|
||||
Content: `weather","arguments":{"location":"Seattle`,
|
||||
Done: false,
|
||||
PromptEvalCount: 2,
|
||||
PromptEvalDuration: 1,
|
||||
},
|
||||
{
|
||||
Content: `, WA","unit":"celsius"}}`,
|
||||
Done: true,
|
||||
DoneReason: "tool_call",
|
||||
PromptEvalCount: 3,
|
||||
PromptEvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
fn(resp)
|
||||
time.Sleep(10 * time.Millisecond) // Small delay between chunks
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Read and validate the streamed responses
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
var finalToolCall api.ToolCall
|
||||
|
||||
for {
|
||||
var resp api.ChatResponse
|
||||
if err := decoder.Decode(&resp); err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Done {
|
||||
if len(resp.Message.ToolCalls) != 1 {
|
||||
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
||||
}
|
||||
finalToolCall = resp.Message.ToolCalls[0]
|
||||
}
|
||||
}
|
||||
|
||||
expectedToolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Seattle, WA",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
||||
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue