mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
llm: remove internal subprocess req and resp types (#9324)
This commit refactors the LLM subsystem by removing internal subprocess request and response types. It consolidates duplicate type definitions across the codebase, moving them to centralized locations. The change also standardizes interfaces between components, simplifies the ServerStatusResp struct, and moves the ParseDurationMs function to a common package. This cleanup reduces code duplication between different runner implementations (llamarunner and ollamarunner).
This commit is contained in:
parent
4e320b8b90
commit
3892c3a703
4 changed files with 125 additions and 354 deletions
136
llm/server.go
136
llm/server.go
|
@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||||
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
|
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("starting llama server", "cmd", s.cmd.String())
|
slog.Info("starting llama server", "cmd", s.cmd)
|
||||||
if envconfig.Debug() {
|
if envconfig.Debug() {
|
||||||
filteredEnv := []string{}
|
filteredEnv := []string{}
|
||||||
for _, ev := range s.cmd.Env {
|
for _, ev := range s.cmd.Env {
|
||||||
|
@ -470,7 +470,7 @@ const ( // iota is reset to 0
|
||||||
ServerStatusError
|
ServerStatusError
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s ServerStatus) ToString() string {
|
func (s ServerStatus) String() string {
|
||||||
switch s {
|
switch s {
|
||||||
case ServerStatusReady:
|
case ServerStatusReady:
|
||||||
return "llm server ready"
|
return "llm server ready"
|
||||||
|
@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerStatusResp struct {
|
type ServerStatusResponse struct {
|
||||||
Status string `json:"status"`
|
Status ServerStatus `json:"status"`
|
||||||
SlotsIdle int `json:"slots_idle"`
|
Progress float32 `json:"progress"`
|
||||||
SlotsProcessing int `json:"slots_processing"`
|
|
||||||
Error string `json:"error"`
|
|
||||||
Progress float32 `json:"progress"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
|
@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
}
|
}
|
||||||
if s.cmd.ProcessState.ExitCode() == -1 {
|
if s.cmd.ProcessState.ExitCode() == -1 {
|
||||||
// Most likely a signal killed it, log some more details to try to help troubleshoot
|
// Most likely a signal killed it, log some more details to try to help troubleshoot
|
||||||
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
|
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
|
||||||
}
|
}
|
||||||
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||||
}
|
}
|
||||||
|
@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
return ServerStatusError, fmt.Errorf("read health request: %w", err)
|
return ServerStatusError, fmt.Errorf("read health request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var status ServerStatusResp
|
var ssr ServerStatusResponse
|
||||||
if err := json.Unmarshal(body, &status); err != nil {
|
if err := json.Unmarshal(body, &ssr); err != nil {
|
||||||
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
|
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch status.Status {
|
switch ssr.Status {
|
||||||
case "ok":
|
case ServerStatusLoadingModel:
|
||||||
return ServerStatusReady, nil
|
s.loadProgress = ssr.Progress
|
||||||
case "no slot available":
|
return ssr.Status, nil
|
||||||
return ServerStatusNoSlotsAvailable, nil
|
case ServerStatusReady, ServerStatusNoSlotsAvailable:
|
||||||
case "loading model":
|
return ssr.Status, nil
|
||||||
s.loadProgress = status.Progress
|
|
||||||
return ServerStatusLoadingModel, nil
|
|
||||||
default:
|
default:
|
||||||
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
return ssr.Status, fmt.Errorf("server error: %+v", ssr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
status, _ := s.getServerStatus(ctx)
|
status, _ := s.getServerStatus(ctx)
|
||||||
if lastStatus != status && status != ServerStatusReady {
|
if lastStatus != status && status != ServerStatusReady {
|
||||||
// Only log on status changes
|
// Only log on status changes
|
||||||
slog.Info("waiting for server to become available", "status", status.ToString())
|
slog.Info("waiting for server to become available", "status", status)
|
||||||
}
|
}
|
||||||
switch status {
|
switch status {
|
||||||
case ServerStatusReady:
|
case ServerStatusReady:
|
||||||
|
@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
||||||
stallTimer = time.Now().Add(stallDuration)
|
stallTimer = time.Now().Add(stallDuration)
|
||||||
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
|
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
|
||||||
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
|
slog.Debug("model load completed, waiting for server to become available", "status", status)
|
||||||
stallTimer = time.Now().Add(stallDuration)
|
stallTimer = time.Now().Add(stallDuration)
|
||||||
fullyLoaded = true
|
fullyLoaded = true
|
||||||
}
|
}
|
||||||
|
@ -671,63 +666,26 @@ type ImageData struct {
|
||||||
AspectRatioID int `json:"aspect_ratio_id"`
|
AspectRatioID int `json:"aspect_ratio_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type completion struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Stop bool `json:"stop"`
|
|
||||||
StoppedLimit bool `json:"stopped_limit"`
|
|
||||||
|
|
||||||
Timings struct {
|
|
||||||
PredictedN int `json:"predicted_n"`
|
|
||||||
PredictedMS float64 `json:"predicted_ms"`
|
|
||||||
PromptN int `json:"prompt_n"`
|
|
||||||
PromptMS float64 `json:"prompt_ms"`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format json.RawMessage
|
Format json.RawMessage
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
|
|
||||||
|
Grammar string // set before sending the request to the subprocess
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string
|
Content string `json:"content"`
|
||||||
DoneReason string
|
DoneReason string `json:"done_reason"`
|
||||||
Done bool
|
Done bool `json:"done"`
|
||||||
PromptEvalCount int
|
PromptEvalCount int `json:"prompt_eval_count"`
|
||||||
PromptEvalDuration time.Duration
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
EvalCount int
|
EvalCount int `json:"eval_count"`
|
||||||
EvalDuration time.Duration
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
request := map[string]any{
|
|
||||||
"prompt": req.Prompt,
|
|
||||||
"stream": true,
|
|
||||||
"n_predict": req.Options.NumPredict,
|
|
||||||
"n_keep": req.Options.NumKeep,
|
|
||||||
"main_gpu": req.Options.MainGPU,
|
|
||||||
"temperature": req.Options.Temperature,
|
|
||||||
"top_k": req.Options.TopK,
|
|
||||||
"top_p": req.Options.TopP,
|
|
||||||
"min_p": req.Options.MinP,
|
|
||||||
"typical_p": req.Options.TypicalP,
|
|
||||||
"repeat_last_n": req.Options.RepeatLastN,
|
|
||||||
"repeat_penalty": req.Options.RepeatPenalty,
|
|
||||||
"presence_penalty": req.Options.PresencePenalty,
|
|
||||||
"frequency_penalty": req.Options.FrequencyPenalty,
|
|
||||||
"mirostat": req.Options.Mirostat,
|
|
||||||
"mirostat_tau": req.Options.MirostatTau,
|
|
||||||
"mirostat_eta": req.Options.MirostatEta,
|
|
||||||
"seed": req.Options.Seed,
|
|
||||||
"stop": req.Options.Stop,
|
|
||||||
"image_data": req.Images,
|
|
||||||
"cache_prompt": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req.Format) > 0 {
|
if len(req.Format) > 0 {
|
||||||
switch string(req.Format) {
|
switch string(req.Format) {
|
||||||
case `null`, `""`:
|
case `null`, `""`:
|
||||||
|
@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
// these as "not set".
|
// these as "not set".
|
||||||
break
|
break
|
||||||
case `"json"`:
|
case `"json"`:
|
||||||
request["grammar"] = grammarJSON
|
req.Grammar = grammarJSON
|
||||||
default:
|
default:
|
||||||
if req.Format[0] != '{' {
|
if req.Format[0] != '{' {
|
||||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||||
|
@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
if g == nil {
|
if g == nil {
|
||||||
return fmt.Errorf("invalid JSON schema in format")
|
return fmt.Errorf("invalid JSON schema in format")
|
||||||
}
|
}
|
||||||
request["grammar"] = string(g)
|
req.Grammar = string(g)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Options == nil {
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
req.Options = &opts
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
slog.Info("aborting completion request due to client closing the connection")
|
slog.Info("aborting completion request due to client closing the connection")
|
||||||
|
@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady {
|
||||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
return fmt.Errorf("unexpected server status: %s", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handling JSON marshaling with special characters unescaped.
|
// Handling JSON marshaling with special characters unescaped.
|
||||||
|
@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
enc := json.NewEncoder(buffer)
|
enc := json.NewEncoder(buffer)
|
||||||
enc.SetEscapeHTML(false)
|
enc.SetEscapeHTML(false)
|
||||||
|
|
||||||
if err := enc.Encode(request); err != nil {
|
if err := enc.Encode(req); err != nil {
|
||||||
return fmt.Errorf("failed to marshal data: %v", err)
|
return fmt.Errorf("failed to marshal data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
evt = line
|
evt = line
|
||||||
}
|
}
|
||||||
|
|
||||||
var c completion
|
var c CompletionResponse
|
||||||
if err := json.Unmarshal(evt, &c); err != nil {
|
if err := json.Unmarshal(evt, &c); err != nil {
|
||||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Stop {
|
if c.Done {
|
||||||
doneReason := "stop"
|
fn(c)
|
||||||
if c.StoppedLimit {
|
|
||||||
doneReason = "length"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn(CompletionResponse{
|
|
||||||
Done: true,
|
|
||||||
DoneReason: doneReason,
|
|
||||||
PromptEvalCount: c.Timings.PromptN,
|
|
||||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
|
||||||
EvalCount: c.Timings.PredictedN,
|
|
||||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
|
||||||
})
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady {
|
||||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
return nil, fmt.Errorf("unexpected server status: %s", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||||
|
@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseDurationMs(ms float64) time.Duration {
|
|
||||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return dur
|
|
||||||
}
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/runner/common"
|
"github.com/ollama/ollama/runner/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -99,7 +100,7 @@ type NewSequenceParams struct {
|
||||||
embedding bool
|
embedding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||||
// inputs processes the prompt and images into a list of inputs
|
// inputs processes the prompt and images into a list of inputs
|
||||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
// generating image embeddings for each image
|
// generating image embeddings for each image
|
||||||
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
|
||||||
var inputs []input
|
var inputs []input
|
||||||
var parts []string
|
var parts []string
|
||||||
var matches [][]string
|
var matches [][]string
|
||||||
|
@ -229,7 +230,7 @@ type Server struct {
|
||||||
image *ImageContext
|
image *ImageContext
|
||||||
|
|
||||||
// status for external health reporting - loading, ready to serve, etc.
|
// status for external health reporting - loading, ready to serve, etc.
|
||||||
status ServerStatus
|
status llm.ServerStatus
|
||||||
|
|
||||||
// current progress on loading the model
|
// current progress on loading the model
|
||||||
progress float32
|
progress float32
|
||||||
|
@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
|
||||||
// this way the api acts as a proxy instead of using a different api for the
|
|
||||||
// runner
|
|
||||||
type Options struct {
|
|
||||||
api.Runner
|
|
||||||
|
|
||||||
NumKeep int `json:"n_keep"`
|
|
||||||
Seed int `json:"seed"`
|
|
||||||
NumPredict int `json:"n_predict"`
|
|
||||||
TopK int `json:"top_k"`
|
|
||||||
TopP float32 `json:"top_p"`
|
|
||||||
MinP float32 `json:"min_p"`
|
|
||||||
TypicalP float32 `json:"typical_p"`
|
|
||||||
RepeatLastN int `json:"repeat_last_n"`
|
|
||||||
Temperature float32 `json:"temperature"`
|
|
||||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
|
||||||
PresencePenalty float32 `json:"presence_penalty"`
|
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
||||||
Mirostat int `json:"mirostat"`
|
|
||||||
MirostatTau float32 `json:"mirostat_tau"`
|
|
||||||
MirostatEta float32 `json:"mirostat_eta"`
|
|
||||||
Stop []string `json:"stop"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageData struct {
|
|
||||||
Data []byte `json:"data"`
|
|
||||||
ID int `json:"id"`
|
|
||||||
AspectRatioID int `json:"aspect_ratio_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompletionRequest struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Images []ImageData `json:"image_data"`
|
|
||||||
Grammar string `json:"grammar"`
|
|
||||||
CachePrompt bool `json:"cache_prompt"`
|
|
||||||
|
|
||||||
Options
|
|
||||||
}
|
|
||||||
|
|
||||||
type Timings struct {
|
|
||||||
PredictedN int `json:"predicted_n"`
|
|
||||||
PredictedMS float64 `json:"predicted_ms"`
|
|
||||||
PromptN int `json:"prompt_n"`
|
|
||||||
PromptMS float64 `json:"prompt_ms"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompletionResponse struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Stop bool `json:"stop"`
|
|
||||||
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
|
||||||
PredictedN int `json:"predicted_n,omitempty"`
|
|
||||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
|
||||||
PromptN int `json:"prompt_n,omitempty"`
|
|
||||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
|
||||||
|
|
||||||
Timings Timings `json:"timings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
var req CompletionRequest
|
var req llm.CompletionRequest
|
||||||
req.Options = Options(api.DefaultOptions())
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Options == nil {
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
req.Options = &opts
|
||||||
|
}
|
||||||
|
|
||||||
// Set the headers to indicate streaming
|
// Set the headers to indicate streaming
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("Transfer-Encoding", "chunked")
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var samplingParams llama.SamplingParams
|
// Extract options from the CompletionRequest
|
||||||
samplingParams.TopK = req.TopK
|
samplingParams := llama.SamplingParams{
|
||||||
samplingParams.TopP = req.TopP
|
TopK: req.Options.TopK,
|
||||||
samplingParams.MinP = req.MinP
|
TopP: req.Options.TopP,
|
||||||
samplingParams.TypicalP = req.TypicalP
|
MinP: req.Options.MinP,
|
||||||
samplingParams.Temp = req.Temperature
|
TypicalP: req.Options.TypicalP,
|
||||||
samplingParams.RepeatLastN = req.RepeatLastN
|
Temp: req.Options.Temperature,
|
||||||
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
RepeatLastN: req.Options.RepeatLastN,
|
||||||
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
PenaltyRepeat: req.Options.RepeatPenalty,
|
||||||
samplingParams.PenaltyPresent = req.PresencePenalty
|
PenaltyFreq: req.Options.FrequencyPenalty,
|
||||||
samplingParams.Mirostat = req.Mirostat
|
PenaltyPresent: req.Options.PresencePenalty,
|
||||||
samplingParams.MirostatTau = req.MirostatTau
|
Mirostat: req.Options.Mirostat,
|
||||||
samplingParams.MirostatEta = req.MirostatEta
|
MirostatTau: req.Options.MirostatTau,
|
||||||
samplingParams.Seed = uint32(req.Seed)
|
MirostatEta: req.Options.MirostatEta,
|
||||||
samplingParams.Grammar = req.Grammar
|
Seed: uint32(req.Options.Seed),
|
||||||
|
Grammar: req.Grammar,
|
||||||
|
}
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.NumPredict,
|
numPredict: req.Options.NumPredict,
|
||||||
stop: req.Stop,
|
stop: req.Options.Stop,
|
||||||
numKeep: req.NumKeep,
|
numKeep: req.Options.NumKeep,
|
||||||
samplingParams: &samplingParams,
|
samplingParams: &samplingParams,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
})
|
})
|
||||||
|
@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
found := false
|
found := false
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
|
@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
Content: content,
|
Content: content,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
} else {
|
} else {
|
||||||
// Send the final response
|
// Send the final response
|
||||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
doneReason := "stop"
|
||||||
Stop: true,
|
if seq.doneReason == "limit" {
|
||||||
StoppedLimit: seq.doneReason == "limit",
|
doneReason = "length"
|
||||||
Timings: Timings{
|
}
|
||||||
PromptN: seq.numPromptInputs,
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
Done: true,
|
||||||
PredictedN: seq.numDecoded,
|
DoneReason: doneReason,
|
||||||
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
PromptEvalCount: seq.numPromptInputs,
|
||||||
},
|
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||||
|
EvalCount: seq.numDecoded,
|
||||||
|
EvalDuration: time.Since(seq.startGenerationTime),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
CachePrompt bool `json:"cache_prompt"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
|
||||||
Embedding []float32 `json:"embedding"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
var req EmbeddingRequest
|
var req llm.EmbeddingRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
found := false
|
found := false
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
|
@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
embedding := <-seq.embedding
|
embedding := <-seq.embedding
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||||
Embedding: embedding,
|
Embedding: embedding,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type HealthResponse struct {
|
|
||||||
Status string `json:"status"`
|
|
||||||
Progress float32 `json:"progress"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServerStatus int
|
|
||||||
|
|
||||||
const (
|
|
||||||
ServerStatusReady ServerStatus = iota
|
|
||||||
ServerStatusLoadingModel
|
|
||||||
ServerStatusError
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s ServerStatus) ToString() string {
|
|
||||||
switch s {
|
|
||||||
case ServerStatusReady:
|
|
||||||
return "ok"
|
|
||||||
case ServerStatusLoadingModel:
|
|
||||||
return "loading model"
|
|
||||||
default:
|
|
||||||
return "server error"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||||
Status: s.status.ToString(),
|
Status: s.status,
|
||||||
Progress: s.progress,
|
Progress: s.progress,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
@ -879,7 +794,7 @@ func (s *Server) loadModel(
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.status = ServerStatusReady
|
s.status = llm.ServerStatusReady
|
||||||
s.ready.Done()
|
s.ready.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -937,7 +852,7 @@ func Execute(args []string) error {
|
||||||
parallel: *parallel,
|
parallel: *parallel,
|
||||||
seqs: make([]*Sequence, *parallel),
|
seqs: make([]*Sequence, *parallel),
|
||||||
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
||||||
status: ServerStatusLoadingModel,
|
status: llm.ServerStatusLoadingModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
var tensorSplitFloats []float32
|
var tensorSplitFloats []float32
|
||||||
|
|
|
@ -107,6 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
|
||||||
if !cachePrompt {
|
if !cachePrompt {
|
||||||
numPast = 0
|
numPast = 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"golang.org/x/sync/semaphore"
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
|
@ -94,7 +95,7 @@ type NewSequenceParams struct {
|
||||||
embedding bool
|
embedding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
@ -145,7 +146,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||||
// inputs processes the prompt and images into a list of inputs
|
// inputs processes the prompt and images into a list of inputs
|
||||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
// decoding images
|
// decoding images
|
||||||
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
|
func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) {
|
||||||
var inputs []input.Input
|
var inputs []input.Input
|
||||||
var parts []string
|
var parts []string
|
||||||
var matches [][]string
|
var matches [][]string
|
||||||
|
@ -222,7 +223,7 @@ type Server struct {
|
||||||
model model.Model
|
model model.Model
|
||||||
|
|
||||||
// status for external health reporting - loading, ready to serve, etc.
|
// status for external health reporting - loading, ready to serve, etc.
|
||||||
status ServerStatus
|
status llm.ServerStatus
|
||||||
|
|
||||||
// current progress on loading the model
|
// current progress on loading the model
|
||||||
progress float32
|
progress float32
|
||||||
|
@ -501,75 +502,18 @@ func (s *Server) processBatch() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
|
||||||
// this way the api acts as a proxy instead of using a different api for the
|
|
||||||
// runner
|
|
||||||
type Options struct {
|
|
||||||
api.Runner
|
|
||||||
|
|
||||||
NumKeep int `json:"n_keep"`
|
|
||||||
Seed int `json:"seed"`
|
|
||||||
NumPredict int `json:"n_predict"`
|
|
||||||
TopK int `json:"top_k"`
|
|
||||||
TopP float32 `json:"top_p"`
|
|
||||||
MinP float32 `json:"min_p"`
|
|
||||||
TypicalP float32 `json:"typical_p"`
|
|
||||||
RepeatLastN int `json:"repeat_last_n"`
|
|
||||||
Temperature float32 `json:"temperature"`
|
|
||||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
|
||||||
PresencePenalty float32 `json:"presence_penalty"`
|
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
||||||
Mirostat int `json:"mirostat"`
|
|
||||||
MirostatTau float32 `json:"mirostat_tau"`
|
|
||||||
MirostatEta float32 `json:"mirostat_eta"`
|
|
||||||
Stop []string `json:"stop"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageData struct {
|
|
||||||
Data []byte `json:"data"`
|
|
||||||
ID int `json:"id"`
|
|
||||||
AspectRatioID int `json:"aspect_ratio_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompletionRequest struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Images []ImageData `json:"image_data"`
|
|
||||||
Grammar string `json:"grammar"`
|
|
||||||
CachePrompt bool `json:"cache_prompt"`
|
|
||||||
|
|
||||||
Options
|
|
||||||
}
|
|
||||||
|
|
||||||
type Timings struct {
|
|
||||||
PredictedN int `json:"predicted_n"`
|
|
||||||
PredictedMS float64 `json:"predicted_ms"`
|
|
||||||
PromptN int `json:"prompt_n"`
|
|
||||||
PromptMS float64 `json:"prompt_ms"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompletionResponse struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Stop bool `json:"stop"`
|
|
||||||
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
|
||||||
PredictedN int `json:"predicted_n,omitempty"`
|
|
||||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
|
||||||
PromptN int `json:"prompt_n,omitempty"`
|
|
||||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
|
||||||
|
|
||||||
Timings Timings `json:"timings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
var req CompletionRequest
|
var req llm.CompletionRequest
|
||||||
req.Options = Options(api.DefaultOptions())
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Options == nil {
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
req.Options = &opts
|
||||||
|
}
|
||||||
|
|
||||||
// Set the headers to indicate streaming
|
// Set the headers to indicate streaming
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("Transfer-Encoding", "chunked")
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
@ -591,18 +535,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := sample.NewSampler(
|
sampler := sample.NewSampler(
|
||||||
req.Temperature,
|
req.Options.Temperature,
|
||||||
req.TopK,
|
req.Options.TopK,
|
||||||
req.TopP,
|
req.Options.TopP,
|
||||||
req.MinP,
|
req.Options.MinP,
|
||||||
req.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.NumPredict,
|
numPredict: req.Options.NumPredict,
|
||||||
stop: req.Stop,
|
stop: req.Options.Stop,
|
||||||
numKeep: int32(req.NumKeep),
|
numKeep: int32(req.Options.NumKeep),
|
||||||
sampler: sampler,
|
sampler: sampler,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
})
|
})
|
||||||
|
@ -625,7 +569,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
found := false
|
found := false
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
|
@ -652,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
Content: content,
|
Content: content,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
@ -663,15 +607,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
} else {
|
} else {
|
||||||
// Send the final response
|
// Send the final response
|
||||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
doneReason := "stop"
|
||||||
Stop: true,
|
if seq.doneReason == "limit" {
|
||||||
StoppedLimit: seq.doneReason == "limit",
|
doneReason = "length"
|
||||||
Timings: Timings{
|
}
|
||||||
PromptN: seq.numPromptInputs,
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
Done: true,
|
||||||
PredictedN: seq.numPredicted,
|
DoneReason: doneReason,
|
||||||
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
PromptEvalCount: seq.numPromptInputs,
|
||||||
},
|
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||||
|
EvalCount: seq.numPredicted,
|
||||||
|
EvalDuration: time.Since(seq.startGenerationTime),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
@ -682,43 +628,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
CachePrompt bool `json:"cache_prompt"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
|
||||||
Embedding []float32 `json:"embedding"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type HealthResponse struct {
|
|
||||||
Status string `json:"status"`
|
|
||||||
Progress float32 `json:"progress"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServerStatus int
|
|
||||||
|
|
||||||
const (
|
|
||||||
ServerStatusReady ServerStatus = iota
|
|
||||||
ServerStatusLoadingModel
|
|
||||||
ServerStatusError
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s ServerStatus) ToString() string {
|
|
||||||
switch s {
|
|
||||||
case ServerStatusReady:
|
|
||||||
return "ok"
|
|
||||||
case ServerStatusLoadingModel:
|
|
||||||
return "loading model"
|
|
||||||
default:
|
|
||||||
return "server error"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||||
Status: s.status.ToString(),
|
Status: s.status,
|
||||||
Progress: s.progress,
|
Progress: s.progress,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
@ -772,7 +685,7 @@ func (s *Server) loadModel(
|
||||||
s.seqs = make([]*Sequence, s.parallel)
|
s.seqs = make([]*Sequence, s.parallel)
|
||||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||||
|
|
||||||
s.status = ServerStatusReady
|
s.status = llm.ServerStatusReady
|
||||||
s.ready.Done()
|
s.ready.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -824,7 +737,7 @@ func Execute(args []string) error {
|
||||||
|
|
||||||
server := &Server{
|
server := &Server{
|
||||||
batchSize: *batchSize,
|
batchSize: *batchSize,
|
||||||
status: ServerStatusLoadingModel,
|
status: llm.ServerStatusLoadingModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jessegross): Parameters that need to be implemented:
|
// TODO(jessegross): Parameters that need to be implemented:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue