mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
fix: stream accumulator exits early (#10593)
the stream accumulator exits as soon as it sees `api.ProgressResponse(status="success")` which isn't strictly correctly since some requests may have multiple successes, e.g. `/api/create` when the source model needs to be pulled.
This commit is contained in:
parent
6e9a7a2568
commit
0d6e35d3c6
2 changed files with 96 additions and 12 deletions
|
@ -1341,31 +1341,29 @@ func Serve(ln net.Listener) error {
|
||||||
|
|
||||||
func waitForStream(c *gin.Context, ch chan any) {
|
func waitForStream(c *gin.Context, ch chan any) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
var latest api.ProgressResponse
|
||||||
for resp := range ch {
|
for resp := range ch {
|
||||||
switch r := resp.(type) {
|
switch r := resp.(type) {
|
||||||
case api.ProgressResponse:
|
case api.ProgressResponse:
|
||||||
if r.Status == "success" {
|
latest = r
|
||||||
c.JSON(http.StatusOK, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case gin.H:
|
case gin.H:
|
||||||
status, ok := r["status"].(int)
|
status, ok := r["status"].(int)
|
||||||
if !ok {
|
if !ok {
|
||||||
status = http.StatusInternalServerError
|
status = http.StatusInternalServerError
|
||||||
}
|
}
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
errorMsg, ok := r["error"].(string)
|
||||||
c.JSON(status, gin.H{"error": errorMsg})
|
if !ok {
|
||||||
return
|
errorMsg = "unknown error"
|
||||||
} else {
|
|
||||||
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
c.JSON(status, gin.H{"error": errorMsg})
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unknown message type"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
|
|
||||||
|
c.JSON(http.StatusOK, latest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponse(c *gin.Context, ch chan any) {
|
func streamResponse(c *gin.Context, ch chan any) {
|
||||||
|
|
|
@ -21,6 +21,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
|
@ -882,3 +884,87 @@ func TestFilterThinkTags(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWaitForStream(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
messages []any
|
||||||
|
expectCode int
|
||||||
|
expectBody string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "error",
|
||||||
|
messages: []any{
|
||||||
|
gin.H{"error": "internal server error"},
|
||||||
|
},
|
||||||
|
expectCode: http.StatusInternalServerError,
|
||||||
|
expectBody: `{"error":"internal server error"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error status",
|
||||||
|
messages: []any{
|
||||||
|
gin.H{"status": http.StatusNotFound, "error": "not found"},
|
||||||
|
},
|
||||||
|
expectCode: http.StatusNotFound,
|
||||||
|
expectBody: `{"error":"not found"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown error",
|
||||||
|
messages: []any{
|
||||||
|
gin.H{"msg": "something else"},
|
||||||
|
},
|
||||||
|
expectCode: http.StatusInternalServerError,
|
||||||
|
expectBody: `{"error":"unknown error"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown type",
|
||||||
|
messages: []any{
|
||||||
|
struct{}{},
|
||||||
|
},
|
||||||
|
expectCode: http.StatusInternalServerError,
|
||||||
|
expectBody: `{"error":"unknown message type"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "progress success",
|
||||||
|
messages: []any{
|
||||||
|
api.ProgressResponse{Status: "success"},
|
||||||
|
},
|
||||||
|
expectCode: http.StatusOK,
|
||||||
|
expectBody: `{"status":"success"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "progress more than success",
|
||||||
|
messages: []any{
|
||||||
|
api.ProgressResponse{Status: "success"},
|
||||||
|
api.ProgressResponse{Status: "one more thing"},
|
||||||
|
},
|
||||||
|
expectCode: http.StatusOK,
|
||||||
|
expectBody: `{"status":"one more thing"}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
ch := make(chan any, len(tt.messages))
|
||||||
|
for _, msg := range tt.messages {
|
||||||
|
ch <- msg
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
|
||||||
|
waitForStream(c, ch)
|
||||||
|
|
||||||
|
if w.Code != tt.expectCode {
|
||||||
|
t.Errorf("expected status %d, got %d", tt.expectCode, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), tt.expectBody); diff != "" {
|
||||||
|
t.Errorf("body mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue