mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
Merge pull request #9742 from ollama/mxyng/engine-error-embeddings
fix: error on models that don't support embeddings
This commit is contained in:
commit
ccfd41c4f0
2 changed files with 9 additions and 66 deletions
|
@ -691,65 +691,6 @@ type EmbeddingResponse struct {
|
||||||
Embedding []float32 `json:"embedding"`
|
Embedding []float32 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var req EmbeddingRequest
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
slog.Debug("embedding request", "content", req.Content)
|
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
|
||||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
|
||||||
if errors.Is(err, context.Canceled) {
|
|
||||||
slog.Info("aborting embeddings request due to client closing the connection")
|
|
||||||
} else {
|
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
found := false
|
|
||||||
for i, sq := range s.seqs {
|
|
||||||
if sq == nil {
|
|
||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
|
||||||
if err != nil {
|
|
||||||
s.mu.Unlock()
|
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.seqs[i] = seq
|
|
||||||
s.cond.Signal()
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
embedding := <-seq.embedding
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
|
||||||
Embedding: embedding,
|
|
||||||
}); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type HealthResponse struct {
|
type HealthResponse struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Progress float32 `json:"progress"`
|
Progress float32 `json:"progress"`
|
||||||
|
@ -927,9 +868,13 @@ func Execute(args []string) error {
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/embedding", server.embeddings)
|
// TODO: support embeddings
|
||||||
mux.HandleFunc("/completion", server.completion)
|
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
|
||||||
mux.HandleFunc("/health", server.health)
|
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("POST /completion", server.completion)
|
||||||
|
mux.HandleFunc("GET /health", server.health)
|
||||||
|
|
||||||
httpServer := http.Server{
|
httpServer := http.Server{
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
|
|
|
@ -483,8 +483,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
slog.Error("embedding generation failed", "error", err)
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -545,8 +544,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
|
|
||||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embedding: %v", err)})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue