From 3892c3a7032c99db250c3266276c4525d243950a Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 14 Mar 2025 15:21:53 -0700 Subject: [PATCH] llm: remove internal subprocess req and resp types (#9324) This commit refactors the LLM subsystem by removing internal subprocess request and response types. It consolidates duplicate type definitions across the codebase, moving them to centralized locations. The change also standardizes interfaces between components, simplifies the ServerStatusResp struct, and moves the ParseDurationMs function to a common package. This cleanup reduces code duplication between different runner implementations (llamarunner and ollamarunner). --- llm/server.go | 136 +++++++------------------ runner/llamarunner/runner.go | 185 +++++++++------------------------- runner/ollamarunner/cache.go | 1 + runner/ollamarunner/runner.go | 157 +++++++---------------------- 4 files changed, 125 insertions(+), 354 deletions(-) diff --git a/llm/server.go b/llm/server.go index c6f117125..adc11aaea 100644 --- a/llm/server.go +++ b/llm/server.go @@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal) } - slog.Info("starting llama server", "cmd", s.cmd.String()) + slog.Info("starting llama server", "cmd", s.cmd) if envconfig.Debug() { filteredEnv := []string{} for _, ev := range s.cmd.Env { @@ -470,7 +470,7 @@ const ( // iota is reset to 0 ServerStatusError ) -func (s ServerStatus) ToString() string { +func (s ServerStatus) String() string { switch s { case ServerStatusReady: return "llm server ready" @@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string { } } -type ServerStatusResp struct { - Status string `json:"status"` - SlotsIdle int `json:"slots_idle"` - SlotsProcessing int `json:"slots_processing"` - Error string `json:"error"` - Progress float32 `json:"progress"` +type ServerStatusResponse struct { + Status ServerStatus `json:"status"` + Progress float32 `json:"progress"` } func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { @@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { } if s.cmd.ProcessState.ExitCode() == -1 { // Most likely a signal killed it, log some more details to try to help troubleshoot - slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String()) + slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState) } return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) } @@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { return ServerStatusError, fmt.Errorf("read health request: %w", err) } - var status ServerStatusResp - if err := json.Unmarshal(body, &status); err != nil { + var ssr ServerStatusResponse + if err := json.Unmarshal(body, &ssr); err != nil { return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err) } - switch status.Status { - case "ok": - return ServerStatusReady, nil - case "no slot available": - return ServerStatusNoSlotsAvailable, nil - case "loading model": - s.loadProgress = status.Progress - return ServerStatusLoadingModel, nil + switch ssr.Status { + case ServerStatusLoadingModel: + s.loadProgress = ssr.Progress + return ssr.Status, nil + case ServerStatusReady, ServerStatusNoSlotsAvailable: + return ssr.Status, nil default: - return ServerStatusError, fmt.Errorf("server error: %+v", status) + return ssr.Status, fmt.Errorf("server error: %+v", ssr) } } @@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { status, _ := s.getServerStatus(ctx) if lastStatus != status && status != ServerStatusReady { // Only log on status changes - slog.Info("waiting for server to become available", "status", status.ToString()) + slog.Info("waiting for server to become available", "status", status) } switch status { case ServerStatusReady: @@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) stallTimer = time.Now().Add(stallDuration) } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 { - slog.Debug("model load completed, waiting for server to become available", "status", status.ToString()) + slog.Debug("model load completed, waiting for server to become available", "status", status) stallTimer = time.Now().Add(stallDuration) fullyLoaded = true } @@ -671,63 +666,26 @@ type ImageData struct { AspectRatioID int `json:"aspect_ratio_id"` } -type completion struct { - Content string `json:"content"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedLimit bool `json:"stopped_limit"` - - Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` - } -} - type CompletionRequest struct { Prompt string Format json.RawMessage Images []ImageData Options *api.Options + + Grammar string // set before sending the request to the subprocess } type CompletionResponse struct { - Content string - DoneReason string - Done bool - PromptEvalCount int - PromptEvalDuration time.Duration - EvalCount int - EvalDuration time.Duration + Content string `json:"content"` + DoneReason string `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { - request := map[string]any{ - "prompt": req.Prompt, - "stream": true, - "n_predict": req.Options.NumPredict, - "n_keep": req.Options.NumKeep, - "main_gpu": req.Options.MainGPU, - "temperature": req.Options.Temperature, - "top_k": req.Options.TopK, - "top_p": req.Options.TopP, - "min_p": req.Options.MinP, - "typical_p": req.Options.TypicalP, - "repeat_last_n": req.Options.RepeatLastN, - "repeat_penalty": req.Options.RepeatPenalty, - "presence_penalty": req.Options.PresencePenalty, - "frequency_penalty": req.Options.FrequencyPenalty, - "mirostat": req.Options.Mirostat, - "mirostat_tau": req.Options.MirostatTau, - "mirostat_eta": req.Options.MirostatEta, - "seed": req.Options.Seed, - "stop": req.Options.Stop, - "image_data": req.Images, - "cache_prompt": true, - } - if len(req.Format) > 0 { switch string(req.Format) { case `null`, `""`: @@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu // these as "not set". break case `"json"`: - request["grammar"] = grammarJSON + req.Grammar = grammarJSON default: if req.Format[0] != '{' { return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) @@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if g == nil { return fmt.Errorf("invalid JSON schema in format") } - request["grammar"] = string(g) + req.Grammar = string(g) } } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + if err := s.sem.Acquire(ctx, 1); err != nil { if errors.Is(err, context.Canceled) { slog.Info("aborting completion request due to client closing the connection") @@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if err != nil { return err } else if status != ServerStatusReady { - return fmt.Errorf("unexpected server status: %s", status.ToString()) + return fmt.Errorf("unexpected server status: %s", status) } // Handling JSON marshaling with special characters unescaped. @@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu enc := json.NewEncoder(buffer) enc.SetEscapeHTML(false) - if err := enc.Encode(request); err != nil { + if err := enc.Encode(req); err != nil { return fmt.Errorf("failed to marshal data: %v", err) } @@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu evt = line } - var c completion + var c CompletionResponse if err := json.Unmarshal(evt, &c); err != nil { return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } @@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu }) } - if c.Stop { - doneReason := "stop" - if c.StoppedLimit { - doneReason = "length" - } - - fn(CompletionResponse{ - Done: true, - DoneReason: doneReason, - PromptEvalCount: c.Timings.PromptN, - PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), - EvalCount: c.Timings.PredictedN, - EvalDuration: parseDurationMs(c.Timings.PredictedMS), - }) + if c.Done { + fn(c) return nil } } @@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err if err != nil { return nil, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + return nil, fmt.Errorf("unexpected server status: %s", status) } data, err := json.Marshal(EmbeddingRequest{Content: input}) @@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 { } return 0 } - -func parseDurationMs(ms float64) time.Duration { - dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) - if err != nil { - panic(err) - } - - return dur -} diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 8662afc1e..83802d604 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/runner/common" ) @@ -99,7 +100,7 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() @@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // generating image embeddings for each image -func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) { var inputs []input var parts []string var matches [][]string @@ -229,7 +230,7 @@ type Server struct { image *ImageContext // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var samplingParams llama.SamplingParams - samplingParams.TopK = req.TopK - samplingParams.TopP = req.TopP - samplingParams.MinP = req.MinP - samplingParams.TypicalP = req.TypicalP - samplingParams.Temp = req.Temperature - samplingParams.RepeatLastN = req.RepeatLastN - samplingParams.PenaltyRepeat = req.RepeatPenalty - samplingParams.PenaltyFreq = req.FrequencyPenalty - samplingParams.PenaltyPresent = req.PresencePenalty - samplingParams.Mirostat = req.Mirostat - samplingParams.MirostatTau = req.MirostatTau - samplingParams.MirostatEta = req.MirostatEta - samplingParams.Seed = uint32(req.Seed) - samplingParams.Grammar = req.Grammar + // Extract options from the CompletionRequest + samplingParams := llama.SamplingParams{ + TopK: req.Options.TopK, + TopP: req.Options.TopP, + MinP: req.Options.MinP, + TypicalP: req.Options.TypicalP, + Temp: req.Options.Temperature, + RepeatLastN: req.Options.RepeatLastN, + PenaltyRepeat: req.Options.RepeatPenalty, + PenaltyFreq: req.Options.FrequencyPenalty, + PenaltyPresent: req.Options.PresencePenalty, + Mirostat: req.Options.Mirostat, + MirostatTau: req.Options.MirostatTau, + MirostatEta: req.Options.MirostatEta, + Seed: uint32(req.Options.Seed), + Grammar: req.Grammar, + } seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: req.NumKeep, + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: req.Options.NumKeep, samplingParams: &samplingParams, embedding: false, }) @@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numDecoded, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numDecoded, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { - var req EmbeddingRequest + var req llm.EmbeddingRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) return @@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding - if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ + if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ Embedding: embedding, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } } -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -879,7 +794,7 @@ func (s *Server) loadModel( panic(err) } - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -937,7 +852,7 @@ func Execute(args []string) error { parallel: *parallel, seqs: make([]*Sequence, *parallel), seqsSem: semaphore.NewWeighted(int64(*parallel)), - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } var tensorSplitFloats []float32 diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index a411fddb1..adcb3f738 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -107,6 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp return nil, nil, err } + // TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved? if !cachePrompt { numPast = 0 } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index c380ef221..d6339a615 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -24,6 +24,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -94,7 +95,7 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() @@ -145,7 +146,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { +func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) { var inputs []input.Input var parts []string var matches [][]string @@ -222,7 +223,7 @@ type Server struct { model model.Model // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -501,75 +502,18 @@ func (s *Server) processBatch() error { return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -591,18 +535,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } sampler := sample.NewSampler( - req.Temperature, - req.TopK, - req.TopP, - req.MinP, - req.Seed, + req.Options.Temperature, + req.Options.TopK, + req.Options.TopP, + req.Options.MinP, + req.Options.Seed, grammar, ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: int32(req.NumKeep), + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: int32(req.Options.NumKeep), sampler: sampler, embedding: false, }) @@ -625,7 +569,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -652,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -663,15 +607,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numPredicted, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numPredicted, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -682,43 +628,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -772,7 +685,7 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -824,7 +737,7 @@ func Execute(args []string) error { server := &Server{ batchSize: *batchSize, - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } // TODO(jessegross): Parameters that need to be implemented: