mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 02:16:36 +02:00
server: parallelize embeddings in API web handler instead of in subprocess runner (#6220)
For simplicity, perform parallelization of embedding requests in the API handler instead of offloading this to the subprocess runner. This keeps the scheduling story simpler as it builds on existing parallel requests, similar to existing text completion functionality.
This commit is contained in:
parent
25906d72d1
commit
15c2d8fe14
4 changed files with 53 additions and 71 deletions
|
@ -23,6 +23,7 @@ import (
|
|||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
|
@ -346,6 +347,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
var count int
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
|
@ -368,25 +370,36 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
count += len(tokens)
|
||||
|
||||
input[i] = s
|
||||
}
|
||||
embeddings, err := r.Embed(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
slog.Error("embedding generation failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
|
||||
var g errgroup.Group
|
||||
embeddings := make([][]float32, len(input))
|
||||
for i, text := range input {
|
||||
g.Go(func() error {
|
||||
embedding, err := r.Embedding(c.Request.Context(), text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
embeddings[i] = normalize(embedding)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
for i, e := range embeddings.Embedding {
|
||||
embeddings.Embedding[i] = normalize(e)
|
||||
if err := g.Wait(); err != nil {
|
||||
slog.Error("embedding generation failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.EmbedResponse{
|
||||
Model: req.Model,
|
||||
Embeddings: embeddings.Embedding,
|
||||
Embeddings: embeddings,
|
||||
TotalDuration: time.Since(checkpointStart),
|
||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||
PromptEvalCount: embeddings.PromptEvalCount,
|
||||
PromptEvalCount: count,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
@ -430,21 +443,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
|
||||
embedding := make([]float64, len(embeddings.Embedding[0]))
|
||||
|
||||
for i, v := range embeddings.Embedding[0] {
|
||||
embedding[i] = float64(v)
|
||||
var e []float64
|
||||
for _, v := range embedding {
|
||||
e = append(e, float64(v))
|
||||
}
|
||||
|
||||
resp := api.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
Embedding: e,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue