mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 09:56:28 +02:00
fix: stream accumulator exits early (#10593)
the stream accumulator exits as soon as it sees `api.ProgressResponse(status="success")` which isn't strictly correctly since some requests may have multiple successes, e.g. `/api/create` when the source model needs to be pulled.
This commit is contained in:
parent
6e9a7a2568
commit
0d6e35d3c6
2 changed files with 96 additions and 12 deletions
|
@ -1341,31 +1341,29 @@ func Serve(ln net.Listener) error {
|
|||
|
||||
func waitForStream(c *gin.Context, ch chan any) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
var latest api.ProgressResponse
|
||||
for resp := range ch {
|
||||
switch r := resp.(type) {
|
||||
case api.ProgressResponse:
|
||||
if r.Status == "success" {
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
latest = r
|
||||
case gin.H:
|
||||
status, ok := r["status"].(int)
|
||||
if !ok {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
if errorMsg, ok := r["error"].(string); ok {
|
||||
c.JSON(status, gin.H{"error": errorMsg})
|
||||
return
|
||||
} else {
|
||||
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
|
||||
return
|
||||
errorMsg, ok := r["error"].(string)
|
||||
if !ok {
|
||||
errorMsg = "unknown error"
|
||||
}
|
||||
c.JSON(status, gin.H{"error": errorMsg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unknown message type"})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
|
||||
|
||||
c.JSON(http.StatusOK, latest)
|
||||
}
|
||||
|
||||
func streamResponse(c *gin.Context, ch chan any) {
|
||||
|
|
|
@ -21,6 +21,8 @@ import (
|
|||
"testing"
|
||||
"unicode"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/openai"
|
||||
|
@ -882,3 +884,87 @@ func TestFilterThinkTags(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitForStream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
messages []any
|
||||
expectCode int
|
||||
expectBody string
|
||||
}{
|
||||
{
|
||||
name: "error",
|
||||
messages: []any{
|
||||
gin.H{"error": "internal server error"},
|
||||
},
|
||||
expectCode: http.StatusInternalServerError,
|
||||
expectBody: `{"error":"internal server error"}`,
|
||||
},
|
||||
{
|
||||
name: "error status",
|
||||
messages: []any{
|
||||
gin.H{"status": http.StatusNotFound, "error": "not found"},
|
||||
},
|
||||
expectCode: http.StatusNotFound,
|
||||
expectBody: `{"error":"not found"}`,
|
||||
},
|
||||
{
|
||||
name: "unknown error",
|
||||
messages: []any{
|
||||
gin.H{"msg": "something else"},
|
||||
},
|
||||
expectCode: http.StatusInternalServerError,
|
||||
expectBody: `{"error":"unknown error"}`,
|
||||
},
|
||||
{
|
||||
name: "unknown type",
|
||||
messages: []any{
|
||||
struct{}{},
|
||||
},
|
||||
expectCode: http.StatusInternalServerError,
|
||||
expectBody: `{"error":"unknown message type"}`,
|
||||
},
|
||||
{
|
||||
name: "progress success",
|
||||
messages: []any{
|
||||
api.ProgressResponse{Status: "success"},
|
||||
},
|
||||
expectCode: http.StatusOK,
|
||||
expectBody: `{"status":"success"}`,
|
||||
},
|
||||
{
|
||||
name: "progress more than success",
|
||||
messages: []any{
|
||||
api.ProgressResponse{Status: "success"},
|
||||
api.ProgressResponse{Status: "one more thing"},
|
||||
},
|
||||
expectCode: http.StatusOK,
|
||||
expectBody: `{"status":"one more thing"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
ch := make(chan any, len(tt.messages))
|
||||
for _, msg := range tt.messages {
|
||||
ch <- msg
|
||||
}
|
||||
close(ch)
|
||||
|
||||
waitForStream(c, ch)
|
||||
|
||||
if w.Code != tt.expectCode {
|
||||
t.Errorf("expected status %d, got %d", tt.expectCode, w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), tt.expectBody); diff != "" {
|
||||
t.Errorf("body mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue