fs: move ml.Config to fs package

This commit is contained in:
Michael Yang 2025-03-18 14:38:44 -07:00 committed by Michael Yang
parent e53b3cbd0c
commit 3b96a93672
16 changed files with 55 additions and 40 deletions

13
fs/config.go Normal file
View file

@ -0,0 +1,13 @@
package fs
type Config interface {
Architecture() string
String(string, ...string) string
Uint(string, ...uint32) uint32
Float(string, ...float32) float32
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}

View file

@ -5,6 +5,7 @@ import (
"slices" "slices"
"testing" "testing"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
@ -373,7 +374,7 @@ func TestCanResume(t *testing.T) {
type testBackend struct{} type testBackend struct{}
func (b *testBackend) Config() ml.Config { func (b *testBackend) Config() fs.Config {
panic("not implemented") panic("not implemented")
} }

View file

@ -9,22 +9,12 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"github.com/ollama/ollama/fs"
) )
type Config interface {
Architecture() string
String(string, ...string) string
Uint(string, ...uint32) uint32
Float(string, ...float32) float32
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}
type Backend interface { type Backend interface {
Config() Config Config() fs.Config
Get(name string) Tensor Get(name string) Tensor
NewContext() Context NewContext() Context
NewContextSize(size int) Context NewContextSize(size int) Context

View file

@ -24,7 +24,8 @@ import (
"unsafe" "unsafe"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
fs "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -41,7 +42,7 @@ func devices() []*C.struct_ggml_backend_device {
} }
type Backend struct { type Backend struct {
meta *fs.GGML meta *fsggml.GGML
sched *C.struct_ggml_backend_sched sched *C.struct_ggml_backend_sched
tensors map[string]*C.struct_ggml_tensor tensors map[string]*C.struct_ggml_tensor
@ -58,7 +59,7 @@ type Backend struct {
} }
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) { func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1) meta, n, err := fsggml.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -182,7 +183,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
maxTensors += blocks * 2 maxTensors += blocks * 2
type tensor struct { type tensor struct {
source *fs.Tensor source *fsggml.Tensor
target string target string
} }
@ -413,7 +414,7 @@ func init() {
ml.RegisterBackend("ggml", New) ml.RegisterBackend("ggml", New)
} }
func (b *Backend) Config() ml.Config { func (b *Backend) Config() fs.Config {
return b.meta.KV() return b.meta.KV()
} }

View file

