llm: remove internal subprocess req and resp types (#9324)

This commit refactors the LLM subsystem by removing internal subprocess
request and response types. It consolidates duplicate type definitions
across the codebase, moving them to centralized locations. The change also
standardizes interfaces between components, simplifies the ServerStatusResp
struct, and moves the ParseDurationMs function to a common package. This
cleanup reduces code duplication between different runner implementations
(llamarunner and ollamarunner).
This commit is contained in:
Bruce MacDonald 2025-03-14 15:21:53 -07:00 committed by GitHub
parent 4e320b8b90
commit 3892c3a703
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 125 additions and 354 deletions

View file

@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal) s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
} }
slog.Info("starting llama server", "cmd", s.cmd.String()) slog.Info("starting llama server", "cmd", s.cmd)
if envconfig.Debug() { if envconfig.Debug() {
filteredEnv := []string{} filteredEnv := []string{}
for _, ev := range s.cmd.Env { for _, ev := range s.cmd.Env {
@ -470,7 +470,7 @@ const ( // iota is reset to 0
ServerStatusError ServerStatusError
) )
func (s ServerStatus) ToString() string { func (s ServerStatus) String() string {
switch s { switch s {
case ServerStatusReady: case ServerStatusReady:
return "llm server ready" return "llm server ready"
@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string {
} }
} }
type ServerStatusResp struct { type ServerStatusResponse struct {
Status string `json:"status"` Status ServerStatus `json:"status"`
SlotsIdle int `json:"slots_idle"` Progress float32 `json:"progress"`
SlotsProcessing int `json:"slots_processing"`
Error string `json:"error"`
Progress float32 `json:"progress"`
} }
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
} }
if s.cmd.ProcessState.ExitCode() == -1 { if s.cmd.ProcessState.ExitCode() == -1 {
// Most likely a signal killed it, log some more details to try to help troubleshoot // Most likely a signal killed it, log some more details to try to help troubleshoot
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String()) slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
} }
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
} }
@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
return ServerStatusError, fmt.Errorf("read health request: %w", err) return ServerStatusError, fmt.Errorf("read health request: %w", err)
} }
var status ServerStatusResp var ssr ServerStatusResponse
if err := json.Unmarshal(body, &status); err != nil { if err := json.Unmarshal(body, &ssr); err != nil {
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err) return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
} }
switch status.Status { switch ssr.Status {
case "ok": case ServerStatusLoadingModel:
return ServerStatusReady, nil s.loadProgress = ssr.Progress
case "no slot available": return ssr.Status, nil
return ServerStatusNoSlotsAvailable, nil case ServerStatusReady, ServerStatusNoSlotsAvailable:
case "loading model": return ssr.Status, nil
s.loadProgress = status.Progress
return ServerStatusLoadingModel, nil
default: default:
return ServerStatusError, fmt.Errorf("server error: %+v", status) return ssr.Status, fmt.Errorf("server error: %+v", ssr)
} }
} }
@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
status, _ := s.getServerStatus(ctx) status, _ := s.getServerStatus(ctx)
if lastStatus != status && status != ServerStatusReady { if lastStatus != status && status != ServerStatusReady {
// Only log on status changes // Only log on status changes
slog.Info("waiting for server to become available", "status", status.ToString()) slog.Info("waiting for server to become available", "status", status)
} }
switch status { switch status {
case ServerStatusReady: case ServerStatusReady:
@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
stallTimer = time.Now().Add(stallDuration) stallTimer = time.Now().Add(stallDuration)
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 { } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString()) slog.Debug("model load completed, waiting for server to become available", "status", status)
stallTimer = time.Now().Add(stallDuration) stallTimer = time.Now().Add(stallDuration)
fullyLoaded = true fullyLoaded = true
} }
@ -671,63 +666,26 @@ type ImageData struct {
AspectRatioID int `json:"aspect_ratio_id"` AspectRatioID int `json:"aspect_ratio_id"`
} }
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
}
type CompletionRequest struct { type CompletionRequest struct {
Prompt string Prompt string
Format json.RawMessage Format json.RawMessage
Images []ImageData Images []ImageData
Options *api.Options Options *api.Options
Grammar string // set before sending the request to the subprocess
} }
type CompletionResponse struct { type CompletionResponse struct {
Content string Content string `json:"content"`
DoneReason string DoneReason string `json:"done_reason"`
Done bool Done bool `json:"done"`
PromptEvalCount int PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int EvalCount int `json:"eval_count"`
EvalDuration time.Duration EvalDuration time.Duration `json:"eval_duration"`
} }
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
request := map[string]any{
"prompt": req.Prompt,
"stream": true,
"n_predict": req.Options.NumPredict,
"n_keep": req.Options.NumKeep,
"main_gpu": req.Options.MainGPU,
"temperature": req.Options.Temperature,
"top_k": req.Options.TopK,
"top_p": req.Options.TopP,
"min_p": req.Options.MinP,
"typical_p": req.Options.TypicalP,
"repeat_last_n": req.Options.RepeatLastN,
"repeat_penalty": req.Options.RepeatPenalty,
"presence_penalty": req.Options.PresencePenalty,
"frequency_penalty": req.Options.FrequencyPenalty,
"mirostat": req.Options.Mirostat,
"mirostat_tau": req.Options.MirostatTau,
"mirostat_eta": req.Options.MirostatEta,
"seed": req.Options.Seed,
"stop": req.Options.Stop,
"image_data": req.Images,
"cache_prompt": true,
}
if len(req.Format) > 0 { if len(req.Format) > 0 {
switch string(req.Format) { switch string(req.Format) {
case `null`, `""`: case `null`, `""`:
@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
// these as "not set". // these as "not set".
break break
case `"json"`: case `"json"`:
request["grammar"] = grammarJSON req.Grammar = grammarJSON
default: default:
if req.Format[0] != '{' { if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if g == nil { if g == nil {
return fmt.Errorf("invalid JSON schema in format") return fmt.Errorf("invalid JSON schema in format")
} }
request["grammar"] = string(g) req.Grammar = string(g)
} }
} }
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection") slog.Info("aborting completion request due to client closing the connection")
@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if err != nil { if err != nil {
return err return err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %s", status.ToString()) return fmt.Errorf("unexpected server status: %s", status)
} }
// Handling JSON marshaling with special characters unescaped. // Handling JSON marshaling with special characters unescaped.
@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
enc := json.NewEncoder(buffer) enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false) enc.SetEscapeHTML(false)
if err := enc.Encode(request); err != nil { if err := enc.Encode(req); err != nil {
return fmt.Errorf("failed to marshal data: %v", err) return fmt.Errorf("failed to marshal data: %v", err)
} }
@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
evt = line evt = line
} }
var c completion var c CompletionResponse
if err := json.Unmarshal(evt, &c); err != nil { if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshalling llm prediction response: %v", err) return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
} }
@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}) })
} }
if c.Stop { if c.Done {
doneReason := "stop" fn(c)
if c.StoppedLimit {
doneReason = "length"
}
fn(CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
})
return nil return nil
} }
} }
@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) return nil, fmt.Errorf("unexpected server status: %s", status)
} }
data, err := json.Marshal(EmbeddingRequest{Content: input}) data, err := json.Marshal(EmbeddingRequest{Content: input})
@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
} }
return 0 return 0
} }
func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
return dur
}

