mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
server: validate local path on safetensor create (#9379)
More validation during the safetensor creation process. Properly handle relative paths (like ./model.safetensors) while rejecting absolute paths Add comprehensive test coverage for various paths No functionality changes for valid inputs - existing workflows remain unaffected Leverages Go 1.24's new os.Root functionality for secure containment
This commit is contained in:
parent
31e472baa4
commit
bebb6823c0
2 changed files with 131 additions and 1 deletions
|
@ -8,6 +8,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -34,6 +35,7 @@ var (
|
||||||
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
|
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
|
||||||
errUnknownType = errors.New("unknown type")
|
errUnknownType = errors.New("unknown type")
|
||||||
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
|
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
|
||||||
|
errFilePath = errors.New("file path must be relative")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) CreateHandler(c *gin.Context) {
|
func (s *Server) CreateHandler(c *gin.Context) {
|
||||||
|
@ -46,6 +48,13 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for v := range r.Files {
|
||||||
|
if !fs.ValidPath(v) {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||||
|
@ -104,7 +113,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||||
if r.Adapters != nil {
|
if r.Adapters != nil {
|
||||||
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType} {
|
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
|
||||||
if errors.Is(err, badReq) {
|
if errors.Is(err, badReq) {
|
||||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||||
return
|
return
|
||||||
|
@ -221,8 +230,22 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tmpDir)
|
defer os.RemoveAll(tmpDir)
|
||||||
|
// Set up a root to validate paths
|
||||||
|
root, err := os.OpenRoot(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer root.Close()
|
||||||
|
|
||||||
for fp, digest := range files {
|
for fp, digest := range files {
|
||||||
|
if !fs.ValidPath(fp) {
|
||||||
|
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
|
||||||
|
}
|
||||||
|
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
// Path is likely outside the root
|
||||||
|
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||||
|
}
|
||||||
|
|
||||||
blobPath, err := GetBlobsPath(digest)
|
blobPath, err := GetBlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -270,6 +293,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer bin.Close()
|
||||||
|
|
||||||
f, _, err := ggml.Decode(bin, 0)
|
f, _, err := ggml.Decode(bin, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
106
server/create_test.go
Normal file
106
server/create_test.go
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertFromSafetensors(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
// Helper function to create a new layer and return its digest
|
||||||
|
makeTemp := func(content string) string {
|
||||||
|
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create layer: %v", err)
|
||||||
|
}
|
||||||
|
return l.Digest
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a safetensors compatible file with empty JSON content
|
||||||
|
var buf bytes.Buffer
|
||||||
|
headerSize := int64(len("{}"))
|
||||||
|
binary.Write(&buf, binary.LittleEndian, headerSize)
|
||||||
|
buf.WriteString("{}")
|
||||||
|
|
||||||
|
model := makeTemp(buf.String())
|
||||||
|
config := makeTemp(`{
|
||||||
|
"architectures": ["LlamaForCausalLM"],
|
||||||
|
"vocab_size": 32000
|
||||||
|
}`)
|
||||||
|
tokenizer := makeTemp(`{
|
||||||
|
"version": "1.0",
|
||||||
|
"truncation": null,
|
||||||
|
"padding": null,
|
||||||
|
"added_tokens": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"single_word": false,
|
||||||
|
"lstrip": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filePath string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
// Invalid
|
||||||
|
{
|
||||||
|
name: "InvalidRelativePathShallow",
|
||||||
|
filePath: filepath.Join("..", "file.safetensors"),
|
||||||
|
wantErr: errFilePath,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "InvalidRelativePathDeep",
|
||||||
|
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
|
||||||
|
wantErr: errFilePath,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "InvalidNestedPath",
|
||||||
|
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
|
||||||
|
wantErr: errFilePath,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AbsolutePathOutsideRoot",
|
||||||
|
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
|
||||||
|
wantErr: errFilePath, // Should fail since it's outside tmpDir
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ValidRelativePath",
|
||||||
|
filePath: "model.safetensors",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Create the minimum required file map for convertFromSafetensors
|
||||||
|
files := map[string]string{
|
||||||
|
tt.filePath: model,
|
||||||
|
"config.json": config,
|
||||||
|
"tokenizer.json": tokenizer,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
|
||||||
|
|
||||||
|
if (tt.wantErr == nil && err != nil) ||
|
||||||
|
(tt.wantErr != nil && err == nil) ||
|
||||||
|
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
|
||||||
|
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue