ollama/server/routes_tokenization_test.go
2024-12-16 11:06:14 -08:00

290 lines
8.3 KiB
Go

package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/llm"
)
func TestTokenize(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
t.Run("missing body", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", nil)
s.TokenizeHandler(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "missing request body\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", strings.NewReader("{}"))
s.TokenizeHandler(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("tokenize text", func(t *testing.T) {
// First create the model
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d", w.Code)
}
// Now test tokenization
body, err := json.Marshal(api.TokenizeRequest{
Model: "test",
Text: "Hello world how are you",
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w = httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.TokenizeHandler(w, r)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var resp api.TokenizeResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Errorf("failed to decode response: %v", err)
}
// Our mock tokenizer creates sequential tokens based on word count
expected := []int{0, 1, 2, 3, 4}
if diff := cmp.Diff(resp.Tokens, expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("tokenize empty text", func(t *testing.T) {
body, err := json.Marshal(api.TokenizeRequest{
Model: "test",
Text: "",
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.TokenizeHandler(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
func TestDetokenize(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
t.Run("detokenize tokens", func(t *testing.T) {
// Create the model first
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d - %s", w.Code, w.Body.String())
}
body, err := json.Marshal(api.DetokenizeRequest{
Model: "test",
Tokens: []int{0, 1, 2, 3, 4},
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w = httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.DetokenizeHandler(w, r)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var resp api.DetokenizeResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Errorf("failed to decode response: %v", err)
}
})
t.Run("detokenize empty tokens", func(t *testing.T) {
body, err := json.Marshal(api.DetokenizeRequest{
Model: "test",
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.DetokenizeHandler(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "missing tokens for detokenization\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("detokenize missing model", func(t *testing.T) {
body, err := json.Marshal(api.DetokenizeRequest{
Tokens: []int{0, 1, 2},
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.DetokenizeHandler(w, r)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "model '' not found\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}