View file

@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
) )
@ -99,7 +100,7 @@ type NewSequenceParams struct {
embedding bool embedding bool
} }
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait() s.ready.Wait()
startTime := time.Now() startTime := time.Now()
@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// generating image embeddings for each image // generating image embeddings for each image
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
var inputs []input var inputs []input
var parts []string var parts []string
var matches [][]string var matches [][]string
@ -229,7 +230,7 @@ type Server struct {
image *ImageContext image *ImageContext
// status for external health reporting - loading, ready to serve, etc. // status for external health reporting - loading, ready to serve, etc.
status ServerStatus status llm.ServerStatus
// current progress on loading the model // current progress on loading the model
progress float32 progress float32
@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return nil return nil
} }
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) { func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest var req llm.CompletionRequest
req.Options = Options(api.DefaultOptions())
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest) http.Error(w, "Bad request", http.StatusBadRequest)
return return
} }
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming // Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Transfer-Encoding", "chunked")
@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
var samplingParams llama.SamplingParams // Extract options from the CompletionRequest
samplingParams.TopK = req.TopK samplingParams := llama.SamplingParams{
samplingParams.TopP = req.TopP TopK: req.Options.TopK,
samplingParams.MinP = req.MinP TopP: req.Options.TopP,
samplingParams.TypicalP = req.TypicalP MinP: req.Options.MinP,
samplingParams.Temp = req.Temperature TypicalP: req.Options.TypicalP,
samplingParams.RepeatLastN = req.RepeatLastN Temp: req.Options.Temperature,
samplingParams.PenaltyRepeat = req.RepeatPenalty RepeatLastN: req.Options.RepeatLastN,
samplingParams.PenaltyFreq = req.FrequencyPenalty PenaltyRepeat: req.Options.RepeatPenalty,
samplingParams.PenaltyPresent = req.PresencePenalty PenaltyFreq: req.Options.FrequencyPenalty,
samplingParams.Mirostat = req.Mirostat PenaltyPresent: req.Options.PresencePenalty,
samplingParams.MirostatTau = req.MirostatTau Mirostat: req.Options.Mirostat,
samplingParams.MirostatEta = req.MirostatEta MirostatTau: req.Options.MirostatTau,
samplingParams.Seed = uint32(req.Seed) MirostatEta: req.Options.MirostatEta,
samplingParams.Grammar = req.Grammar Seed: uint32(req.Options.Seed),
Grammar: req.Grammar,
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict, numPredict: req.Options.NumPredict,
stop: req.Stop, stop: req.Options.Stop,
numKeep: req.NumKeep, numKeep: req.Options.NumKeep,
samplingParams: &samplingParams, samplingParams: &samplingParams,
embedding: false, embedding: false,
}) })
@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false found := false
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content, Content: content,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
// Send the final response // Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{ doneReason := "stop"
Stop: true, if seq.doneReason == "limit" {
StoppedLimit: seq.doneReason == "limit", doneReason = "length"
Timings: Timings{ }
PromptN: seq.numPromptInputs, if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), Done: true,
PredictedN: seq.numDecoded, DoneReason: doneReason,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), PromptEvalCount: seq.numPromptInputs,
}, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numDecoded,
EvalDuration: time.Since(seq.startGenerationTime),
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
} }
@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
} }
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest var req llm.EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return return
@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
found := false found := false
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embedding := <-seq.embedding embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: embedding, Embedding: embedding,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
} }
} }
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) { func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{ if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
Status: s.status.ToString(), Status: s.status,
Progress: s.progress, Progress: s.progress,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@ -879,7 +794,7 @@ func (s *Server) loadModel(
panic(err) panic(err)
} }
s.status = ServerStatusReady s.status = llm.ServerStatusReady
s.ready.Done() s.ready.Done()
} }
@ -937,7 +852,7 @@ func Execute(args []string) error {
parallel: *parallel, parallel: *parallel,
seqs: make([]*Sequence, *parallel), seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)), seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: ServerStatusLoadingModel, status: llm.ServerStatusLoadingModel,
} }
var tensorSplitFloats []float32 var tensorSplitFloats []float32

