llm: set done reason at server level (#9830)

No functional change. Many different done reasons can be set at the runner
level, so rather than obsuring them we should return them to the server
process and let it choose what to do with the done reason. This separates
the API concerns from the runner.
This commit is contained in:
Bruce MacDonald 2025-04-03 10:19:24 -07:00 committed by GitHub
parent b51e0f397c
commit e53b3cbd0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 54 additions and 42 deletions

View file

@ -675,9 +675,32 @@ type CompletionRequest struct {
Grammar string // set before sending the request to the subprocess Grammar string // set before sending the request to the subprocess
} }
// DoneReason represents the reason why a completion response is done
type DoneReason int
const (
// DoneReasonStop indicates the completion stopped naturally
DoneReasonStop DoneReason = iota
// DoneReasonLength indicates the completion stopped due to length limits
DoneReasonLength
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
DoneReasonConnectionClosed
)
func (d DoneReason) String() string {
switch d {
case DoneReasonLength:
return "length"
case DoneReasonStop:
return "stop"
default:
return "" // closed
}
}
type CompletionResponse struct { type CompletionResponse struct {
Content string `json:"content"` Content string `json:"content"`
DoneReason string `json:"done_reason"` DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"` Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"` PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
@ -786,7 +809,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue continue
} }
// slog.Debug("got line", "line", string(line))
evt, ok := bytes.CutPrefix(line, []byte("data: ")) evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok { if !ok {
evt = line evt = line

View file

@ -83,7 +83,7 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation // true if an embedding are to be returned instead of text generation
embeddingOnly bool embeddingOnly bool
doneReason string doneReason llm.DoneReason
// Metrics // Metrics
startProcessingTime time.Time startProcessingTime time.Time
@ -301,7 +301,7 @@ func flushPending(seq *Sequence) bool {
} }
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex] seq := s.seqs[seqIndex]
flushPending(seq) flushPending(seq)
@ -380,7 +380,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit") s.removeSequence(seqIdx, llm.DoneReasonLength)
continue continue
} }
@ -482,7 +482,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
} }
seq.embedding <- embed seq.embedding <- embed
s.removeSequence(i, "") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
@ -499,7 +499,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// as it's important for the /api/generate context // as it's important for the /api/generate context
// seq.responses <- piece // seq.responses <- piece
s.removeSequence(i, "stop") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
@ -530,7 +530,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
} }
seq.cache.Inputs = seq.cache.Inputs[:tokenLen] seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, "stop") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
@ -543,7 +543,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
} }
if !flushPending(seq) { if !flushPending(seq) {
s.removeSequence(i, "connection") s.removeSequence(i, llm.DoneReasonConnectionClosed)
} }
} }
@ -657,14 +657,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
// Send the final response
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true, Done: true,
DoneReason: doneReason, DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs, PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numDecoded, EvalCount: seq.numDecoded,

View file

@ -82,7 +82,7 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation // true if an embedding are to be returned instead of text generation
embeddingOnly bool embeddingOnly bool
doneReason string doneReason llm.DoneReason
// Metrics // Metrics
startProcessingTime time.Time startProcessingTime time.Time
@ -341,7 +341,7 @@ func flushPending(seq *Sequence) bool {
} }
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex] seq := s.seqs[seqIndex]
flushPending(seq) flushPending(seq)
@ -391,7 +391,7 @@ func (s *Server) processBatch() error {
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit") s.removeSequence(seqIdx, llm.DoneReasonLength)
continue continue
} }
@ -510,7 +510,7 @@ func (s *Server) processBatch() error {
if seq.embeddingOnly { if seq.embeddingOnly {
// TODO(jessegross): Embedding support // TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported") slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, "") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
@ -528,7 +528,7 @@ func (s *Server) processBatch() error {
// as it's important for the /api/generate context // as it's important for the /api/generate context
// seq.responses <- piece // seq.responses <- piece
s.removeSequence(i, "stop") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
@ -564,7 +564,7 @@ func (s *Server) processBatch() error {
} }
seq.cache.Inputs = seq.cache.Inputs[:tokenLen] seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, "stop") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
@ -577,7 +577,7 @@ func (s *Server) processBatch() error {
} }
if !flushPending(seq) { if !flushPending(seq) {
s.removeSequence(i, "connection") s.removeSequence(i, llm.DoneReasonConnectionClosed)
} }
} }
@ -690,14 +690,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
// Send the final response
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true, Done: true,
DoneReason: doneReason, DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs, PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numPredicted, EvalCount: seq.numPredicted,

View file

@ -308,11 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Options: opts, Options: opts,
}, func(cr llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Response: cr.Content, Response: cr.Content,
Done: cr.Done, Done: cr.Done,
DoneReason: cr.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount, PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
@ -326,6 +325,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
if cr.Done { if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@ -1533,11 +1533,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
Options: opts, Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
res := api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done, Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
@ -1547,6 +1546,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
if r.Done { if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }

View file

@ -58,7 +58,7 @@ func TestGenerateChat(t *testing.T) {
mock := mockRunner{ mock := mockRunner{
CompletionResponse: llm.CompletionResponse{ CompletionResponse: llm.CompletionResponse{
Done: true, Done: true,
DoneReason: "stop", DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1, PromptEvalCount: 1,
PromptEvalDuration: 1, PromptEvalDuration: 1,
EvalCount: 1, EvalCount: 1,
@ -401,7 +401,7 @@ func TestGenerateChat(t *testing.T) {
mock.CompletionResponse = llm.CompletionResponse{ mock.CompletionResponse = llm.CompletionResponse{
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`, Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
Done: true, Done: true,
DoneReason: "done", DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1, PromptEvalCount: 1,
PromptEvalDuration: 1, PromptEvalDuration: 1,
EvalCount: 1, EvalCount: 1,
@ -519,7 +519,7 @@ func TestGenerateChat(t *testing.T) {
{ {
Content: `, WA","unit":"celsius"}}`, Content: `, WA","unit":"celsius"}}`,
Done: true, Done: true,
DoneReason: "tool_call", DoneReason: llm.DoneReasonStop,
PromptEvalCount: 3, PromptEvalCount: 3,
PromptEvalDuration: 1, PromptEvalDuration: 1,
}, },
@ -594,7 +594,7 @@ func TestGenerate(t *testing.T) {
mock := mockRunner{ mock := mockRunner{
CompletionResponse: llm.CompletionResponse{ CompletionResponse: llm.CompletionResponse{
Done: true, Done: true,
DoneReason: "stop", DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1, PromptEvalCount: 1,
PromptEvalDuration: 1, PromptEvalDuration: 1,
EvalCount: 1, EvalCount: 1,