mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
embed text document in modelfile
This commit is contained in:
commit
7a5f3616fd
10 changed files with 371 additions and 52 deletions
|
@ -85,6 +85,7 @@ llama_token llama_sample(
|
|||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
|
@ -93,6 +94,7 @@ import (
|
|||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
@ -408,3 +410,38 @@ func (llm *LLM) next() (C.llama_token, error) {
|
|||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Embedding(input string) ([]float64, error) {
|
||||
if !llm.EmbeddingOnly {
|
||||
return nil, errors.New("llama: embedding not enabled")
|
||||
}
|
||||
|
||||
tokens := llm.tokenize(input)
|
||||
if tokens == nil {
|
||||
return nil, errors.New("llama: tokenize embedding")
|
||||
}
|
||||
|
||||
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
|
||||
if retval != 0 {
|
||||
return nil, errors.New("llama: eval")
|
||||
}
|
||||
|
||||
n := int(C.llama_n_embd(llm.ctx))
|
||||
if n <= 0 {
|
||||
return nil, errors.New("llama: no embeddings generated")
|
||||
}
|
||||
|
||||
embedPtr := C.llama_get_embeddings(llm.ctx)
|
||||
if embedPtr == nil {
|
||||
return nil, errors.New("llama: embedding retrieval failed")
|
||||
}
|
||||
|
||||
header := reflect.SliceHeader{
|
||||
Data: uintptr(unsafe.Pointer(embedPtr)),
|
||||
Len: n,
|
||||
Cap: n,
|
||||
}
|
||||
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
|
||||
|
||||
return embedSlice, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue