api: return model capabilities from the show endpoint (#10066)

With support for multimodal models becoming more varied and common it is important for clients to be able to easily see what capabilities a model has. Retuning these from the show endpoint will allow clients to easily see what a model can do.
This commit is contained in:
Bruce MacDonald 2025-04-01 15:21:46 -07:00 committed by GitHub
parent c001b98087
commit e172f095ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 521 additions and 69 deletions

View file

@ -87,7 +87,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
@ -144,7 +144,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(name.String())
m, err := GetModel(name.String())
if err != nil {
switch {
case errors.Is(err, fs.ErrNotExist):
@ -159,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
s.sched.expireRunner(model)
s.sched.expireRunner(m)
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
@ -176,9 +176,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
caps := []Capability{CapabilityCompletion}
caps := []model.Capability{model.CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, CapabilityInsert)
caps = append(caps, model.CapabilityInsert)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
isMllama := checkMllamaModelFamily(model)
isMllama := checkMllamaModelFamily(m)
if isMllama && len(req.Images) > 1 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
return
@ -211,7 +211,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
if isMllama && len(model.ProjectorPaths) > 0 {
if isMllama && len(m.ProjectorPaths) > 0 {
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
@ -422,7 +422,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@ -530,7 +530,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@ -813,12 +813,13 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
ModifiedAt: manifest.fi.ModTime(),
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
}
var params []string
@ -1468,9 +1469,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
caps := []Capability{CapabilityCompletion}
caps := []model.Capability{model.CapabilityCompletion}
if len(req.Tools) > 0 {
caps = append(caps, CapabilityTools)
caps = append(caps, model.CapabilityTools)
}
name := model.ParseName(req.Model)