From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Tue, 8 Apr 2025 15:28:34 -0700 Subject: [PATCH] embeddings allow a loaded model in llama.cpp to be used for both embeddings and causal attention text generation instead of forcing one or the error --- src/llama-context.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5a2eef9b..9c1fe93f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1225,7 +1225,7 @@ int llama_context::decode(llama_batch & inp_batch) { int64_t n_outputs_all = 0; // count outputs - if (batch.logits && !embd_pooled) { + if (batch.logits) { for (uint32_t i = 0; i < n_tokens_all; ++i) { n_outputs_all += batch.logits[i] != 0; } @@ -1337,7 +1337,7 @@ int llama_context::decode(llama_batch & inp_batch) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = cparams.embeddings ? nullptr : res->get_logits(); + auto * t_logits = cparams.causal_attn ? res->get_logits() : nullptr; auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; if (t_embd && res->get_embd_pooled()) { @@ -1481,7 +1481,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead - bool has_logits = !cparams.embeddings; + bool has_logits = cparams.causal_attn; bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); // TODO: hacky enc-dec support