View file

@ -107,6 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
return nil, nil, err return nil, nil, err
} }
// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
if !cachePrompt { if !cachePrompt {
numPast = 0 numPast = 0
} }

View file

@ -24,6 +24,7 @@ import (
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
@ -94,7 +95,7 @@ type NewSequenceParams struct {
embedding bool embedding bool
} }
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait() s.ready.Wait()
startTime := time.Now() startTime := time.Now()
@ -145,7 +146,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) {
var inputs []input.Input var inputs []input.Input
var parts []string var parts []string
var matches [][]string var matches [][]string
@ -222,7 +223,7 @@ type Server struct {
model model.Model model model.Model
// status for external health reporting - loading, ready to serve, etc. // status for external health reporting - loading, ready to serve, etc.
status ServerStatus status llm.ServerStatus
// current progress on loading the model // current progress on loading the model
progress float32 progress float32
@ -501,75 +502,18 @@ func (s *Server) processBatch() error {
return nil return nil
} }
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) { func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest var req llm.CompletionRequest
req.Options = Options(api.DefaultOptions())
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest) http.Error(w, "Bad request", http.StatusBadRequest)
return return
} }
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming // Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Transfer-Encoding", "chunked")
@ -591,18 +535,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
sampler := sample.NewSampler( sampler := sample.NewSampler(
req.Temperature, req.Options.Temperature,
req.TopK, req.Options.TopK,
req.TopP, req.Options.TopP,
req.MinP, req.Options.MinP,
req.Seed, req.Options.Seed,
grammar, grammar,
) )
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict, numPredict: req.Options.NumPredict,
stop: req.Stop, stop: req.Options.Stop,
numKeep: int32(req.NumKeep), numKeep: int32(req.Options.NumKeep),
sampler: sampler, sampler: sampler,
embedding: false, embedding: false,
}) })
@ -625,7 +569,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false found := false
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@ -652,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content, Content: content,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@ -663,15 +607,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
// Send the final response // Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{ doneReason := "stop"
Stop: true, if seq.doneReason == "limit" {
StoppedLimit: seq.doneReason == "limit", doneReason = "length"
Timings: Timings{ }
PromptN: seq.numPromptInputs, if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), Done: true,
PredictedN: seq.numPredicted, DoneReason: doneReason,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), PromptEvalCount: seq.numPromptInputs,
}, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numPredicted,
EvalDuration: time.Since(seq.startGenerationTime),
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
} }
@ -682,43 +628,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
} }
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) { func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{ if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
Status: s.status.ToString(), Status: s.status,
Progress: s.progress, Progress: s.progress,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@ -772,7 +685,7 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel) s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
s.status = ServerStatusReady s.status = llm.ServerStatusReady
s.ready.Done() s.ready.Done()
} }
@ -824,7 +737,7 @@ func Execute(args []string) error {
server := &Server{ server := &Server{
batchSize: *batchSize, batchSize: *batchSize,
status: ServerStatusLoadingModel, status: llm.ServerStatusLoadingModel,
} }
// TODO(jessegross): Parameters that need to be implemented: // TODO(jessegross): Parameters that need to be implemented: