ml: update Context.Forward interface

update Context.Forward to accept multiple tensors to match
Context.Compute signature

update Context.Forward to return Context such that it can be chained
with Context.Compute
This commit is contained in:
Michael Yang 2025-02-21 11:57:08 -08:00
parent 41dc280491
commit 3e8b8a1933
6 changed files with 18 additions and 14 deletions

View file

@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
out, _, mask := cache.Get(context)
context.Forward(out)
context.Forward(mask)
context.Compute(out, mask)
context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)