server: parallelize embeddings in API web handler instead of in subprocess runner (#6220)

For simplicity, perform parallelization of embedding requests in the API handler instead of offloading this to the subprocess runner. This keeps the scheduling story simpler as it builds on existing parallel requests, similar to existing text completion functionality.
This commit is contained in:
Jeffrey Morgan 2024-08-11 11:57:10 -07:00 committed by GitHub
parent 25906d72d1
commit 15c2d8fe14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 53 additions and 71 deletions

View file

@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embed(ctx context.Context, input []string) (*EmbedResponse, error)
Embedding(ctx context.Context, input string) ([]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@ -883,24 +883,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil
}
type EmbedRequest struct {
Content []string `json:"content"`
type EmbeddingRequest struct {
Content string `json:"content"`
}
type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_n"`
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
// each input will use a slot, so we need to acquire the semaphore for
// the number of inputs up to numParallel
slots := int64(min(len(input), s.numParallel))
if err := s.sem.Acquire(ctx, slots); err != nil {
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
}
defer s.sem.Release(slots)
defer s.sem.Release(1)
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
@ -910,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(EmbedRequest{Content: input})
data, err := json.Marshal(EmbeddingRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, fmt.Errorf("do embedding request: %w", err)
}
@ -937,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
return nil, fmt.Errorf("%s", body)
}
var e EmbedResponse
var e EmbeddingResponse
if err := json.Unmarshal(body, &e); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return &e, nil
return e.Embedding, nil
}
type TokenizeRequest struct {