fix: stream accumulator exits early (#10593)

the stream accumulator exits as soon as it sees `api.ProgressResponse(status="success")` which isn't strictly correctly
since some requests may have multiple successes, e.g. `/api/create` when the source model needs to be pulled.
This commit is contained in:
Michael Yang 2025-05-08 13:17:30 -07:00 committed by GitHub
parent 6e9a7a2568
commit 0d6e35d3c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 96 additions and 12 deletions

View file

@ -1341,31 +1341,29 @@ func Serve(ln net.Listener) error {
func waitForStream(c *gin.Context, ch chan any) { func waitForStream(c *gin.Context, ch chan any) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
var latest api.ProgressResponse
for resp := range ch { for resp := range ch {
switch r := resp.(type) { switch r := resp.(type) {
case api.ProgressResponse: case api.ProgressResponse:
if r.Status == "success" { latest = r
c.JSON(http.StatusOK, r)
return
}
case gin.H: case gin.H:
status, ok := r["status"].(int) status, ok := r["status"].(int)
if !ok { if !ok {
status = http.StatusInternalServerError status = http.StatusInternalServerError
} }
if errorMsg, ok := r["error"].(string); ok { errorMsg, ok := r["error"].(string)
c.JSON(status, gin.H{"error": errorMsg}) if !ok {
return errorMsg = "unknown error"
} else {
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
return
} }
c.JSON(status, gin.H{"error": errorMsg})
return
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unknown message type"})
return return
} }
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
c.JSON(http.StatusOK, latest)
} }
func streamResponse(c *gin.Context, ch chan any) { func streamResponse(c *gin.Context, ch chan any) {

View file

@ -21,6 +21,8 @@ import (
"testing" "testing"
"unicode" "unicode"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
@ -882,3 +884,87 @@ func TestFilterThinkTags(t *testing.T) {
} }
} }
} }
func TestWaitForStream(t *testing.T) {
gin.SetMode(gin.TestMode)
cases := []struct {
name string
messages []any
expectCode int
expectBody string
}{
{
name: "error",
messages: []any{
gin.H{"error": "internal server error"},
},
expectCode: http.StatusInternalServerError,
expectBody: `{"error":"internal server error"}`,
},
{
name: "error status",
messages: []any{
gin.H{"status": http.StatusNotFound, "error": "not found"},
},
expectCode: http.StatusNotFound,
expectBody: `{"error":"not found"}`,
},
{
name: "unknown error",
messages: []any{
gin.H{"msg": "something else"},
},
expectCode: http.StatusInternalServerError,
expectBody: `{"error":"unknown error"}`,
},
{
name: "unknown type",
messages: []any{
struct{}{},
},
expectCode: http.StatusInternalServerError,
expectBody: `{"error":"unknown message type"}`,
},
{
name: "progress success",
messages: []any{
api.ProgressResponse{Status: "success"},
},
expectCode: http.StatusOK,
expectBody: `{"status":"success"}`,
},
{
name: "progress more than success",
messages: []any{
api.ProgressResponse{Status: "success"},
api.ProgressResponse{Status: "one more thing"},
},
expectCode: http.StatusOK,
expectBody: `{"status":"one more thing"}`,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ch := make(chan any, len(tt.messages))
for _, msg := range tt.messages {
ch <- msg
}
close(ch)
waitForStream(c, ch)
if w.Code != tt.expectCode {
t.Errorf("expected status %d, got %d", tt.expectCode, w.Code)
}
if diff := cmp.Diff(w.Body.String(), tt.expectBody); diff != "" {
t.Errorf("body mismatch (-want +got):\n%s", diff)
}
})
}
}