@ -16,7 +16,8 @@ import (
_ "golang.org/x/image/tiff" _ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
fs "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend" _ "github.com/ollama/ollama/ml/backend"
@ -83,10 +84,10 @@ func (m *Base) Config() config {
return m.config return m.config
} }
var models = make(map[string]func(ml.Config) (Model, error)) var models = make(map[string]func(fs.Config) (Model, error))
// Register registers a model constructor for the given architecture // Register registers a model constructor for the given architecture
func Register(name string, f func(ml.Config) (Model, error)) { func Register(name string, f func(fs.Config) (Model, error)) {
if _, ok := models[name]; ok { if _, ok := models[name]; ok {
panic("model: model already registered") panic("model: model already registered")
} }
@ -131,14 +132,14 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
meta, _, err := fs.Decode(r, -1) meta, _, err := fsggml.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return getTextProcessor(meta.KV()) return getTextProcessor(meta.KV())
} }
func getTextProcessor(kv fs.KV) (TextProcessor, error) { func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
arch := kv.Architecture() arch := kv.Architecture()
f, ok := models[arch] f, ok := models[arch]
if !ok { if !ok {

View file

@ -7,7 +7,8 @@ import (
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
fs "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -139,7 +140,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
} }
func TestGetTextProcessor(t *testing.T) { func TestGetTextProcessor(t *testing.T) {
tp, err := getTextProcessor(fs.KV{}) tp, err := getTextProcessor(fsggml.KV{})
if err == nil { if err == nil {
t.Error("expected error") t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") { } else if !strings.Contains(err.Error(), "unsupported model architecture") {
@ -148,10 +149,10 @@ func TestGetTextProcessor(t *testing.T) {
t.Error("expected nil tp") t.Error("expected nil tp")
} }
models["dummy"] = func(ml.Config) (Model, error) { models["dummy"] = func(fs.Config) (Model, error) {
return notTextProcessorModel{}, nil return notTextProcessorModel{}, nil
} }
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"}) tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
if err == nil { if err == nil {
t.Error("expected error") t.Error("expected error")
} else if !strings.Contains(err.Error(), "not a TextProcessor") { } else if !strings.Contains(err.Error(), "not a TextProcessor") {

View file

@ -3,6 +3,7 @@ package gemma2
import ( import (
"math" "math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -35,7 +36,7 @@ const (
gemma27BLayerCount = 46 gemma27BLayerCount = 46
) )
func New(c ml.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
SentencePieceModel: model.NewSentencePieceModel( SentencePieceModel: model.NewSentencePieceModel(
&model.Vocabulary{ &model.Vocabulary{

View file

@ -6,6 +6,7 @@ import (
"math" "math"
"slices" "slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -52,7 +53,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
return visionOutputs return visionOutputs
} }
func New(c ml.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
SentencePieceModel: model.NewSentencePieceModel( SentencePieceModel: model.NewSentencePieceModel(
&model.Vocabulary{ &model.Vocabulary{

View file

@ -3,6 +3,7 @@ package gemma3
import ( import (
"math" "math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -40,7 +41,7 @@ const (
cacheTypeCausal cacheTypeCausal
) )
func newTextModel(c ml.Config) *TextModel { func newTextModel(c fs.Config) *TextModel {
numBlocks := int(c.Uint("block_count")) numBlocks := int(c.Uint("block_count"))
m := TextModel{ m := TextModel{

View file

@ -3,6 +3,7 @@ package gemma3
import ( import (
"math" "math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
) )
@ -111,7 +112,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
return hiddenState return hiddenState
} }
func newVisionModel(c ml.Config) *VisionModel { func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{ return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")), Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{ VisionModelOptions: &VisionModelOptions{

View file

@ -3,7 +3,7 @@ package gemma3
import ( import (
"image" "image"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc" "github.com/ollama/ollama/model/imageproc"
) )
@ -11,7 +11,7 @@ type ImageProcessor struct {
imageSize, patchSize, numChannels int imageSize, patchSize, numChannels int
} }
func newImageProcessor(c ml.Config) ImageProcessor { func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{ return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")), imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")), patchSize: int(c.Uint("vision.patch_size")),

View file

@ -5,6 +5,7 @@ import (
"math" "math"
"strings" "strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -30,7 +31,7 @@ type Model struct {
*Options *Options
} }
func New(c ml.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
} }

View file

@ -8,6 +8,7 @@ import (
"image" "image"
"slices" "slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -32,7 +33,7 @@ const (
selfAttentionLayer selfAttentionLayer
) )
func New(c ml.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
// Verify unified config // Verify unified config
if c.Uint("vision.block_count") == 0 { if c.Uint("vision.block_count") == 0 {
return nil, fmt.Errorf("non-unified vision model not supported") return nil, fmt.Errorf("non-unified vision model not supported")

View file

@ -4,6 +4,7 @@ import (
"math" "math"
"slices" "slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -220,7 +221,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask
return m.Output.Forward(ctx, hiddenState) return m.Output.Forward(ctx, hiddenState)
} }
func newTextModel(c ml.Config) *TextModel { func newTextModel(c fs.Config) *TextModel {
var decoderLayers []TextDecoderLayer var decoderLayers []TextDecoderLayer
for i := range c.Uint("block_count") { for i := range c.Uint("block_count") {
var textDecoderLayer TextDecoderLayer var textDecoderLayer TextDecoderLayer

View file

@ -4,6 +4,7 @@ import (
"math" "math"
"slices" "slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
) )
@ -213,7 +214,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
return hiddenState.Concat(ctx, hiddenStates, 0) return hiddenState.Concat(ctx, hiddenStates, 0)
} }
func newVisionModel(c ml.Config) *VisionModel { func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{ return &VisionModel{
Transformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))}, Transformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))}, GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))},

View file

@ -8,14 +8,14 @@ import (
"golang.org/x/image/draw" "golang.org/x/image/draw"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/fs"
) )
type ImageProcessor struct { type ImageProcessor struct {
imageSize, numChannels, maxNumTiles int imageSize, numChannels, maxNumTiles int
} }
func newImageProcessor(c ml.Config) ImageProcessor { func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{ return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")), imageSize: int(c.Uint("vision.image_size")),
numChannels: int(c.Uint("vision.num_channels")), numChannels: int(c.Uint("vision.num_channels")),