ollama/server/routes_tokenize_test.go
ParthSareen e3dd90102d WIP
2024-12-18 10:49:22 -08:00

136 lines
3.6 KiB
Go

package server
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
)
type mockModelLoader struct {
LoadModelFn func(string, llama.ModelParams) (*loadedModel, error)
}
func (ml *mockModelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
if ml.LoadModelFn != nil {
return ml.LoadModelFn(name, params)
}
return nil, nil
}
type mockModel struct {
llama.Model
TokenizeFn func(text string, addBos bool, addEos bool) ([]int, error)
TokenToPieceFn func(token int) string
}
func (mockModel) Tokenize(text string, addBos bool, addEos bool) ([]int, error) {
return []int{1, 2, 3}, nil
}
func (mockModel) TokenToPiece(token int) string {
return fmt.Sprint(token)
}
func TestTokenizeHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
mockModel := mockModel{}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mockRunner{}),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mockRunner{},
}
},
},
ml: mockLoader,
}
t.Run("method not allowed", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
}
})
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("missing text", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
Model: "test",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
Text: "test text",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("model not found", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
Model: "nonexistent",
Text: "test text",
})
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
})
t.Run("successful tokenization", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
Model: "test",
Text: "test text",
})
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp api.TokenizeResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
expectedTokens := []int{0, 1}
if len(resp.Tokens) != len(expectedTokens) {
t.Errorf("expected %d tokens, got %d", len(expectedTokens), len(resp.Tokens))
}
for i, token := range resp.Tokens {
if token != expectedTokens[i] {
t.Errorf("expected token %d at position %d, got %d", expectedTokens[i], i, token)
}
}
})
}