From a7835c671615d71280ca7dba7264bd05a4f90915 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 30 Apr 2025 17:59:31 -0700 Subject: [PATCH] fix: write gguf padding (#10510) * add gguf_test * fix padding padding was being added to offset but not to the running count --- fs/ggml/gguf.go | 3 ++- fs/ggml/gguf_test.go | 63 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 fs/ggml/gguf_test.go diff --git a/fs/ggml/gguf.go b/fs/ggml/gguf.go index fb3421576..b7029bc38 100644 --- a/fs/ggml/gguf.go +++ b/fs/ggml/gguf.go @@ -531,11 +531,12 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error { var s uint64 for _, t := range ts { - t.Offset = s + uint64(ggufPadding(int64(s), int64(alignment))) + t.Offset = s if err := ggufWriteTensorInfo(ws, t); err != nil { return err } s += t.Size() + s += uint64(ggufPadding(int64(s), int64(alignment))) } for _, t := range ts { diff --git a/fs/ggml/gguf_test.go b/fs/ggml/gguf_test.go new file mode 100644 index 000000000..22e7a5514 --- /dev/null +++ b/fs/ggml/gguf_test.go @@ -0,0 +1,63 @@ +package ggml + +import ( + "bytes" + "os" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWriteGGUF(t *testing.T) { + w, err := os.CreateTemp(t.TempDir(), "*.bin") + if err != nil { + t.Fatal(err) + } + defer w.Close() + + if err := WriteGGUF(w, KV{ + "general.alignment": uint32(16), + }, []Tensor{ + {Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, + {Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, + {Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, + {Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, + {Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, + {Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, + }); err != nil { + t.Fatal(err) + } + + r, err := os.Open(w.Name()) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + ff, _, err := Decode(r, 0) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(ff.KV(), KV{ + "general.alignment": uint32(16), + "general.parameter_count": uint64(36), + }); diff != "" { + t.Errorf("Mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(ff.Tensors(), Tensors{ + Offset: 336, + items: []*Tensor{ + {Name: "test.0", Offset: 0, Shape: []uint64{2, 3}}, + {Name: "test.1", Offset: 32, Shape: []uint64{2, 3}}, + {Name: "test.2", Offset: 64, Shape: []uint64{2, 3}}, + {Name: "test.3", Offset: 96, Shape: []uint64{2, 3}}, + {Name: "test.4", Offset: 128, Shape: []uint64{2, 3}}, + {Name: "test.5", Offset: 160, Shape: []uint64{2, 3}}, + }, + }, cmp.AllowUnexported(Tensors{})); diff != "" { + t.Errorf("Mismatch (-want +got):\n%s", diff) + } +}