ggml: Use pointer receivers for Context

Context is currently mixed between pointer and value receivers. Change
this to be all pointer receivers so don't have to reason about whether
the things we are updating in the struct will be retained.
This commit is contained in:
Jesse Gross 2025-03-11 16:06:06 -07:00 committed by Jesse Gross
parent bc108b9ad6
commit f33ccd5d27

View file

@ -484,7 +484,7 @@ type Context struct {
maxGraphNodes int
}
func (c Context) Input() ml.Context {
func (c *Context) Input() ml.Context {
if c.b.input != nil {
return &Context{
b: c.b,
@ -494,10 +494,10 @@ func (c Context) Input() ml.Context {
}
}
return &c
return c
}
func (c Context) Layer(i int) ml.Context {
func (c *Context) Layer(i int) ml.Context {
if buft, ok := c.b.layers[i]; ok {
return &Context{
b: c.b,
@ -507,7 +507,7 @@ func (c Context) Layer(i int) ml.Context {
}
}
return &c
return c
}
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
@ -522,7 +522,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
return c
}
func (c Context) Compute(tensors ...ml.Tensor) {
func (c *Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
C.ggml_backend_sched_reset(c.b.sched)
@ -541,7 +541,7 @@ func (c Context) Compute(tensors ...ml.Tensor) {
}
}
func (c Context) Reserve() error {
func (c *Context) Reserve() error {
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
C.ggml_backend_sched_reset(c.b.sched)
return errors.New("failed to reserve graph")
@ -559,7 +559,7 @@ func (c Context) Reserve() error {
return nil
}
func (c Context) MaxGraphNodes() int {
func (c *Context) MaxGraphNodes() int {
return c.maxGraphNodes
}
@ -576,7 +576,7 @@ func pad(length, pad C.size_t) C.size_t {
return ((length + pad - 1) / pad) * pad
}
func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
if c.buft == nil {
panic("set Input or Layer before creating tensors")
}
@ -621,7 +621,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
return &Tensor{b: c.b, t: t}, nil
}
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape)
if err != nil {
panic(err)
@ -630,7 +630,7 @@ func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return t
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape)
if err != nil {
panic(err)
@ -658,7 +658,7 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
return nil
}
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil {
return nil, err
}
@ -675,7 +675,7 @@ func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
return t, nil
}
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil {
return nil, err
}