mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
43 lines
1.7 KiB
Diff
43 lines
1.7 KiB
Diff
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
From: jmorganca <jmorganca@gmail.com>
|
|
Date: Tue, 8 Apr 2025 15:28:34 -0700
|
|
Subject: [PATCH] embeddings
|
|
|
|
allow a loaded model in llama.cpp to be used for
|
|
both embeddings and causal attention text generation
|
|
instead of forcing one or the error
|
|
---
|
|
src/llama-context.cpp | 6 +++---
|
|
1 file changed, 3 insertions(+), 3 deletions(-)
|
|
|
|
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
|
index 5a2eef9b..9c1fe93f 100644
|
|
--- a/src/llama-context.cpp
|
|
+++ b/src/llama-context.cpp
|
|
@@ -1225,7 +1225,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
int64_t n_outputs_all = 0;
|
|
|
|
// count outputs
|
|
- if (batch.logits && !embd_pooled) {
|
|
+ if (batch.logits) {
|
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
n_outputs_all += batch.logits[i] != 0;
|
|
}
|
|
@@ -1337,7 +1337,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
|
//}
|
|
|
|
- auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
|
|
+ auto * t_logits = cparams.causal_attn ? res->get_logits() : nullptr;
|
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
|
|
|
if (t_embd && res->get_embd_pooled()) {
|
|
@@ -1481,7 +1481,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
const auto n_embd = hparams.n_embd;
|
|
|
|
// TODO: use a per-batch flag for logits presence instead
|
|
- bool has_logits = !cparams.embeddings;
|
|
+ bool has_logits = cparams.causal_attn;
|
|
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
|
|
|
// TODO: hacky enc-dec support
|