diff --git a/llm/server.go b/llm/server.go index c6f117125..adc11aaea 100644 --- a/llm/server.go +++ b/llm/server.go @@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal) } - slog.Info("starting llama server", "cmd", s.cmd.String()) + slog.Info("starting llama server", "cmd", s.cmd) if envconfig.Debug() { filteredEnv := []string{} for _, ev := range s.cmd.Env { @@ -470,7 +470,7 @@ const ( // iota is reset to 0 ServerStatusError ) -func (s ServerStatus) ToString() string { +func (s ServerStatus) String() string { switch s { case ServerStatusReady: return "llm server ready" @@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string { } } -type ServerStatusResp struct { - Status string `json:"status"` - SlotsIdle int `json:"slots_idle"` - SlotsProcessing int `json:"slots_processing"` - Error string `json:"error"` - Progress float32 `json:"progress"` +type ServerStatusResponse struct { + Status ServerStatus `json:"status"` + Progress float32 `json:"progress"` } func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { @@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { } if s.cmd.ProcessState.ExitCode() == -1 { // Most likely a signal killed it, log some more details to try to help troubleshoot - slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String()) + slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState) } return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) } @@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { return ServerStatusError, fmt.Errorf("read health request: %w", err) } - var status ServerStatusResp - if err := json.Unmarshal(body, &status); err != nil { + var ssr ServerStatusResponse + if err := json.Unmarshal(body, &ssr); err != nil { return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err) } - switch status.Status { - case "ok": - return ServerStatusReady, nil - case "no slot available": - return ServerStatusNoSlotsAvailable, nil - case "loading model": - s.loadProgress = status.Progress - return ServerStatusLoadingModel, nil + switch ssr.Status { + case ServerStatusLoadingModel: + s.loadProgress = ssr.Progress + return ssr.Status, nil + case ServerStatusReady, ServerStatusNoSlotsAvailable: + return ssr.Status, nil default: - return ServerStatusError, fmt.Errorf("server error: %+v", status) + return ssr.Status, fmt.Errorf("server error: %+v", ssr) } } @@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { status, _ := s.getServerStatus(ctx) if lastStatus != status && status != ServerStatusReady { // Only log on status changes - slog.Info("waiting for server to become available", "status", status.ToString()) + slog.Info("waiting for server to become available", "status", status) } switch status { case ServerStatusReady: @@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) stallTimer = time.Now().Add(stallDuration) } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 { - slog.Debug("model load completed, waiting for server to become available", "status", status.ToString()) + slog.Debug("model load completed, waiting for server to become available", "status", status) stallTimer = time.Now().Add(stallDuration) fullyLoaded = true } @@ -671,63 +666,26 @@ type ImageData struct { AspectRatioID int `json:"aspect_ratio_id"` } -type completion struct { - Content string `json:"content"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedLimit bool `json:"stopped_limit"` - - Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` - } -} - type CompletionRequest struct { Prompt string Format json.RawMessage Images []ImageData Options *api.Options + + Grammar string // set before sending the request to the subprocess } type CompletionResponse struct { - Content string - DoneReason string - Done bool - PromptEvalCount int - PromptEvalDuration time.Duration - EvalCount int - EvalDuration time.Duration + Content string `json:"content"` + DoneReason string `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { - request := map[string]any{ - "prompt": req.Prompt, - "stream": true, - "n_predict": req.Options.NumPredict, - "n_keep": req.Options.NumKeep, - "main_gpu": req.Options.MainGPU, - "temperature": req.Options.Temperature, - "top_k": req.Options.TopK, - "top_p": req.Options.TopP, - "min_p": req.Options.MinP, - "typical_p": req.Options.TypicalP, - "repeat_last_n": req.Options.RepeatLastN, - "repeat_penalty": req.Options.RepeatPenalty, - "presence_penalty": req.Options.PresencePenalty, - "frequency_penalty": req.Options.FrequencyPenalty, - "mirostat": req.Options.Mirostat, - "mirostat_tau": req.Options.MirostatTau, - "mirostat_eta": req.Options.MirostatEta, - "seed": req.Options.Seed, - "stop": req.Options.Stop, - "image_data": req.Images, - "cache_prompt": true, - } - if len(req.Format) > 0 { switch string(req.Format) { case `null`, `""`: @@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu // these as "not set". break case `"json"`: - request["grammar"] = grammarJSON + req.Grammar = grammarJSON default: if req.Format[0] != '{' { return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) @@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if g == nil { return fmt.Errorf("invalid JSON schema in format") } - request["grammar"] = string(g) + req.Grammar = string(g) } } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + if err := s.sem.Acquire(ctx, 1); err != nil { if errors.Is(err, context.Canceled) { slog.Info("aborting completion request due to client closing the connection") @@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if err != nil { return err } else if status != ServerStatusReady { - return fmt.Errorf("unexpected server status: %s", status.ToString()) + return fmt.Errorf("unexpected server status: %s", status) } // Handling JSON marshaling with special characters unescaped. @@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu enc := json.NewEncoder(buffer) enc.SetEscapeHTML(false) - if err := enc.Encode(request); err != nil { + if err := enc.Encode(req); err != nil { return fmt.Errorf("failed to marshal data: %v", err) } @@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu evt = line } - var c completion + var c CompletionResponse if err := json.Unmarshal(evt, &c); err != nil { return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } @@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu }) } - if c.Stop { - doneReason := "stop" - if c.StoppedLimit { - doneReason = "length" - } - - fn(CompletionResponse{ - Done: true, - DoneReason: doneReason, - PromptEvalCount: c.Timings.PromptN, - PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), - EvalCount: c.Timings.PredictedN, - EvalDuration: parseDurationMs(c.Timings.PredictedMS), - }) + if c.Done { + fn(c) return nil } } @@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err if err != nil { return nil, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + return nil, fmt.Errorf("unexpected server status: %s", status) } data, err := json.Marshal(EmbeddingRequest{Content: input}) @@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 { } return 0 } - -func parseDurationMs(ms float64) time.Duration { - dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) - if err != nil { - panic(err) - } - - return dur -} diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 8662afc1e..83802d604 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/runner/common" ) @@ -99,7 +100,7 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() @@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // generating image embeddings for each image -func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) { var inputs []input var parts []string var matches [][]string @@ -229,7 +230,7 @@ type Server struct { image *ImageContext // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var samplingParams llama.SamplingParams - samplingParams.TopK = req.TopK - samplingParams.TopP = req.TopP - samplingParams.MinP = req.MinP - samplingParams.TypicalP = req.TypicalP - samplingParams.Temp = req.Temperature - samplingParams.RepeatLastN = req.RepeatLastN - samplingParams.PenaltyRepeat = req.RepeatPenalty - samplingParams.PenaltyFreq = req.FrequencyPenalty - samplingParams.PenaltyPresent = req.PresencePenalty - samplingParams.Mirostat = req.Mirostat - samplingParams.MirostatTau = req.MirostatTau - samplingParams.MirostatEta = req.MirostatEta - samplingParams.Seed = uint32(req.Seed) - samplingParams.Grammar = req.Grammar + // Extract options from the CompletionRequest + samplingParams := llama.SamplingParams{ + TopK: req.Options.TopK, + TopP: req.Options.TopP, + MinP: req.Options.MinP, + TypicalP: req.Options.TypicalP, + Temp: req.Options.Temperature, + RepeatLastN: req.Options.RepeatLastN, + PenaltyRepeat: req.Options.RepeatPenalty, + PenaltyFreq: req.Options.FrequencyPenalty, + PenaltyPresent: req.Options.PresencePenalty, + Mirostat: req.Options.Mirostat, + MirostatTau: req.Options.MirostatTau, + MirostatEta: req.Options.MirostatEta, + Seed: uint32(req.Options.Seed), + Grammar: req.Grammar, + } seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: req.NumKeep, + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: req.Options.NumKeep, samplingParams: &samplingParams, embedding: false, }) @@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numDecoded, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numDecoded, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { - var req EmbeddingRequest + var req llm.EmbeddingRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) return @@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding - if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ + if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ Embedding: embedding, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } } -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -879,7 +794,7 @@ func (s *Server) loadModel( panic(err) } - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -937,7 +852,7 @@ func Execute(args []string) error { parallel: *parallel, seqs: make([]*Sequence, *parallel), seqsSem: semaphore.NewWeighted(int64(*parallel)), - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } var tensorSplitFloats []float32 diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index a411fddb1..adcb3f738 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -107,6 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp return nil, nil, err } + // TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved? if !cachePrompt { numPast = 0 } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index c380ef221..d6339a615 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -24,6 +24,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -94,7 +95,7 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() @@ -145,7 +146,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { +func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) { var inputs []input.Input var parts []string var matches [][]string @@ -222,7 +223,7 @@ type Server struct { model model.Model // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -501,75 +502,18 @@ func (s *Server) processBatch() error { return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -591,18 +535,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } sampler := sample.NewSampler( - req.Temperature, - req.TopK, - req.TopP, - req.MinP, - req.Seed, + req.Options.Temperature, + req.Options.TopK, + req.Options.TopP, + req.Options.MinP, + req.Options.Seed, grammar, ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: int32(req.NumKeep), + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: int32(req.Options.NumKeep), sampler: sampler, embedding: false, }) @@ -625,7 +569,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -652,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -663,15 +607,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numPredicted, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numPredicted, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -682,43 +628,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -772,7 +685,7 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -824,7 +737,7 @@ func Execute(args []string) error { server := &Server{ batchSize: *batchSize, - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } // TODO(jessegross): Parameters that need to be implemented: