mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
api: return model capabilities from the show endpoint (#10066)
With support for multimodal models becoming more varied and common it is important for clients to be able to easily see what capabilities a model has. Retuning these from the show endpoint will allow clients to easily see what a model can do.
This commit is contained in:
parent
c001b98087
commit
e172f095ba
9 changed files with 521 additions and 69 deletions
24
api/types.go
24
api/types.go
|
@ -12,6 +12,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusError is an error with an HTTP status code and message.
|
// StatusError is an error with an HTTP status code and message.
|
||||||
|
@ -340,17 +341,18 @@ type ShowRequest struct {
|
||||||
|
|
||||||
// ShowResponse is the response returned from [Client.Show].
|
// ShowResponse is the response returned from [Client.Show].
|
||||||
type ShowResponse struct {
|
type ShowResponse struct {
|
||||||
License string `json:"license,omitempty"`
|
License string `json:"license,omitempty"`
|
||||||
Modelfile string `json:"modelfile,omitempty"`
|
Modelfile string `json:"modelfile,omitempty"`
|
||||||
Parameters string `json:"parameters,omitempty"`
|
Parameters string `json:"parameters,omitempty"`
|
||||||
Template string `json:"template,omitempty"`
|
Template string `json:"template,omitempty"`
|
||||||
System string `json:"system,omitempty"`
|
System string `json:"system,omitempty"`
|
||||||
Details ModelDetails `json:"details,omitempty"`
|
Details ModelDetails `json:"details,omitempty"`
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||||
Tensors []Tensor `json:"tensors,omitempty"`
|
Tensors []Tensor `json:"tensors,omitempty"`
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||||
|
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CopyRequest is the request passed to [Client.Copy].
|
// CopyRequest is the request passed to [Client.Copy].
|
||||||
|
|
15
cmd/cmd.go
15
cmd/cmd.go
|
@ -18,6 +18,7 @@ import (
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -339,6 +340,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||||
|
|
||||||
|
// TODO: remove the projector info and vision info checks below,
|
||||||
|
// these are left in for backwards compatibility with older servers
|
||||||
|
// that don't have the capabilities field in the model info
|
||||||
if len(info.ProjectorInfo) != 0 {
|
if len(info.ProjectorInfo) != 0 {
|
||||||
opts.MultiModal = true
|
opts.MultiModal = true
|
||||||
}
|
}
|
||||||
|
@ -669,6 +675,15 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if len(resp.Capabilities) > 0 {
|
||||||
|
tableRender("Capabilities", func() (rows [][]string) {
|
||||||
|
for _, capability := range resp.Capabilities {
|
||||||
|
rows = append(rows, []string{"", capability.String()})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if resp.ProjectorInfo != nil {
|
if resp.ProjectorInfo != nil {
|
||||||
tableRender("Projector", func() (rows [][]string) {
|
tableRender("Projector", func() (rows [][]string) {
|
||||||
arch := resp.ProjectorInfo["general.architecture"].(string)
|
arch := resp.ProjectorInfo["general.architecture"].(string)
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShowInfo(t *testing.T) {
|
func TestShowInfo(t *testing.T) {
|
||||||
|
@ -260,6 +261,34 @@ Weigh anchor!
|
||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("capabilities", func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "test",
|
||||||
|
ParameterSize: "7B",
|
||||||
|
QuantizationLevel: "FP16",
|
||||||
|
},
|
||||||
|
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
|
||||||
|
}, false, &b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := " Model\n" +
|
||||||
|
" architecture test \n" +
|
||||||
|
" parameters 7B \n" +
|
||||||
|
" quantization FP16 \n" +
|
||||||
|
"\n" +
|
||||||
|
" Capabilities\n" +
|
||||||
|
" vision \n" +
|
||||||
|
" tools \n" +
|
||||||
|
"\n"
|
||||||
|
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteHandler(t *testing.T) {
|
func TestDeleteHandler(t *testing.T) {
|
||||||
|
|
|
@ -1217,7 +1217,7 @@ Show information about a model including details, modelfile, template, parameter
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/show -d '{
|
curl http://localhost:11434/api/show -d '{
|
||||||
"model": "llama3.2"
|
"model": "llava"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1260,7 +1260,11 @@ curl http://localhost:11434/api/show -d '{
|
||||||
"tokenizer.ggml.pre": "llama-bpe",
|
"tokenizer.ggml.pre": "llama-bpe",
|
||||||
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
||||||
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
||||||
}
|
},
|
||||||
|
"capabilities": [
|
||||||
|
"completion",
|
||||||
|
"vision"
|
||||||
|
],
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
101
server/images.go
101
server/images.go
|
@ -35,17 +35,11 @@ var (
|
||||||
errCapabilityCompletion = errors.New("completion")
|
errCapabilityCompletion = errors.New("completion")
|
||||||
errCapabilityTools = errors.New("tools")
|
errCapabilityTools = errors.New("tools")
|
||||||
errCapabilityInsert = errors.New("insert")
|
errCapabilityInsert = errors.New("insert")
|
||||||
|
errCapabilityVision = errors.New("vision")
|
||||||
|
errCapabilityEmbedding = errors.New("embedding")
|
||||||
errInsecureProtocol = errors.New("insecure protocol http")
|
errInsecureProtocol = errors.New("insecure protocol http")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Capability string
|
|
||||||
|
|
||||||
const (
|
|
||||||
CapabilityCompletion = Capability("completion")
|
|
||||||
CapabilityTools = Capability("tools")
|
|
||||||
CapabilityInsert = Capability("insert")
|
|
||||||
)
|
|
||||||
|
|
||||||
type registryOptions struct {
|
type registryOptions struct {
|
||||||
Insecure bool
|
Insecure bool
|
||||||
Username string
|
Username string
|
||||||
|
@ -72,46 +66,77 @@ type Model struct {
|
||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capabilities returns the capabilities that the model supports
|
||||||
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
|
capabilities := []model.Capability{}
|
||||||
|
|
||||||
|
// Check for completion capability
|
||||||
|
r, err := os.Open(m.ModelPath)
|
||||||
|
if err == nil {
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
f, _, err := ggml.Decode(r, 0)
|
||||||
|
if err == nil {
|
||||||
|
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||||
|
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||||
|
} else {
|
||||||
|
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||||
|
}
|
||||||
|
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
||||||
|
capabilities = append(capabilities, model.CapabilityVision)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slog.Error("couldn't decode ggml", "error", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slog.Error("couldn't open model file", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Template == nil {
|
||||||
|
return capabilities
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for tools capability
|
||||||
|
if slices.Contains(m.Template.Vars(), "tools") {
|
||||||
|
capabilities = append(capabilities, model.CapabilityTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for insert capability
|
||||||
|
if slices.Contains(m.Template.Vars(), "suffix") {
|
||||||
|
capabilities = append(capabilities, model.CapabilityInsert)
|
||||||
|
}
|
||||||
|
|
||||||
|
return capabilities
|
||||||
|
}
|
||||||
|
|
||||||
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||||
// any missing or unknown capabilities
|
// any missing or unknown capabilities
|
||||||
func (m *Model) CheckCapabilities(caps ...Capability) error {
|
func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||||
|
available := m.Capabilities()
|
||||||
var errs []error
|
var errs []error
|
||||||
for _, cap := range caps {
|
|
||||||
switch cap {
|
|
||||||
case CapabilityCompletion:
|
|
||||||
r, err := os.Open(m.ModelPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("couldn't open model file", "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
defer r.Close()
|
|
||||||
|
|
||||||
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
|
// Map capabilities to their corresponding error
|
||||||
f, _, err := ggml.Decode(r, 0)
|
capToErr := map[model.Capability]error{
|
||||||
if err != nil {
|
model.CapabilityCompletion: errCapabilityCompletion,
|
||||||
slog.Error("couldn't decode ggml", "error", err)
|
model.CapabilityTools: errCapabilityTools,
|
||||||
continue
|
model.CapabilityInsert: errCapabilityInsert,
|
||||||
}
|
model.CapabilityVision: errCapabilityVision,
|
||||||
|
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
for _, cap := range want {
|
||||||
errs = append(errs, errCapabilityCompletion)
|
err, ok := capToErr[cap]
|
||||||
}
|
if !ok {
|
||||||
case CapabilityTools:
|
|
||||||
if !slices.Contains(m.Template.Vars(), "tools") {
|
|
||||||
errs = append(errs, errCapabilityTools)
|
|
||||||
}
|
|
||||||
case CapabilityInsert:
|
|
||||||
vars := m.Template.Vars()
|
|
||||||
if !slices.Contains(vars, "suffix") {
|
|
||||||
errs = append(errs, errCapabilityInsert)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
slog.Error("unknown capability", "capability", cap)
|
slog.Error("unknown capability", "capability", cap)
|
||||||
return fmt.Errorf("unknown capability: %s", cap)
|
return fmt.Errorf("unknown capability: %s", cap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !slices.Contains(available, cap) {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := errors.Join(errs...); err != nil {
|
if len(errs) > 0 {
|
||||||
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
360
server/images_test.go
Normal file
360
server/images_test.go
Normal file
|
@ -0,0 +1,360 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Constants for GGUF magic bytes and version
|
||||||
|
var (
|
||||||
|
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
|
||||||
|
ggufVer = uint32(3) // Version 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to create mock GGUF data
|
||||||
|
func createMockGGUFData(architecture string, vision bool) []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
// Write GGUF header
|
||||||
|
buf.Write(ggufMagic)
|
||||||
|
binary.Write(&buf, binary.LittleEndian, ggufVer)
|
||||||
|
|
||||||
|
// Write tensor count (0 for our test)
|
||||||
|
var numTensors uint64 = 0
|
||||||
|
binary.Write(&buf, binary.LittleEndian, numTensors)
|
||||||
|
|
||||||
|
// Calculate number of metadata entries
|
||||||
|
numMetaEntries := uint64(1) // architecture entry
|
||||||
|
if vision {
|
||||||
|
numMetaEntries++
|
||||||
|
}
|
||||||
|
// Add embedding entry if architecture is "bert"
|
||||||
|
if architecture == "bert" {
|
||||||
|
numMetaEntries++
|
||||||
|
}
|
||||||
|
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
|
||||||
|
|
||||||
|
// Write architecture metadata
|
||||||
|
archKey := "general.architecture"
|
||||||
|
keyLen := uint64(len(archKey))
|
||||||
|
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||||
|
buf.WriteString(archKey)
|
||||||
|
|
||||||
|
// String type (8)
|
||||||
|
var strType uint32 = 8
|
||||||
|
binary.Write(&buf, binary.LittleEndian, strType)
|
||||||
|
|
||||||
|
// String length
|
||||||
|
strLen := uint64(len(architecture))
|
||||||
|
binary.Write(&buf, binary.LittleEndian, strLen)
|
||||||
|
buf.WriteString(architecture)
|
||||||
|
|
||||||
|
if vision {
|
||||||
|
visionKey := architecture + ".vision.block_count"
|
||||||
|
keyLen = uint64(len(visionKey))
|
||||||
|
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||||
|
buf.WriteString(visionKey)
|
||||||
|
|
||||||
|
// uint32 type (4)
|
||||||
|
var uint32Type uint32 = 4
|
||||||
|
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||||
|
|
||||||
|
// uint32 value (1)
|
||||||
|
var countVal uint32 = 1
|
||||||
|
binary.Write(&buf, binary.LittleEndian, countVal)
|
||||||
|
}
|
||||||
|
// Write embedding metadata if architecture is "bert"
|
||||||
|
if architecture == "bert" {
|
||||||
|
poolKey := architecture + ".pooling_type"
|
||||||
|
keyLen = uint64(len(poolKey))
|
||||||
|
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||||
|
buf.WriteString(poolKey)
|
||||||
|
|
||||||
|
// uint32 type (4)
|
||||||
|
var uint32Type uint32 = 4
|
||||||
|
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||||
|
|
||||||
|
// uint32 value (1)
|
||||||
|
var poolingVal uint32 = 1
|
||||||
|
binary.Write(&buf, binary.LittleEndian, poolingVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelCapabilities(t *testing.T) {
|
||||||
|
// Create a temporary directory for test files
|
||||||
|
tempDir, err := os.MkdirTemp("", "model_capabilities_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Create different types of mock model files
|
||||||
|
completionModelPath := filepath.Join(tempDir, "model.bin")
|
||||||
|
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||||
|
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||||
|
// Create a simple model file for tests that don't depend on GGUF content
|
||||||
|
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
||||||
|
|
||||||
|
err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create completion model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create completion model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create simple model file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testModels := []struct {
|
||||||
|
name string
|
||||||
|
model Model
|
||||||
|
expectedCaps []model.Capability
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model with completion capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: completionModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "model with completion, tools, and insert capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: completionModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with tools and insert capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with tools capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityTools},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with vision capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: visionModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with vision, tools, and insert capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: visionModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with embedding capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: embeddingModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare two slices of model.Capability regardless of order
|
||||||
|
compareCapabilities := func(a, b []model.Capability) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
aCount := make(map[model.Capability]int)
|
||||||
|
for _, cap := range a {
|
||||||
|
aCount[cap]++
|
||||||
|
}
|
||||||
|
|
||||||
|
bCount := make(map[model.Capability]int)
|
||||||
|
for _, cap := range b {
|
||||||
|
bCount[cap]++
|
||||||
|
}
|
||||||
|
|
||||||
|
for cap, count := range aCount {
|
||||||
|
if bCount[cap] != count {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testModels {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test Capabilities method
|
||||||
|
caps := tt.model.Capabilities()
|
||||||
|
if !compareCapabilities(caps, tt.expectedCaps) {
|
||||||
|
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelCheckCapabilities(t *testing.T) {
|
||||||
|
// Create a temporary directory for test files
|
||||||
|
tempDir, err := os.MkdirTemp("", "model_check_capabilities_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||||
|
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
||||||
|
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||||
|
|
||||||
|
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create simple model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create vision model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model Model
|
||||||
|
checkCaps []model.Capability
|
||||||
|
expectedErrMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "completion model without tools capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityTools},
|
||||||
|
expectedErrMsg: "does not support tools",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with all needed capabilities",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model missing insert capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityInsert},
|
||||||
|
expectedErrMsg: "does not support insert",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model missing vision capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityVision},
|
||||||
|
expectedErrMsg: "does not support vision",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with vision capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: visionModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityVision},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with embedding capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: embeddingModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityEmbedding},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{"unknown"},
|
||||||
|
expectedErrMsg: "unknown capability",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test CheckCapabilities method
|
||||||
|
err := tt.model.CheckCapabilities(tt.checkCaps...)
|
||||||
|
if tt.expectedErrMsg == "" {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg)
|
||||||
|
} else if !strings.Contains(err.Error(), tt.expectedErrMsg) {
|
||||||
|
t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -87,7 +87,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
||||||
|
|
||||||
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||||
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := GetModel(name.String())
|
m, err := GetModel(name.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, fs.ErrNotExist):
|
case errors.Is(err, fs.ErrNotExist):
|
||||||
|
@ -159,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
// expire the runner
|
// expire the runner
|
||||||
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||||
s.sched.expireRunner(model)
|
s.sched.expireRunner(m)
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
|
@ -176,9 +176,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []model.Capability{model.CapabilityCompletion}
|
||||||
if req.Suffix != "" {
|
if req.Suffix != "" {
|
||||||
caps = append(caps, CapabilityInsert)
|
caps = append(caps, model.CapabilityInsert)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||||
|
@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
isMllama := checkMllamaModelFamily(model)
|
isMllama := checkMllamaModelFamily(m)
|
||||||
if isMllama && len(req.Images) > 1 {
|
if isMllama && len(req.Images) > 1 {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
|
||||||
return
|
return
|
||||||
|
@ -211,7 +211,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
images := make([]llm.ImageData, len(req.Images))
|
images := make([]llm.ImageData, len(req.Images))
|
||||||
for i := range req.Images {
|
for i := range req.Images {
|
||||||
if isMllama && len(model.ProjectorPaths) > 0 {
|
if isMllama && len(m.ProjectorPaths) > 0 {
|
||||||
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
||||||
|
@ -422,7 +422,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
|
@ -530,7 +530,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
|
@ -813,12 +813,13 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := &api.ShowResponse{
|
resp := &api.ShowResponse{
|
||||||
License: strings.Join(m.License, "\n"),
|
License: strings.Join(m.License, "\n"),
|
||||||
System: m.System,
|
System: m.System,
|
||||||
Template: m.Template.String(),
|
Template: m.Template.String(),
|
||||||
Details: modelDetails,
|
Details: modelDetails,
|
||||||
Messages: msgs,
|
Messages: msgs,
|
||||||
ModifiedAt: manifest.fi.ModTime(),
|
Capabilities: m.Capabilities(),
|
||||||
|
ModifiedAt: manifest.fi.ModTime(),
|
||||||
}
|
}
|
||||||
|
|
||||||
var params []string
|
var params []string
|
||||||
|
@ -1468,9 +1469,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []model.Capability{model.CapabilityCompletion}
|
||||||
if len(req.Tools) > 0 {
|
if len(req.Tools) > 0 {
|
||||||
caps = append(caps, CapabilityTools)
|
caps = append(caps, model.CapabilityTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
name := model.ParseName(req.Model)
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LlmRequest struct {
|
type LlmRequest struct {
|
||||||
|
@ -195,7 +196,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Embedding models should always be loaded with parallel=1
|
// Embedding models should always be loaded with parallel=1
|
||||||
if pending.model.CheckCapabilities(CapabilityCompletion) != nil {
|
if pending.model.CheckCapabilities(model.CapabilityCompletion) != nil {
|
||||||
numParallel = 1
|
numParallel = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
15
types/model/capability.go
Normal file
15
types/model/capability.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
type Capability string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CapabilityCompletion = Capability("completion")
|
||||||
|
CapabilityTools = Capability("tools")
|
||||||
|
CapabilityInsert = Capability("insert")
|
||||||
|
CapabilityVision = Capability("vision")
|
||||||
|
CapabilityEmbedding = Capability("embedding")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c Capability) String() string {
|
||||||
|
return string(c)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue