mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
api: structured outputs - chat endpoint (#7900)
Adds structured outputs to chat endpoint --------- Co-authored-by: Michael Yang <mxyng@pm.me> Co-authored-by: Hieu Nguyen <hieunguyen1053@outlook.com>
This commit is contained in:
parent
eb8366d658
commit
630e7dc6ff
10 changed files with 180 additions and 25 deletions
|
@ -94,7 +94,7 @@ type ChatRequest struct {
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
|
||||||
// Format is the format to return the response in (e.g. "json").
|
// Format is the format to return the response in (e.g. "json").
|
||||||
Format string `json:"format"`
|
Format json.RawMessage `json:"format,omitempty"`
|
||||||
|
|
||||||
// KeepAlive controls how long the model will stay loaded into memory
|
// KeepAlive controls how long the model will stay loaded into memory
|
||||||
// following the request.
|
// following the request.
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -1038,7 +1039,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||||
req := &api.ChatRequest{
|
req := &api.ChatRequest{
|
||||||
Model: opts.Model,
|
Model: opts.Model,
|
||||||
Messages: opts.Messages,
|
Messages: opts.Messages,
|
||||||
Format: opts.Format,
|
Format: json.RawMessage(opts.Format),
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -85,9 +85,12 @@ COMPILER inline get_compiler() {
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/cgo"
|
"runtime/cgo"
|
||||||
"slices"
|
"slices"
|
||||||
|
@ -699,3 +702,33 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
|
||||||
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
|
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
|
||||||
C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
|
C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type JsonSchema struct {
|
||||||
|
Defs map[string]any `json:"$defs,omitempty"`
|
||||||
|
Properties map[string]any `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (js JsonSchema) AsGrammar() string {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(js); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
cStr := C.CString(b.String())
|
||||||
|
defer C.free(unsafe.Pointer(cStr))
|
||||||
|
|
||||||
|
// Allocate buffer for grammar output with reasonable size
|
||||||
|
const maxLen = 32768 // 32KB
|
||||||
|
buf := make([]byte, maxLen)
|
||||||
|
|
||||||
|
// Call C function to convert schema to grammar
|
||||||
|
length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
|
||||||
|
if length == 0 {
|
||||||
|
slog.Warn("unable to convert schema to grammar")
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(buf[:length])
|
||||||
|
}
|
||||||
|
|
|
@ -1 +1,70 @@
|
||||||
package llama
|
package llama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJsonSchema(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
schema JsonSchema
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty schema",
|
||||||
|
schema: JsonSchema{
|
||||||
|
Type: "object",
|
||||||
|
},
|
||||||
|
expected: `array ::= "[" space ( value ("," space value)* )? "]" space
|
||||||
|
boolean ::= ("true" | "false") space
|
||||||
|
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||||
|
decimal-part ::= [0-9]{1,16}
|
||||||
|
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||||
|
null ::= "null" space
|
||||||
|
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||||
|
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||||
|
root ::= object
|
||||||
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
|
string ::= "\"" char* "\"" space
|
||||||
|
value ::= object | array | string | number | boolean | null`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid schema with circular reference",
|
||||||
|
schema: JsonSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]any{
|
||||||
|
"self": map[string]any{
|
||||||
|
"$ref": "#", // Self reference
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "", // Should return empty string for invalid schema
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "schema with invalid type",
|
||||||
|
schema: JsonSchema{
|
||||||
|
Type: "invalid_type", // Invalid type
|
||||||
|
Properties: map[string]any{
|
||||||
|
"foo": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "", // Should return empty string for invalid schema
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := tc.schema.AsGrammar()
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
|
||||||
|
if diff := cmp.Diff(tc.expected, result); diff != "" {
|
||||||
|
t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
29
llama/sampling_ext.cpp
vendored
29
llama/sampling_ext.cpp
vendored
|
@ -1,11 +1,13 @@
|
||||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "sampling_ext.h"
|
#include "sampling_ext.h"
|
||||||
|
#include "json-schema-to-grammar.h"
|
||||||
|
|
||||||
struct gpt_sampler *gpt_sampler_cinit(
|
struct gpt_sampler *gpt_sampler_cinit(
|
||||||
const struct llama_model *model, struct gpt_sampler_cparams *params)
|
const struct llama_model *model, struct gpt_sampler_cparams *params)
|
||||||
{
|
{
|
||||||
try {
|
try
|
||||||
|
{
|
||||||
gpt_sampler_params sparams;
|
gpt_sampler_params sparams;
|
||||||
sparams.top_k = params->top_k;
|
sparams.top_k = params->top_k;
|
||||||
sparams.top_p = params->top_p;
|
sparams.top_p = params->top_p;
|
||||||
|
@ -24,7 +26,9 @@ struct gpt_sampler *gpt_sampler_cinit(
|
||||||
sparams.seed = params->seed;
|
sparams.seed = params->seed;
|
||||||
sparams.grammar = params->grammar;
|
sparams.grammar = params->grammar;
|
||||||
return gpt_sampler_init(model, sparams);
|
return gpt_sampler_init(model, sparams);
|
||||||
} catch (const std::exception & err) {
|
}
|
||||||
|
catch (const std::exception &err)
|
||||||
|
{
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -54,3 +58,24 @@ void gpt_sampler_caccept(
|
||||||
{
|
{
|
||||||
gpt_sampler_accept(sampler, id, apply_grammar);
|
gpt_sampler_accept(sampler, id, apply_grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
nlohmann::json schema = nlohmann::json::parse(json_schema);
|
||||||
|
std::string grammar_str = json_schema_to_grammar(schema);
|
||||||
|
size_t len = grammar_str.length();
|
||||||
|
if (len >= max_len)
|
||||||
|
{
|
||||||
|
len = max_len - 1;
|
||||||
|
}
|
||||||
|
strncpy(grammar, grammar_str.c_str(), len);
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
catch (const std::exception &e)
|
||||||
|
{
|
||||||
|
strncpy(grammar, "", max_len - 1);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
2
llama/sampling_ext.h
vendored
2
llama/sampling_ext.h
vendored
|
@ -47,6 +47,8 @@ extern "C"
|
||||||
llama_token id,
|
llama_token id,
|
||||||
bool apply_grammar);
|
bool apply_grammar);
|
||||||
|
|
||||||
|
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -634,27 +634,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
const jsonGrammar = `
|
const jsonGrammar = `
|
||||||
root ::= object
|
root ::= object
|
||||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
object ::=
|
object ::=
|
||||||
"{" ws (
|
"{" ws (
|
||||||
string ":" ws value
|
string ":" ws value
|
||||||
("," ws string ":" ws value)*
|
("," ws string ":" ws value)*
|
||||||
)? "}" ws
|
)? "}" ws
|
||||||
|
|
||||||
array ::=
|
array ::=
|
||||||
"[" ws (
|
"[" ws (
|
||||||
value
|
value
|
||||||
("," ws value)*
|
("," ws value)*
|
||||||
)? "]" ws
|
)? "]" ws
|
||||||
|
|
||||||
string ::=
|
string ::=
|
||||||
"\"" (
|
"\"" (
|
||||||
[^"\\\x7F\x00-\x1F] |
|
[^"\\\x7F\x00-\x1F] |
|
||||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
)* "\"" ws
|
)* "\"" ws
|
||||||
|
|
||||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
ws ::= ([ \t\n] ws)?
|
ws ::= ([ \t\n] ws)?
|
||||||
`
|
`
|
||||||
|
@ -684,7 +679,7 @@ type completion struct {
|
||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format string
|
Format json.RawMessage
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
}
|
}
|
||||||
|
@ -749,10 +744,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Format == "json" {
|
// TODO (parthsareen): Move conversion to grammar with sampling logic
|
||||||
request["grammar"] = jsonGrammar
|
// API should do error handling for invalid formats
|
||||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
if req.Format != nil {
|
||||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
|
||||||
|
request["grammar"] = jsonGrammar
|
||||||
|
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||||
|
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||||
|
}
|
||||||
|
} else if schema, err := func() (llama.JsonSchema, error) {
|
||||||
|
var schema llama.JsonSchema
|
||||||
|
err := json.Unmarshal(req.Format, &schema)
|
||||||
|
return schema, err
|
||||||
|
}(); err == nil {
|
||||||
|
request["grammar"] = schema.AsGrammar()
|
||||||
|
} else {
|
||||||
|
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,12 @@ type Usage struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JsonSchema struct {
|
||||||
|
Schema map[string]any `json:"schema"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbedRequest struct {
|
type EmbedRequest struct {
|
||||||
|
@ -482,9 +487,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
options["top_p"] = 1.0
|
options["top_p"] = 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
var format string
|
var format json.RawMessage
|
||||||
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
|
if r.ResponseFormat != nil {
|
||||||
format = "json"
|
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
||||||
|
// Support the old "json_object" type for OpenAI compatibility
|
||||||
|
case "json_object":
|
||||||
|
format = json.RawMessage(`"json"`)
|
||||||
|
case "json_schema":
|
||||||
|
if r.ResponseFormat.JsonSchema != nil {
|
||||||
|
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
|
||||||
|
}
|
||||||
|
format = schema
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.ChatRequest{
|
return &api.ChatRequest{
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
@ -107,7 +108,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
"presence_penalty": 5.0,
|
"presence_penalty": 5.0,
|
||||||
"top_p": 6.0,
|
"top_p": 6.0,
|
||||||
},
|
},
|
||||||
Format: "json",
|
Format: json.RawMessage(`"json"`),
|
||||||
Stream: &True,
|
Stream: &True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -316,13 +317,13 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||||
t.Fatal("requests did not match")
|
t.Fatalf("requests did not match: %+v", diff)
|
||||||
}
|
}
|
||||||
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||||
if !reflect.DeepEqual(tc.err, errResp) {
|
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||||
t.Fatal("errors did not match")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -278,7 +278,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: json.RawMessage(req.Format),
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue