From 0d6e35d3c67cf37de1c425d178c71d7351083013 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 8 May 2025 13:17:30 -0700 Subject: [PATCH] fix: stream accumulator exits early (#10593) the stream accumulator exits as soon as it sees `api.ProgressResponse(status="success")` which isn't strictly correctly since some requests may have multiple successes, e.g. `/api/create` when the source model needs to be pulled. --- server/routes.go | 22 +++++------ server/routes_test.go | 86 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/server/routes.go b/server/routes.go index 8886073cf..8b0c7aca5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1341,31 +1341,29 @@ func Serve(ln net.Listener) error { func waitForStream(c *gin.Context, ch chan any) { c.Header("Content-Type", "application/json") + var latest api.ProgressResponse for resp := range ch { switch r := resp.(type) { case api.ProgressResponse: - if r.Status == "success" { - c.JSON(http.StatusOK, r) - return - } + latest = r case gin.H: status, ok := r["status"].(int) if !ok { status = http.StatusInternalServerError } - if errorMsg, ok := r["error"].(string); ok { - c.JSON(status, gin.H{"error": errorMsg}) - return - } else { - c.JSON(status, gin.H{"error": "unexpected error format in progress response"}) - return + errorMsg, ok := r["error"].(string) + if !ok { + errorMsg = "unknown error" } + c.JSON(status, gin.H{"error": errorMsg}) + return default: - c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"}) + c.JSON(http.StatusInternalServerError, gin.H{"error": "unknown message type"}) return } } - c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"}) + + c.JSON(http.StatusOK, latest) } func streamResponse(c *gin.Context, ch chan any) { diff --git a/server/routes_test.go b/server/routes_test.go index fd63b78be..7c44bc957 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -21,6 +21,8 @@ import ( "testing" "unicode" + "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/openai" @@ -882,3 +884,87 @@ func TestFilterThinkTags(t *testing.T) { } } } + +func TestWaitForStream(t *testing.T) { + gin.SetMode(gin.TestMode) + + cases := []struct { + name string + messages []any + expectCode int + expectBody string + }{ + { + name: "error", + messages: []any{ + gin.H{"error": "internal server error"}, + }, + expectCode: http.StatusInternalServerError, + expectBody: `{"error":"internal server error"}`, + }, + { + name: "error status", + messages: []any{ + gin.H{"status": http.StatusNotFound, "error": "not found"}, + }, + expectCode: http.StatusNotFound, + expectBody: `{"error":"not found"}`, + }, + { + name: "unknown error", + messages: []any{ + gin.H{"msg": "something else"}, + }, + expectCode: http.StatusInternalServerError, + expectBody: `{"error":"unknown error"}`, + }, + { + name: "unknown type", + messages: []any{ + struct{}{}, + }, + expectCode: http.StatusInternalServerError, + expectBody: `{"error":"unknown message type"}`, + }, + { + name: "progress success", + messages: []any{ + api.ProgressResponse{Status: "success"}, + }, + expectCode: http.StatusOK, + expectBody: `{"status":"success"}`, + }, + { + name: "progress more than success", + messages: []any{ + api.ProgressResponse{Status: "success"}, + api.ProgressResponse{Status: "one more thing"}, + }, + expectCode: http.StatusOK, + expectBody: `{"status":"one more thing"}`, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + ch := make(chan any, len(tt.messages)) + for _, msg := range tt.messages { + ch <- msg + } + close(ch) + + waitForStream(c, ch) + + if w.Code != tt.expectCode { + t.Errorf("expected status %d, got %d", tt.expectCode, w.Code) + } + + if diff := cmp.Diff(w.Body.String(), tt.expectBody); diff != "" { + t.Errorf("body mismatch (-want +got):\n%s", diff) + } + }) + } +}