convert: only extract large files

This commit is contained in:
Michael Yang 2024-06-29 16:53:59 -07:00
parent 781fc2d576
commit eafc607abb
10 changed files with 120 additions and 200 deletions

View file

@ -7,9 +7,9 @@ import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"log/slog"
"os"
"path/filepath"
"slices"
)
@ -32,8 +32,8 @@ type Tokenizer struct {
Template string
}
func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
v, err := parseVocabulary(d)
func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) {
v, err := parseVocabulary(fsys)
if err != nil {
return nil, err
}
@ -44,7 +44,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
}
addedTokens := make(map[string]token)
if f, err := os.Open(filepath.Join(d, "tokenizer.json")); errors.Is(err, os.ErrNotExist) {
if f, err := fsys.Open("tokenizer.json"); errors.Is(err, os.ErrNotExist) {
} else if err != nil {
return nil, err
} else {
@ -87,7 +87,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
}
}
if f, err := os.Open(filepath.Join(d, "tokenizer_config.json")); errors.Is(err, os.ErrNotExist) {
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
} else if err != nil {
return nil, err
} else {
@ -172,8 +172,8 @@ type Vocabulary struct {
Types []int32
}
func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) {
f, err := os.Open(filepath.Join(p, "tokenizer.json"))
func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
f, err := fsys.Open("tokenizer.json")
if err != nil {
return nil, err
}
@ -219,20 +219,20 @@ func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) {
return &v, nil
}
func parseVocabulary(d string) (*Vocabulary, error) {
patterns := map[string]func(string) (*Vocabulary, error){
func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
patterns := map[string]func(fs.FS) (*Vocabulary, error){
"tokenizer.model": parseSentencePiece,
"tokenizer.json": parseVocabularyFromTokenizer,
}
for pattern, parseFn := range patterns {
if _, err := os.Stat(filepath.Join(d, pattern)); errors.Is(err, os.ErrNotExist) {
if _, err := fs.Stat(fsys, pattern); errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
return nil, err
}
return parseFn(d)
return parseFn(fsys)
}
return nil, errors.New("unknown tensor format")