embed text document in modelfile

This commit is contained in:
Bruce MacDonald 2023-08-09 10:26:19 -04:00 committed by GitHub
commit 7a5f3616fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 371 additions and 52 deletions

View file

@ -85,6 +85,7 @@ llama_token llama_sample(
}
*/
import "C"
import (
"bytes"
"embed"
@ -93,6 +94,7 @@ import (
"io"
"log"
"os"
"reflect"
"strings"
"sync"
"unicode/utf8"
@ -408,3 +410,38 @@ func (llm *LLM) next() (C.llama_token, error) {
return token, nil
}
func (llm *LLM) Embedding(input string) ([]float64, error) {
if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled")
}
tokens := llm.tokenize(input)
if tokens == nil {
return nil, errors.New("llama: tokenize embedding")
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
if retval != 0 {
return nil, errors.New("llama: eval")
}
n := int(C.llama_n_embd(llm.ctx))
if n <= 0 {
return nil, errors.New("llama: no embeddings generated")
}
embedPtr := C.llama_get_embeddings(llm.ctx)
if embedPtr == nil {
return nil, errors.New("llama: embedding retrieval failed")
}
header := reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(embedPtr)),
Len: n,
Cap: n,
}
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
return embedSlice, nil
}