diff --git a/Makefile.sync b/Makefile.sync index 007282746..2bb8d4ab8 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggerganov/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=46e3556e01b824e52395fb050b29804b6cff2a7c +FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c .PHONY: help help: diff --git a/llama/build-info.cpp b/llama/build-info.cpp index e169b926b..58c1b0801 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "46e3556e01b824e52395fb050b29804b6cff2a7c"; +char const *LLAMA_COMMIT = "d7cfe1ffe0f435d0048a6058d529daf76e072d9c"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/common/common.cpp b/llama/llama.cpp/common/common.cpp index 4bb140ee2..d2b0d50e3 100644 --- a/llama/llama.cpp/common/common.cpp +++ b/llama/llama.cpp/common/common.cpp @@ -2,6 +2,9 @@ #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING #endif +#include "ggml.h" +#include "gguf.h" + #include "common.h" #include "log.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: @@ -70,6 +73,22 @@ #include #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + +// +// CURL utils +// + +using curl_ptr = std::unique_ptr; + +// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one +struct curl_slist_ptr { + struct curl_slist * ptr = nullptr; + ~curl_slist_ptr() { + if (ptr) { + curl_slist_free_all(ptr); + } + } +}; #endif // LLAMA_USE_CURL using json = nlohmann::ordered_json; @@ -464,6 +483,48 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +std::string string_join(const std::vector & values, const std::string & separator) { + std::ostringstream result; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + result << separator; + } + result << values[i]; + } + return result.str(); +} + +std::vector string_split(const std::string & str, const std::string & delimiter) { + std::vector parts; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + parts.push_back(str.substr(start, end - start)); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + parts.push_back(str.substr(start)); + + return parts; +} + +std::string string_repeat(const std::string & str, size_t n) { + if (n == 0) { + return ""; + } + + std::string result; + result.reserve(str.length() * n); + + for (size_t i = 0; i < n; ++i) { + result += str; + } + + return result; +} + std::string string_from(bool value) { return value ? "true" : "false"; } @@ -846,7 +907,7 @@ struct common_init_result common_init_from_params(common_params & params) { } else if (!params.model_url.empty()) { model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); } else { - model = llama_load_model_from_file(params.model.c_str(), mparams); + model = llama_model_load_from_file(params.model.c_str(), mparams); } if (model == NULL) { @@ -854,26 +915,28 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + const llama_vocab * vocab = llama_model_get_vocab(model); + if (params.reranking) { bool ok = true; - if (llama_token_bos(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__); + if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); ok = false; } - if (llama_token_eos(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__); + if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__); ok = false; } - if (llama_token_sep(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__); + if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); ok = false; } if (!ok) { - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -881,10 +944,10 @@ struct common_init_result common_init_from_params(common_params & params) { auto cparams = common_context_params_to_llama(params); - llama_context * lctx = llama_new_context_with_model(model, cparams); + llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -895,25 +958,26 @@ struct common_init_result common_init_from_params(common_params & params) { if (!params.control_vectors.empty()) { if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; - if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model); const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } - int err = llama_control_vector_apply(lctx, - cvec.data.data(), - cvec.data.size(), - cvec.n_embd, - params.control_vector_layer_start, - params.control_vector_layer_end); + int err = llama_apply_adapter_cvec( + lctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); if (err) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -921,12 +985,12 @@ struct common_init_result common_init_from_params(common_params & params) { // load and optionally apply lora adapters for (auto & la : params.lora_adapters) { - llama_lora_adapter_ptr lora; - lora.reset(llama_lora_adapter_init(model, la.path.c_str())); + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model, la.path.c_str())); if (lora == nullptr) { LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -935,17 +999,17 @@ struct common_init_result common_init_from_params(common_params & params) { } if (!params.lora_init_without_apply) { - common_lora_adapters_apply(lctx, params.lora_adapters); + common_set_adapter_lora(lctx, params.lora_adapters); } - if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); params.sampling.ignore_eos = false; } if (params.sampling.ignore_eos) { - for (llama_token i = 0; i < llama_n_vocab(model); i++) { - if (llama_token_is_eog(model, i)) { + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); params.sampling.logit_bias.push_back({i, -INFINITY}); } @@ -966,8 +1030,9 @@ struct common_init_result common_init_from_params(common_params & params) { LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); std::vector tmp; - llama_token bos = llama_token_bos(model); - llama_token eos = llama_token_eos(model); + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); + // some models (e.g. T5) don't have a BOS token if (bos != LLAMA_TOKEN_NULL) { tmp.push_back(bos); @@ -982,7 +1047,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_encoder(model)) { llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = bos; } tmp.clear(); @@ -1002,11 +1067,11 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora) { - llama_lora_adapter_clear(ctx); +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { + llama_clear_adapter_lora(ctx); for (auto & la : lora) { if (la.scale != 0.0f) { - llama_lora_adapter_set(ctx, la.ptr, la.scale); + llama_set_adapter_lora(ctx, la.ptr, la.scale); } } } @@ -1020,7 +1085,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } - mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; @@ -1123,7 +1187,8 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl - std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; if (!curl) { LOG_ERR("%s: error initializing libcurl\n", __func__); return false; @@ -1137,11 +1202,9 @@ static bool common_download_file(const std::string & url, const std::string & pa // Check if hf-token or bearer-token was specified if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer "; - auth_header += hf_token.c_str(); - struct curl_slist *http_headers = NULL; - http_headers = curl_slist_append(http_headers, auth_header.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers); + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); } #if defined(_WIN32) @@ -1411,7 +1474,7 @@ struct llama_model * common_load_model_from_url( } } - return llama_load_model_from_file(local_path.c_str(), params); + return llama_model_load_from_file(local_path.c_str(), params); } struct llama_model * common_load_model_from_hf( @@ -1437,6 +1500,80 @@ struct llama_model * common_load_model_from_hf( return common_load_model_from_url(model_url, local_path, hf_token, params); } +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) + * + * Return pair of (with "repo" already having tag removed) + * + * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + */ +std::pair common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); + } + + // fetch model info from Hugging Face Hub API + json model_info; + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::string res_str; + std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + static_cast(data)->append((char * ) ptr, size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (!hf_token.empty()) { + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + throw std::runtime_error("error: cannot make GET request to HF API"); + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + if (res_code == 200) { + model_info = json::parse(res_str); + } else if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + } + + // check response + if (!model_info.contains("ggufFile")) { + throw std::runtime_error("error: model does not have ggufFile"); + } + json & gguf_file = model_info.at("ggufFile"); + if (!gguf_file.contains("rfilename")) { + throw std::runtime_error("error: ggufFile does not have rfilename"); + } + + return std::make_pair(hf_repo, gguf_file.at("rfilename")); +} + #else struct llama_model * common_load_model_from_url( @@ -1458,6 +1595,11 @@ struct llama_model * common_load_model_from_hf( return nullptr; } +std::pair common_get_hf_file(const std::string &, const std::string &) { + LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); + return std::make_pair("", ""); +} + #endif // LLAMA_USE_CURL // @@ -1556,21 +1698,23 @@ std::vector common_tokenize( const std::string & text, bool add_special, bool parse_special) { - return common_tokenize(llama_get_model(ctx), text, add_special, parse_special); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_tokenize(vocab, text, add_special, parse_special); } std::vector common_tokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const std::string & text, bool add_special, bool parse_special) { // upper limit for the number of tokens int n_tokens = text.length() + 2 * add_special; std::vector result(n_tokens); - n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -1579,12 +1723,18 @@ std::vector common_tokenize( } std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_token_to_piece(vocab, token, special); +} + +std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); if (n_chars < 0) { piece.resize(-n_chars); - int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); GGML_ASSERT(check == -n_chars); } else { @@ -1594,13 +1744,19 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token return piece; } -std::string common_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { +std::string common_detokenize(const struct llama_context * ctx, const std::vector & tokens, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_detokenize(vocab, tokens, special); +} + +std::string common_detokenize(const struct llama_vocab * vocab, const std::vector & tokens, bool special) { std::string text; text.resize(std::max(text.capacity(), tokens.size())); - int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); if (n_chars < 0) { text.resize(-n_chars); - n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization } @@ -1610,103 +1766,6 @@ std::string common_detokenize(llama_context * ctx, const std::vector 0) { - std::vector model_template(res + 1, 0); - llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size()); - return std::string(model_template.data(), model_template.size() - 1); - } - return ""; -} - -bool common_chat_verify_template(const std::string & tmpl) { - llama_chat_message chat[] = {{"user", "test"}}; - int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); - return res >= 0; -} - -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, - const std::vector & msgs, - bool add_ass) { - int alloc_size = 0; - bool fallback = false; // indicate if we must fallback to default chatml - std::vector chat; - for (auto & msg : msgs) { - chat.push_back({msg.role.c_str(), msg.content.c_str()}); - alloc_size += (msg.role.size() + msg.content.size()) * 1.25; - } - - const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); - std::vector buf(alloc_size); - - // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - - // error: chat template is not supported - if (res < 0) { - if (ptr_tmpl != nullptr) { - // if the custom "tmpl" is not supported, we throw an error - // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); - } else { - // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - fallback = true; - } - } - - // if it turns out that our buffer is too small, we resize it - if ((size_t) res > buf.size()) { - buf.resize(res); - res = llama_chat_apply_template( - fallback ? nullptr : model, - fallback ? "chatml" : ptr_tmpl, - chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - } - - std::string formatted_chat(buf.data(), res); - return formatted_chat; -} - -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass) { - std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); - std::vector chat_new(past_msg); - // if the past_msg ends with a newline, we must preserve it in the formatted version - if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { - ss << "\n"; - }; - // format chat with new_msg - chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); - // get the diff part - ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); - return ss.str(); -} - -std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl) { - std::vector msgs = { - {"system", "You are a helpful assistant"}, - {"user", "Hello"}, - {"assistant", "Hi there"}, - {"user", "How are you?"}, - }; - return common_chat_apply_template(model, tmpl, msgs, true); -} - // // KV cache utils // diff --git a/llama/llama.cpp/common/common.h b/llama/llama.cpp/common/common.h index 0d452cf0f..efe8e7f79 100644 --- a/llama/llama.cpp/common/common.h +++ b/llama/llama.cpp/common/common.h @@ -4,6 +4,7 @@ #include "llama-cpp.h" +#include #include #include #include @@ -24,11 +25,11 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" -struct common_lora_adapter_info { +struct common_adapter_lora_info { std::string path; float scale; - struct llama_lora_adapter * ptr; + struct llama_adapter_lora * ptr; }; using llama_tokens = std::vector; @@ -103,6 +104,17 @@ enum dimre_method { DIMRE_METHOD_MEAN, }; +enum common_conversation_mode { + COMMON_CONVERSATION_MODE_DISABLED = 0, + COMMON_CONVERSATION_MODE_ENABLED = 1, + COMMON_CONVERSATION_MODE_AUTO = 2, +}; + +struct common_grammar_trigger { + std::string word; + bool at_start; +}; + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -128,6 +140,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float top_n_sigma = -1.00f;// -1.0 = disabled float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; @@ -148,7 +161,11 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_TEMPERATURE, }; - std::string grammar; // optional BNF-like grammar to constrain sampling + std::string grammar; // optional BNF-like grammar to constrain sampling + bool grammar_lazy = false; + std::vector grammar_trigger_words; // optional trigger words to trigger lazy grammar + std::vector grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens. + std::set preserved_tokens; std::vector logit_bias; // logit biases to apply @@ -161,15 +178,19 @@ struct common_params_speculative { int32_t n_ctx = 0; // draft context size int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding - int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) float p_split = 0.1f; // speculative decoding split probability - float p_min = 0.9f; // minimum speculative decoding probability (greedy) + float p_min = 0.75f; // minimum speculative decoding probability (greedy) struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - std::string model = ""; // draft model for speculative decoding // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + + std::string model = ""; // draft model for speculative decoding // NOLINT + std::string model_url = ""; // model url to download // NOLINT }; struct common_params_vocoder { @@ -178,6 +199,13 @@ struct common_params_vocoder { std::string model = ""; // model path // NOLINT std::string model_url = ""; // model url to download // NOLINT + + bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT +}; + +enum common_reasoning_format { + COMMON_REASONING_FORMAT_NONE, + COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` }; struct common_params { @@ -240,14 +268,13 @@ struct common_params { std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT - std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; - bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) - std::vector lora_adapters; // lora adapter path with user defined scale + bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) + std::vector lora_adapters; // lora adapter path with user defined scale std::vector control_vectors; // control vector with user defined scale @@ -271,11 +298,11 @@ struct common_params { bool kl_divergence = false; // compute KL divergence bool usage = false; // print usage + bool completion = false; // print source-able completion script bool use_color = false; // use color to distinguish generations and inputs bool special = false; // enable special token output bool interactive = false; // interactive mode bool interactive_first = false; // wait for user input immediately - bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it @@ -301,6 +328,8 @@ struct common_params { ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V + common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; + // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector // NOLINT std::vector image; // path to image file(s) @@ -322,7 +351,9 @@ struct common_params { std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT std::string chat_template = ""; // NOLINT + bool use_jinja = false; // NOLINT bool enable_chat_template = true; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; std::vector api_keys; @@ -401,13 +432,13 @@ bool set_process_priority(enum ggml_sched_priority prio); // #ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif #else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) #endif LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) @@ -416,6 +447,10 @@ std::string string_format(const char * fmt, ...); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); +std::string string_join(const std::vector & values, const std::string & separator); +std::vector string_split(const std::string & str, const std::string & delimiter); +std::string string_repeat(const std::string & str, size_t n); + void string_replace_all(std::string & s, const std::string & search, const std::string & replace); template @@ -454,6 +489,11 @@ static bool string_starts_with(const std::string & str, return str.rfind(prefix, 0) == 0; } +static bool string_ends_with(const std::string & str, + const std::string & suffix) { // While we wait for C++20's std::string::ends_with... + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); @@ -481,7 +521,7 @@ struct common_init_result { llama_model_ptr model; llama_context_ptr context; - std::vector lora; + std::vector lora; }; struct common_init_result common_init_from_params(common_params & params); @@ -495,6 +535,7 @@ struct llama_model * common_load_model_from_url( const std::string & local_path, const std::string & hf_token, const struct llama_model_params & params); + struct llama_model * common_load_model_from_hf( const std::string & repo, const std::string & remote_path, @@ -502,8 +543,12 @@ struct llama_model * common_load_model_from_hf( const std::string & hf_token, const struct llama_model_params & params); +std::pair common_get_hf_file( + const std::string & hf_repo_with_tag, + const std::string & hf_token); + // clear LoRA adapters from context, then apply new list of adapters -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora); +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); // // Batch utils @@ -541,7 +586,7 @@ std::vector common_tokenize( bool parse_special = false); std::vector common_tokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const std::string & text, bool add_special, bool parse_special = false); @@ -553,48 +598,23 @@ std::string common_token_to_piece( llama_token token, bool special = true); +std::string common_token_to_piece( + const struct llama_vocab * vocab, + llama_token token, + bool special = true); + // detokenizes a vector of tokens into a string // should work similar to Python's `tokenizer.decode` // optionally renders special/control tokens std::string common_detokenize( - llama_context * ctx, + const struct llama_context * ctx, const std::vector & tokens, bool special = true); -// -// Chat template utils -// - -// same with llama_chat_message, but uses std::string -struct common_chat_msg { - std::string role; - std::string content; -}; - -// Get the built-in chat template for the model. Return empty string if not present. -std::string common_get_builtin_chat_template(const struct llama_model * model); - -// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool common_chat_verify_template(const std::string & tmpl); - -// CPP wrapper for llama_chat_apply_template -// If the built-in template is not supported, we default to chatml -// If the custom "tmpl" is not supported, we throw an error -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, - const std::vector & chat, - bool add_ass); - -// Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass); - -// Returns an example of formatted chat -std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl); +std::string common_detokenize( + const struct llama_vocab * vocab, + const std::vector & tokens, + bool special = true); // // KV cache utils diff --git a/llama/llama.cpp/common/json-schema-to-grammar.cpp b/llama/llama.cpp/common/json-schema-to-grammar.cpp index 2a8dbd22d..30c28808f 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.cpp +++ b/llama/llama.cpp/common/json-schema-to-grammar.cpp @@ -1,4 +1,6 @@ #include "json-schema-to-grammar.h" +#include "common.h" + #include #include #include @@ -11,11 +13,6 @@ using json = nlohmann::ordered_json; -template -static std::string join(Iterator begin, Iterator end, const std::string & separator); - -static std::string repeat(const std::string & str, size_t n); - static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { auto has_max = max_items != std::numeric_limits::max(); @@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & if (sub_len > 0) { auto from_sub = from.substr(i + 1); auto to_sub = to.substr(i + 1); - auto sub_zeros = repeat("0", sub_len); - auto sub_nines = repeat("9", sub_len); + auto sub_zeros = string_repeat("0", sub_len); + auto sub_nines = string_repeat("9", sub_len); auto to_reached = false; out << "("; @@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & auto max_digits = max_s.length(); for (auto digits = min_digits; digits < max_digits; digits++) { - uniform_range(min_s, repeat("9", digits)); - min_s = "1" + repeat("0", digits); + uniform_range(min_s, string_repeat("9", digits)); + min_s = "1" + string_repeat("0", digits); out << " | "; } uniform_range(min_s, max_s); @@ -318,49 +315,6 @@ std::unordered_map GRAMMAR_LITERAL_ESCAPES = { std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; -template -std::string join(Iterator begin, Iterator end, const std::string & separator) { - std::ostringstream result; - if (begin != end) { - result << *begin; - for (Iterator it = begin + 1; it != end; ++it) { - result << separator << *it; - } - } - return result.str(); -} - -static std::vector split(const std::string & str, const std::string & delimiter) { - std::vector tokens; - size_t start = 0; - size_t end = str.find(delimiter); - - while (end != std::string::npos) { - tokens.push_back(str.substr(start, end - start)); - start = end + delimiter.length(); - end = str.find(delimiter, start); - } - - tokens.push_back(str.substr(start)); - - return tokens; -} - -static std::string repeat(const std::string & str, size_t n) { - if (n == 0) { - return ""; - } - - std::string result; - result.reserve(str.length() * n); - - for (size_t i = 0; i < n; ++i) { - result += str; - } - - return result; -} - static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { std::smatch match; std::string result; @@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) { class SchemaConverter { private: + friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); std::function _fetch_json; bool _dotall; std::unordered_map _rules; @@ -418,7 +373,7 @@ private: for (size_t i = 0; i < alt_schemas.size(); i++) { rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); } - return join(rules.begin(), rules.end(), " | "); + return string_join(rules, " | "); } std::string _visit_pattern(const std::string & pattern, const std::string & name) { @@ -481,7 +436,7 @@ private: for (const auto & item : ret) { results.push_back(to_rule(item)); } - return std::make_pair(join(results.begin(), results.end(), " "), false); + return std::make_pair(string_join(results, " "), false); }; while (i < length) { @@ -539,7 +494,7 @@ private: } curly_brackets += '}'; i++; - auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); + auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); int min_times = 0; int max_times = std::numeric_limits::max(); try { @@ -809,10 +764,11 @@ private: public: SchemaConverter( const std::function & fetch_json, - bool dotall) + bool dotall, + bool compact_spaces) : _fetch_json(fetch_json), _dotall(dotall) { - _rules["space"] = SPACE_RULE; + _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE; } void resolve_refs(json & schema, const std::string & url) { @@ -854,7 +810,7 @@ public: return; } std::string pointer = ref.substr(ref.find('#') + 1); - std::vector tokens = split(pointer, "/"); + std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; if (target.is_null() || !target.contains(sel)) { @@ -905,7 +861,7 @@ public: for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -1019,10 +975,10 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n")); + throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { - fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str()); + fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); } } @@ -1035,11 +991,35 @@ public: } }; -std::string json_schema_to_grammar(const json & schema) { - SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); - auto copy = schema; - converter.resolve_refs(copy, "input"); - converter.visit(copy, ""); +std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { +#ifdef LLAMA_USE_LLGUIDANCE + if (!force_gbnf) { + return "%llguidance {}\nstart: %json " + schema.dump(); + } +#else + (void)force_gbnf; +#endif // LLAMA_USE_LLGUIDANCE + return build_grammar([&](const common_grammar_builder & callbacks) { + auto copy = schema; + callbacks.resolve_refs(copy); + callbacks.add_schema("", copy); + }); +} + +std::string build_grammar(const std::function & cb, const common_grammar_options & options) { + SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces); + common_grammar_builder builder { + /* .add_rule = */ [&](const std::string & name, const std::string & rule) { + return converter._add_rule(name, rule); + }, + /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) { + return converter.visit(schema, name == "root" ? "" : name); + }, + /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) { + converter.resolve_refs(schema, ""); + } + }; + cb(builder); converter.check_errors(); return converter.format_grammar(); } diff --git a/llama/llama.cpp/common/json-schema-to-grammar.h b/llama/llama.cpp/common/json-schema-to-grammar.h index 41623b346..62a3b0a44 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.h +++ b/llama/llama.cpp/common/json-schema-to-grammar.h @@ -5,4 +5,18 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); +std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, + bool force_gbnf = false); + +struct common_grammar_builder { + std::function add_rule; + std::function add_schema; + std::function resolve_refs; +}; + +struct common_grammar_options { + bool dotall = false; + bool compact_spaces = false; +}; + +std::string build_grammar(const std::function & cb, const common_grammar_options & options = {}); diff --git a/llama/llama.cpp/common/log.cpp b/llama/llama.cpp/common/log.cpp index 04c7c0ed1..52b31470c 100644 --- a/llama/llama.cpp/common/log.cpp +++ b/llama/llama.cpp/common/log.cpp @@ -1,5 +1,6 @@ #include "log.h" +#include #include #include #include @@ -14,16 +15,6 @@ void common_log_set_verbosity_thold(int verbosity) { common_log_verbosity_thold = verbosity; } -#define LOG_COL_DEFAULT "\033[0m" -#define LOG_COL_BOLD "\033[1m" -#define LOG_COL_RED "\033[31m" -#define LOG_COL_GREEN "\033[32m" -#define LOG_COL_YELLOW "\033[33m" -#define LOG_COL_BLUE "\033[34m" -#define LOG_COL_MAGENTA "\033[35m" -#define LOG_COL_CYAN "\033[36m" -#define LOG_COL_WHITE "\033[37m" - static int64_t t_us() { return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); } @@ -206,6 +197,7 @@ public: vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy); } #endif + va_end(args_copy); } entry.level = level; diff --git a/llama/llama.cpp/common/log.h b/llama/llama.cpp/common/log.h index 66605cc69..c56bb50d9 100644 --- a/llama/llama.cpp/common/log.h +++ b/llama/llama.cpp/common/log.h @@ -2,9 +2,20 @@ #include "ggml.h" // for ggml_log_level +#define LOG_CLR_TO_EOL "\033[K\r" +#define LOG_COL_DEFAULT "\033[0m" +#define LOG_COL_BOLD "\033[1m" +#define LOG_COL_RED "\033[31m" +#define LOG_COL_GREEN "\033[32m" +#define LOG_COL_YELLOW "\033[33m" +#define LOG_COL_BLUE "\033[34m" +#define LOG_COL_MAGENTA "\033[35m" +#define LOG_COL_CYAN "\033[36m" +#define LOG_COL_WHITE "\033[37m" + #ifndef __GNUC__ # define LOG_ATTRIBUTE_FORMAT(...) -#elif defined(__MINGW32__) +#elif defined(__MINGW32__) && !defined(__clang__) # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #else # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) diff --git a/llama/llama.cpp/common/sampling.cpp b/llama/llama.cpp/common/sampling.cpp index e83a971c7..37a0d9c85 100644 --- a/llama/llama.cpp/common/sampling.cpp +++ b/llama/llama.cpp/common/sampling.cpp @@ -113,7 +113,10 @@ struct common_sampler { void set_logits(struct llama_context * ctx, int idx) { const auto * logits = llama_get_logits_ith(ctx, idx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); cur.resize(n_vocab); @@ -131,24 +134,47 @@ std::string common_params_sampling::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, + top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); } struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); lparams.no_perf = params.no_perf; + struct llama_sampler * grmr; + if (params.grammar.compare(0, 11, "%llguidance") == 0) { +#ifdef LLAMA_USE_LLGUIDANCE + grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); +#else + GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); +#endif // LLAMA_USE_LLGUIDANCE + } else { + std::vector trigger_words; + trigger_words.reserve(params.grammar_trigger_words.size()); + for (const auto & str : params.grammar_trigger_words) { + trigger_words.push_back(str.word.c_str()); + } + + grmr = params.grammar_lazy + ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", + trigger_words.data(), trigger_words.size(), + params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) + : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); + } + auto * result = new common_sampler { /* .params = */ params, - /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), + /* .grmr = */ grmr, /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, @@ -157,56 +183,62 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(result->chain, llama_sampler_init_logit_bias( - llama_n_vocab(model), + llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); if (params.mirostat == 0) { - for (const auto & cnstr : params.samplers) { - switch (cnstr) { - case COMMON_SAMPLER_TYPE_DRY: - { - std::vector c_breakers; - c_breakers.reserve(params.dry_sequence_breakers.size()); - for (const auto & str : params.dry_sequence_breakers) { - c_breakers.push_back(str.c_str()); - } + if (params.top_n_sigma >= 0) { + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); + } else { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto & str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); - } - break; - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); - break; - case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); - break; - case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); - break; - case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); - break; - default: - GGML_ASSERT(false && "unknown sampler type"); + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); + } } } llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); diff --git a/llama/llama.cpp/common/sampling.h b/llama/llama.cpp/common/sampling.h index 348911b18..2064421db 100644 --- a/llama/llama.cpp/common/sampling.h +++ b/llama/llama.cpp/common/sampling.h @@ -102,3 +102,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr); std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector common_sampler_types_from_chars(const std::string & chars); + +llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, + const char * grammar_kind, const char * grammar_data); diff --git a/llama/llama.cpp/examples/llava/clip.cpp b/llama/llama.cpp/examples/llava/clip.cpp index 86b91d5cb..54265bebb 100644 --- a/llama/llama.cpp/examples/llava/clip.cpp +++ b/llama/llama.cpp/examples/llava/clip.cpp @@ -7,6 +7,7 @@ #include "ggml-cpu.h" #include "ggml-alloc.h" #include "ggml-backend.h" +#include "gguf.h" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" @@ -39,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -114,6 +116,7 @@ static std::string format(const char * fmt, ...) { #define KEY_HAS_VIS_ENC "clip.has_vision_encoder" #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_HAS_GLM_PROJ "clip.has_glm_projector" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_USE_GELU "clip.use_gelu" @@ -131,6 +134,7 @@ static std::string format(const char * fmt, ...) { #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" #define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -172,6 +176,15 @@ static std::string format(const char * fmt, ...) { #define TN_MINICPMV_ATTN "resampler.attn.%s.%s" #define TN_MINICPMV_LN "resampler.ln_%s.%s" +#define TN_GLM_ADAPER_CONV "adapter.conv.%s" +#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s" +#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s" +#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s" +#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" +#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" +#define TN_GLM_BOI_W "adapter.boi" +#define TN_GLM_EOI_W "adapter.eoi" + enum projector_type { PROJECTOR_TYPE_MLP, @@ -179,6 +192,7 @@ enum projector_type { PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_RESAMPLER, + PROJECTOR_TYPE_GLM_EDGE, PROJECTOR_TYPE_MERGER, PROJECTOR_TYPE_UNKNOWN, }; @@ -188,6 +202,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, { PROJECTOR_TYPE_RESAMPLER, "resampler"}, + { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, }; @@ -275,7 +290,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { { const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); int arr_n = gguf_get_arr_n(ctx_gguf, i); - const void * data = gguf_get_arr_data(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); std::stringstream ss; ss << "["; for (int j = 0; j < arr_n; j++) { @@ -444,8 +459,9 @@ struct clip_hparams { char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) - int32_t image_grid_pinpoints[32]; + std::vector image_grid_pinpoints; int32_t image_crop_resolution; + std::unordered_set vision_feature_layer; }; struct clip_layer { @@ -512,6 +528,12 @@ struct clip_vision_model { struct ggml_tensor * mm_4_w = NULL; struct ggml_tensor * mm_4_b = NULL; + //GLMV-Edge projection + struct ggml_tensor * mm_model_adapter_conv_w; + struct ggml_tensor * mm_model_adapter_conv_b; + struct ggml_tensor * boi_w; + struct ggml_tensor * eoi_w; + // MobileVLM projection struct ggml_tensor * mm_model_mlp_1_w; struct ggml_tensor * mm_model_mlp_1_b; @@ -572,12 +594,14 @@ struct clip_ctx { bool has_vision_encoder = false; bool has_llava_projector = false; bool has_minicpmv_projector = false; + bool has_glm_projector = false; bool has_qwen2vl_merger = false; int minicpmv_version = 2; struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; + int32_t max_feature_layer; float image_mean[3]; float image_std[3]; bool use_gelu = false; @@ -644,13 +668,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; - int n_layer = hparams.n_layer; const float eps = hparams.eps; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; const int batch_size = imgs->size; - if (ctx->has_llava_projector || ctx->has_minicpmv_projector) { + if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { GGML_ASSERT(batch_size == 1); } @@ -730,6 +753,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 else if (ctx->minicpmv_version == 3) { pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); } + else if (ctx->minicpmv_version == 4) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); + } ggml_set_name(pos_embed, "pos_embed"); ggml_set_input(pos_embed); } @@ -742,14 +768,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); } + std::vector embedding_stack; + const auto & vision_feature_layer = hparams.vision_feature_layer; + // loop over layers - if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) { - // TODO: figure out why we doing thing in this way ??? - n_layer += 1; - } - for (int il = 0; il < n_layer - 1; il++) { + for (int il = 0; il < ctx->max_feature_layer; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + // If this is an embedding feature layer, save the output. + // NOTE: 0 index here refers to the input to the encoder. + if (vision_feature_layer.find(il) != vision_feature_layer.end()) { + embedding_stack.push_back(embeddings); + } + //const size_t nb_q_w = model.layers[il].q_w->nb[0]; // layernorm1 @@ -837,7 +868,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 cur = ggml_add(ctx0, embeddings, cur); embeddings = cur; - } // post-layernorm @@ -848,6 +878,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } + // final layer is a vision feature layer + if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) { + embedding_stack.push_back(embeddings); + } + + // If feature layers are explicitly set, stack them (if we have multiple) + if (!embedding_stack.empty()) { + embeddings = embedding_stack[0]; + for (size_t i = 1; i < embedding_stack.size(); i++) { + embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0); + } + } + // llava projector if (ctx->has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); @@ -1065,6 +1108,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 n_head = hidden_size/d_head; num_query = 64; } + else if (ctx->minicpmv_version == 4) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); @@ -1099,7 +1147,33 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 GGML_ASSERT(false); } } - else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + // glm projector + else if (ctx->has_glm_projector) { + if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + size_t gridsz = (size_t)sqrt(embeddings->ne[1]); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); + embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); + embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); + embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); + embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); + //GLU + { + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + embeddings = ggml_gelu_inplace(ctx0, embeddings); + struct ggml_tensor * x = embeddings; + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); + x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); + embeddings = ggml_silu_inplace(ctx0, embeddings); + embeddings = ggml_mul(ctx0, embeddings,x); + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); + } + } else { + GGML_ABORT("fatel error"); + } + } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -1268,6 +1342,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); } + idx = gguf_find_key(ctx, KEY_HAS_GLM_PROJ); + if (idx != -1) { + new_clip->has_glm_projector = gguf_get_val_bool(ctx, idx); + } + idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER); if (idx != -1) { new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx); @@ -1292,6 +1371,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); + LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector); LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); } @@ -1402,14 +1482,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); int n = gguf_get_arr_n(ctx, idx); const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); - for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) { - hparams.image_grid_pinpoints[i] = pinpoints[i]; + for (int i = 0; i < n; ++i) { + hparams.image_grid_pinpoints.push_back(pinpoints[i]); } - if (n < 32) - hparams.image_grid_pinpoints[n] = 0; - } catch (std::runtime_error & /*e*/) { - hparams.image_grid_pinpoints[0]=0; - } + } catch (std::runtime_error & /*e*/) { } + + // Load the vision feature layer indices if they are explicitly provided; + // if multiple vision feature layers are present, the values will be concatenated + // to form the final visual features. + // NOTE: gguf conversions should standardize the values of the vision feature layer to + // be non-negative, since we use -1 to mark values as unset here. + try { + int idx = get_key_idx(ctx, KEY_FEATURE_LAYER); + int n = gguf_get_arr_n(ctx, idx); + + const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); + + for (int i = 0; i < n; ++i) { + hparams.vision_feature_layer.insert(vision_feature_layer[i]); + } + } catch (std::runtime_error & /*e*/) { } try { int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); @@ -1435,6 +1527,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->image_std[i] = std_data[i]; } + // Calculate the deepest feature layer based on hparams and projector type + new_clip->max_feature_layer = get_deepest_feature_layer(new_clip); + if (verbosity >= 2) { LOG_INF("\n%s: vision model hparams\n", __func__); LOG_INF("image_size %d\n", hparams.image_size); @@ -1448,8 +1543,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); LOG_INF("v_image_grid_pinpoints: "); - for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) { - LOG_INF("%d ", hparams.image_grid_pinpoints[i]); + for (const auto & pp : hparams.image_grid_pinpoints) { + LOG_INF("%d ", pp); + } + LOG_INF("\n"); + LOG_INF("v_vision_feature_layer: "); + for (const auto & feature_layer: hparams.vision_feature_layer) { + LOG_INF("%d ", feature_layer); } LOG_INF("\n"); LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); @@ -1584,6 +1684,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); } + else if (new_clip->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + vision_model.mm_model_adapter_conv_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "weight")); + vision_model.mm_model_adapter_conv_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "bias")); + vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_LINEAR,"weight")); + vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"weight")); + vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"bias")); + vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); + vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_GATE,"weight")); + vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); + vision_model.boi_w = get_tensor(new_clip->ctx_data, TN_GLM_BOI_W); + vision_model.eoi_w = get_tensor(new_clip->ctx_data, TN_GLM_EOI_W); + } else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) { vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); @@ -1676,11 +1788,11 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { } } -static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) { +void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img) { img->nx = nx; img->ny = ny; img->buf.resize(3 * nx * ny); - memcpy(img->buf.data(), data, img->buf.size()); + memcpy(img->buf.data(), rgb_pixels, img->buf.size()); } bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { @@ -1690,7 +1802,7 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { LOG_ERR("%s: failed to load image '%s'\n", __func__, fname); return false; } - build_clip_img_from_data(data, nx, ny, img); + clip_build_img_from_pixels(data, nx, ny, img); stbi_image_free(data); return true; } @@ -1702,7 +1814,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length LOG_ERR("%s: failed to decode image bytes\n", __func__); return false; } - build_clip_img_from_data(data, nx, ny, img); + clip_build_img_from_pixels(data, nx, ny, img); stbi_image_free(data); return true; } @@ -2058,6 +2170,7 @@ static std::vector> uhd_slice_image(const clip_imag images[images.size()-1].push_back(patch); } } + clip_image_u8_free(refine_image); } return images; } @@ -2096,6 +2209,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli clip_image_f32_free(res); } } + for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t j = 0; j < imgs[i].size(); ++j) { + if (imgs[i][j] != nullptr) { + clip_image_u8_free(imgs[i][j]); + } + } + } return true; } else if (ctx->has_qwen2vl_merger) { @@ -2116,6 +2236,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli return true; } + if (ctx->has_glm_projector) { + res_imgs->size = 1; + res_imgs->data = new clip_image_f32[res_imgs->size]; + clip_image_u8 resized_image; + int32_t sz=ctx->vision_model.hparams.image_size; + bicubic_resize(*img, resized_image,sz,sz); + clip_image_f32 * res = clip_image_f32_init(); + //clip_image_save_to_bmp(resized_image, "resized.bmp"); + normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std); + res_imgs->data[0] = *res; + clip_image_f32_free(res); + return true; + } + bool pad_to_square = true; if (!ctx->has_vision_encoder) { LOG_ERR("This gguf file seems to have no vision encoder\n"); @@ -2160,10 +2294,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli } } } else { - if (params.image_grid_pinpoints[0] != 0) { + if (!params.image_grid_pinpoints.empty()) { // "spatial_unpad" with "anyres" processing for llava-1.6 std::vector> possible_resolutions; - for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { + for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) { possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); } std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); @@ -2301,7 +2435,8 @@ void clip_free(clip_ctx * ctx) { } size_t clip_embd_nbytes(const struct clip_ctx * ctx) { - return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float); + int extra_tokens = ctx->has_glm_projector ? 2 : 0; + return (clip_n_patches(ctx) + extra_tokens) * clip_n_mmproj_embd(ctx) * sizeof(float); } size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) { @@ -2328,7 +2463,14 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) { } const int32_t * clip_image_grid(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.image_grid_pinpoints; + if (ctx->vision_model.hparams.image_grid_pinpoints.size()) { + return &ctx->vision_model.hparams.image_grid_pinpoints.front(); + } + return nullptr; +} + +size_t get_clip_image_grid_size(const struct clip_ctx * ctx) { + return ctx->vision_model.hparams.image_grid_pinpoints.size(); } int clip_n_patches(const struct clip_ctx * ctx) { @@ -2343,7 +2485,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { + if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { n_patches /= 4; } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { if (ctx->minicpmv_version == 2) { @@ -2352,6 +2494,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i else if (ctx->minicpmv_version == 3) { n_patches = 64; } + else if (ctx->minicpmv_version == 4) { + n_patches = 64; + } } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); @@ -2473,6 +2618,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if (ctx->has_minicpmv_projector) { GGML_ASSERT(batch_size == 1); } + if (ctx->has_glm_projector) { + GGML_ASSERT(batch_size == 1); + ggml_tensor * boi = ctx->vision_model.boi_w; + ggml_backend_tensor_get(boi,vec,0,ggml_nbytes(boi)); + vec = (float*)(vec+ggml_nelements(boi)); //offset for boi + } // build the inference graph ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); @@ -2531,8 +2682,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); - int bucket_coords_h[70]; - int bucket_coords_w[70]; + int bucket_coords_h[1024]; + int bucket_coords_w[1024]; for (int i = 0; i < pos_h; i++){ bucket_coords_h[i] = std::floor(70.0*i/pos_h); } @@ -2560,6 +2711,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima else if (ctx->minicpmv_version == 3) { embed_dim = 3584; } + else if (ctx->minicpmv_version == 4) { + embed_dim = 3584; + } auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); @@ -2622,11 +2776,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); - { + if (!ctx->has_glm_projector) { struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + // The patches vector is used to get rows to index into the embeds with; + // we should skip dim 0 only if we have CLS to avoid going out of bounds + // when retrieving the rows. + int patch_offset = ctx->has_class_embedding ? 1 : 0; int* patches_data = (int*)malloc(ggml_nbytes(patches)); for (int i = 0; i < num_patches; i++) { - patches_data[i] = i + 1; + patches_data[i] = i + patch_offset; } ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); free(patches_data); @@ -2646,14 +2804,19 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + if (ctx->has_glm_projector) { + //eoi + ggml_tensor * eoi = ctx->vision_model.eoi_w; + int offset = ggml_nelements(embeddings); + ggml_backend_tensor_get(eoi, vec+offset, 0, ggml_nbytes(eoi)); + } + return true; } bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) { - ggml_type type = GGML_TYPE_Q4_1; - assert(itype < GGML_TYPE_COUNT); - type = static_cast(itype); + ggml_type type = static_cast(itype); auto * ctx_clip = clip_model_load(fname_inp, 2); @@ -2706,8 +2869,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i } } - // quantize only 2D tensors - quantize &= (ggml_n_dims(cur) == 2); + // quantize only 2D tensors and bigger than block size + quantize &= (ggml_n_dims(cur) == 2) && cur->ne[0] > ggml_blck_size(type); if (quantize) { new_type = type; @@ -2752,7 +2915,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i total_size_org += orig_size; total_size_new += new_size; gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); + GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size); + gguf_set_tensor_data(ctx_out, name.c_str(), new_data); fout.write((const char *)new_data, new_size); size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size; for (size_t j = 0; j < pad; ++j) { @@ -2802,6 +2966,12 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { else if (ctx->minicpmv_version == 3) { return 3584; } + else if (ctx->minicpmv_version == 4) { + return 3584; + } + } + if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){ + return ctx->vision_model.mm_model_mlp_3_w->ne[1]; } if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { return ctx->vision_model.mm_1_b->ne[0]; @@ -2818,10 +2988,35 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) { return 0; } +bool clip_is_glm(const struct clip_ctx * ctx) { + return ctx->has_glm_projector; +} bool clip_is_qwen2vl(const struct clip_ctx * ctx) { return ctx->has_qwen2vl_merger; } +// Determine the number of encoder layers to iterate over +int get_deepest_feature_layer(const struct clip_ctx * ctx) { + // Get the index of the second to last layer; this is the + // default for models that have a llava projector + const auto & hparams = ctx->vision_model.hparams; + int n_layer = hparams.n_layer - 1; + int deepest_feature_layer = -1; + + // Handle other projectors; incrementing here indicates that we + // should use the last encoder layer for the vision features. + if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { + n_layer += 1; + } + + // If we set explicit vision feature layers, only go up to the deepest one + for (const auto & feature_layer : hparams.vision_feature_layer) { + if (feature_layer > deepest_feature_layer) { + deepest_feature_layer = feature_layer; + } + } + return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer; +} bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { clip_image_f32 clip_img; diff --git a/llama/llama.cpp/examples/llava/clip.h b/llama/llama.cpp/examples/llava/clip.h index 1603edd26..f9f80d7d9 100644 --- a/llama/llama.cpp/examples/llava/clip.h +++ b/llama/llama.cpp/examples/llava/clip.h @@ -55,6 +55,7 @@ CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx); CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); +CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx); CLIP_API int clip_n_patches (const struct clip_ctx * ctx); CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img); @@ -73,6 +74,9 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); +/** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */ +CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img); + CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ @@ -89,10 +93,14 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); +CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); +CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx); + CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); + #ifdef __cplusplus } #endif diff --git a/llama/llama.cpp/examples/llava/llava.cpp b/llama/llama.cpp/examples/llava/llava.cpp index 0f0f3f623..f0e484a1c 100644 --- a/llama/llama.cpp/examples/llava/llava.cpp +++ b/llama/llama.cpp/examples/llava/llava.cpp @@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector return true; } -static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) { +static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { int width = image->nx; int height = image->ny; int num_patches = (height / patch_size) * (width / patch_size); @@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); } else { - int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); - if (has_minicpmv_projector == 2) { - encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); - } - else if (has_minicpmv_projector == 3) { - encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); - } + encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); } if (!encoded) { @@ -313,6 +307,23 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli load_image_size->height = img->ny; clip_add_load_image_size(ctx_clip, load_image_size); LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); + delete[] img_res_v.data; + img_res_v.size = 0; + img_res_v.data = nullptr; + } + else if (clip_is_glm(ctx_clip)){ + struct clip_image_size * load_image_size = clip_image_size_init(); + load_image_size->width = img_res_v.data[0].nx; + load_image_size->height = img_res_v.data[0].ny; + clip_add_load_image_size(ctx_clip, load_image_size); + + bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); + int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2); + *n_img_pos = (pos * pos + 2); + if (!encoded){ + LOG_ERR("Unable to encode image \n"); + return false; + } } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding @@ -342,9 +353,10 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); const int32_t * image_grid = clip_image_grid(ctx_clip); + const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip); std::vector> grid_pinpoints; - for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) { + for (size_t i = 0; i < num_gridpoints; i += 2) { grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); } @@ -384,7 +396,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama)); auto n_image_embd = clip_n_mmproj_embd(ctx_clip); if (n_image_embd != n_llama_embd) { LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); @@ -394,10 +406,14 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * } bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { - int num_max_patches = 6; + // Granite vision uses up to 10 patches + base patch + int num_max_patches = 11; if (clip_is_minicpmv(ctx_clip)) { num_max_patches = 10; } + if (clip_is_glm(ctx_clip)) { + num_max_patches = 1; + } float * image_embd; if (clip_is_qwen2vl(ctx_clip)) { // qwen2vl don't split image into chunks, so `num_max_patches` is not needed. @@ -457,7 +473,7 @@ struct llava_embd_batch { }; bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { - int n_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { int n_eval = image_embed->n_image_pos - i; diff --git a/llama/llama.cpp/include/llama-cpp.h b/llama/llama.cpp/include/llama-cpp.h index 1500cb2fc..8f6368177 100644 --- a/llama/llama.cpp/include/llama-cpp.h +++ b/llama/llama.cpp/include/llama-cpp.h @@ -9,7 +9,7 @@ #include "llama.h" struct llama_model_deleter { - void operator()(llama_model * model) { llama_free_model(model); } + void operator()(llama_model * model) { llama_model_free(model); } }; struct llama_context_deleter { @@ -20,11 +20,11 @@ struct llama_sampler_deleter { void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); } }; -struct llama_lora_adapter_deleter { - void operator()(llama_lora_adapter * lora_adapter) { llama_lora_adapter_free(lora_adapter); } +struct llama_adapter_lora_deleter { + void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } }; typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; -typedef std::unique_ptr llama_lora_adapter_ptr; +typedef std::unique_ptr llama_adapter_lora_ptr; diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index 9f411960e..cc9480056 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -34,7 +34,6 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF -// TODO: use everywhere in the implementation #define LLAMA_TOKEN_NULL -1 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' @@ -57,7 +56,7 @@ extern "C" { // TODO: show sample usage // - // struct llama_vocab; // TODO: add in the future + struct llama_vocab; struct llama_model; struct llama_context; struct llama_sampler; @@ -214,7 +213,7 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported }; - // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -290,9 +289,6 @@ extern "C" { // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() const float * tensor_split; - // comma separated list of RPC servers to use for offloading - const char * rpc_servers; - // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. // If the provided progress_callback returns true, model loading continues. // If it returns false, model loading is immediately aborted. @@ -312,7 +308,7 @@ extern "C" { }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations - // https://github.com/ggerganov/llama.cpp/pull/7544 + // https://github.com/ggml-org/llama.cpp/pull/7544 struct llama_context_params { uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode @@ -325,7 +321,7 @@ extern "C" { enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings - // ref: https://github.com/ggerganov/llama.cpp/pull/2054 + // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model @@ -388,11 +384,10 @@ extern "C" { } llama_chat_message; // lora adapter - // TODO: rename to llama_adapter_lora - struct llama_lora_adapter; + struct llama_adapter_lora; // Helpers for getting default parameters - // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172) + // TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172) LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); @@ -403,31 +398,53 @@ extern "C" { // Call once at the start of the program LLAMA_API void llama_backend_init(void); + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(void); + //optional: LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); // Optional: an auto threadpool gets created in ggml if not passed explicitly LLAMA_API void llama_attach_threadpool( - struct llama_context * ctx, - ggml_threadpool_t threadpool, - ggml_threadpool_t threadpool_batch); + struct llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch); + LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); - // Call once at the end of the program - currently only used for MPI - LLAMA_API void llama_backend_free(void); + DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params), + "use llama_model_load_from_file instead"); - LLAMA_API struct llama_model * llama_load_model_from_file( + // Load the model from a file + // If the file is split into multiple parts, the file name must follow this pattern: -%05d-of-%05d.gguf + // If the split file name does not follow this pattern, use llama_model_load_from_splits + LLAMA_API struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params); - // TODO: rename to llama_model_free - LLAMA_API void llama_free_model(struct llama_model * model); + // Load the model from multiple splits (support custom naming scheme) + // The paths must be in the correct order + LLAMA_API struct llama_model * llama_model_load_from_splits( + const char ** paths, + size_t n_paths, + struct llama_model_params params); - // TODO: rename to llama_init_from_model - LLAMA_API struct llama_context * llama_new_context_with_model( + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), + "use llama_model_free instead"); + + LLAMA_API void llama_model_free(struct llama_model * model); + + LLAMA_API struct llama_context * llama_init_from_model( struct llama_model * model, struct llama_context_params params); + DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params), + "use llama_init_from_model instead"); + // TODO (jmorganca): this should most likely be passed in as part of a batch // and not set on the context for all batches. LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state); @@ -449,20 +466,31 @@ extern "C" { LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); - LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); - LLAMA_API int32_t llama_n_embd (const struct llama_model * model); - LLAMA_API int32_t llama_n_layer (const struct llama_model * model); - LLAMA_API int32_t llama_n_head (const struct llama_model * model); + DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); + DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); + DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead"); + DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead"); - LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); - LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + + LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); + + LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); // Get the model's RoPE frequency scaling factor - LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); + LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); + + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab); + + LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure @@ -488,6 +516,10 @@ extern "C" { // Returns the total size of all the tensors in the model in bytes LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + // Get the default chat template. Returns nullptr if not available + // If name is NULL, returns the default chat template + LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); + // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); @@ -515,34 +547,31 @@ extern "C" { // // Load a LoRA adapter from file - // TODO: rename to llama_adapter_lora_init - LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init( + LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); + // Manually free a LoRA adapter + // Note: loaded adapters will be free when the associated model is deleted + LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + + // The following functions operate on a llama_context, hence the naming: llama_verb_... + // Add a loaded LoRA adapter to given context // This will not modify model's weight - // TODO: rename to llama_set_adapter_lora - LLAMA_API int32_t llama_lora_adapter_set( + LLAMA_API int32_t llama_set_adapter_lora( struct llama_context * ctx, - struct llama_lora_adapter * adapter, + struct llama_adapter_lora * adapter, float scale); // Remove a specific LoRA adapter from given context // Return -1 if the adapter is not present in the context - // TODO: rename to llama_rm_adapter_lora - LLAMA_API int32_t llama_lora_adapter_remove( + LLAMA_API int32_t llama_rm_adapter_lora( struct llama_context * ctx, - struct llama_lora_adapter * adapter); + struct llama_adapter_lora * adapter); // Remove all LoRA adapters from given context - // TODO: rename to llama_clear_adapter_lora - LLAMA_API void llama_lora_adapter_clear(struct llama_context * ctx); - - // Manually free a LoRA adapter - // Note: loaded adapters will be free when the associated model is deleted - // TODO: rename to llama_adapter_lora_free - LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter); + LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. @@ -550,9 +579,8 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. - // TODO: rename to llama_adapter_cvec_apply - LLAMA_API int32_t llama_control_vector_apply( - struct llama_context * lctx, + LLAMA_API int32_t llama_apply_adapter_cvec( + struct llama_context * ctx, const float * data, size_t len, int32_t n_embd, @@ -908,41 +936,60 @@ extern "C" { // Vocab // - LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); + LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); - LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); + LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token); - LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token); + LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token); // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) - LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); + LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token); // Identify if Token Id is a control token or a render-able token - LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token); + LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token); // Special tokens - LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence - LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence - LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn - LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification - LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator - LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line - LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding + LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence + LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence + LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn + LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator + LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line + LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding - LLAMA_API bool llama_add_bos_token(const struct llama_model * model); - LLAMA_API bool llama_add_eos_token(const struct llama_model * model); + LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); + LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); - // infill tokens - DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead"); - DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead"); - DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead"); + LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); - LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model); + DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead"); + DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); + DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); + DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); + DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead"); + DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead"); + DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead"); + DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead"); + DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead"); + DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead"); + DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead"); + + // CLS is equivalent to BOS + DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification + "use llama_vocab_bos instead"); // // Tokenization @@ -958,7 +1005,7 @@ extern "C" { /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated /// as plaintext. Does not insert a leading space. LLAMA_API int32_t llama_tokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const char * text, int32_t text_len, llama_token * tokens, @@ -972,7 +1019,7 @@ extern "C" { // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') // @param special If true, special tokens are rendered in the output. LLAMA_API int32_t llama_token_to_piece( - const struct llama_model * model, + const struct llama_vocab * vocab, llama_token token, char * buf, int32_t length, @@ -986,7 +1033,7 @@ extern "C" { /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so. /// @param unparse_special If true, special tokens are rendered in the output. LLAMA_API int32_t llama_detokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const llama_token * tokens, int32_t n_tokens, char * text, @@ -1000,7 +1047,7 @@ extern "C" { /// Apply chat template. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" - /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template + /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat @@ -1009,7 +1056,6 @@ extern "C" { /// @param length The size of the allocated buffer /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. LLAMA_API int32_t llama_chat_apply_template( - const struct llama_model * model, const char * tmpl, const struct llama_chat_message * chat, size_t n_msg, @@ -1057,7 +1103,6 @@ extern "C" { // llama_sampler_free(smpl); // // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). - // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab // typedef void * llama_sampler_context_t; @@ -1076,11 +1121,12 @@ extern "C" { }; struct llama_sampler { - struct llama_sampler_i * iface; - llama_sampler_context_t ctx; + const struct llama_sampler_i * iface; + llama_sampler_context_t ctx; }; // mirror of llama_sampler_i: + LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @@ -1110,7 +1156,7 @@ extern "C" { /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), - "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)"); + "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); @@ -1118,7 +1164,7 @@ extern "C" { /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); - /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + /// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. @@ -1133,6 +1179,9 @@ extern "C" { /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -1157,10 +1206,22 @@ extern "C" { float eta); LLAMA_API struct llama_sampler * llama_sampler_init_grammar( - const struct llama_model * model, + const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); + /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 + /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) @@ -1169,8 +1230,9 @@ extern "C" { float penalty_present); // 0.0 = disabled /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 - LLAMA_API struct llama_sampler * llama_sampler_init_dry( - const struct llama_model * model, + LLAMA_API struct llama_sampler * llama_sampler_init_dry( + const struct llama_vocab * vocab, + int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, @@ -1204,7 +1266,7 @@ extern "C" { // 3. discard non-EOG tokens with low prob // 4. if no tokens are left -> pick EOT // - LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model); + LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); diff --git a/llama/llama.cpp/src/llama-adapter.cpp b/llama/llama.cpp/src/llama-adapter.cpp index 9fd7edea3..8a0800463 100644 --- a/llama/llama.cpp/src/llama-adapter.cpp +++ b/llama/llama.cpp/src/llama-adapter.cpp @@ -1,5 +1,7 @@ #include "llama-adapter.h" +#include "llama-impl.h" +#include "llama-mmap.h" #include "llama-model.h" #include @@ -9,7 +11,7 @@ // vec -struct ggml_tensor * llama_control_vector::tensor_for(int il) const { +struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { return nullptr; } @@ -17,7 +19,7 @@ struct ggml_tensor * llama_control_vector::tensor_for(int il) const { return tensors[il]; } -struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { +struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { ggml_tensor * layer_dir = tensor_for(il); if (layer_dir != nullptr) { cur = ggml_add(ctx, cur, layer_dir); @@ -26,12 +28,12 @@ struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, s return cur; } -static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) { +bool llama_adapter_cvec::init(const llama_model & model) { const auto & hparams = model.hparams; - GGML_ASSERT(cvec.tensors.empty()); - GGML_ASSERT(cvec.ctxs.empty()); - GGML_ASSERT(cvec.bufs.empty()); + GGML_ASSERT(tensors.empty()); + GGML_ASSERT(ctxs.empty()); + GGML_ASSERT(bufs.empty()); // create a context for each buffer type std::map ctx_map; @@ -50,7 +52,7 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const } ctx_map[buft] = ctx; - cvec.ctxs.emplace_back(ctx); + ctxs.emplace_back(ctx); return ctx; } @@ -59,21 +61,21 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const }; // make tensors - cvec.tensors.reserve(hparams.n_layer); - cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0 + tensors.reserve(hparams.n_layer); + tensors.push_back(nullptr); // there's never a tensor for layer 0 for (size_t il = 1; il < hparams.n_layer; il++) { - ggml_backend_buffer_type_t buft = llama_model_select_buft(model, il); + ggml_backend_buffer_type_t buft = model.select_buft(il); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__); return false; } ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); - cvec.tensors.push_back(tensor); + tensors.push_back(tensor); } // allocate tensors / buffers and zero - cvec.bufs.reserve(ctx_map.size()); + bufs.reserve(ctx_map.size()); for (auto it : ctx_map) { ggml_backend_buffer_type_t buft = it.first; ggml_context * ctx = it.second; @@ -83,14 +85,13 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const return false; } ggml_backend_buffer_clear(buf, 0); - cvec.bufs.emplace_back(buf); + bufs.emplace_back(buf); } return true; } -int32_t llama_control_vector_apply( - struct llama_control_vector & cvec, +int32_t llama_adapter_cvec::apply( const llama_model & model, const float * data, size_t len, @@ -101,8 +102,8 @@ int32_t llama_control_vector_apply( if (data == nullptr) { // disable the current control vector (but leave allocated for later) - cvec.layer_start = -1; - cvec.layer_end = -1; + layer_start = -1; + layer_end = -1; return 0; } @@ -111,21 +112,21 @@ int32_t llama_control_vector_apply( return 1; } - if (cvec.tensors.empty()) { - if (!llama_control_vector_init(cvec, model)) { + if (tensors.empty()) { + if (!init(model)) { return 1; } } - cvec.layer_start = il_start; - cvec.layer_end = il_end; + layer_start = il_start; + layer_end = il_end; for (size_t il = 1; il < hparams.n_layer; il++) { - assert(cvec.tensors[il] != nullptr); + assert(tensors[il] != nullptr); const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present if (off + n_embd <= len) { - ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il])); + ggml_backend_tensor_set(tensors[il], data + off, 0, n_embd * ggml_element_size(tensors[il])); } } @@ -134,7 +135,7 @@ int32_t llama_control_vector_apply( // lora -llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) { +llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) { const std::string name(w->name); const auto pos = ab_map.find(name); @@ -145,11 +146,7 @@ llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) { return nullptr; } -void llama_lora_adapter_free(struct llama_lora_adapter * adapter) { - delete adapter; -} - -static void llama_lora_adapter_init_impl(struct llama_model & model, const char * path_lora, struct llama_lora_adapter & adapter) { +static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) { LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); ggml_context * ctx_init; @@ -221,7 +218,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char }; // bundle lora_a and lora_b into pairs - std::map ab_map; + std::map ab_map; auto str_endswith = [](const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; }; @@ -231,17 +228,21 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char if (str_endswith(name, ".lora_a")) { replace_all(name, ".lora_a", ""); if (ab_map.find(name) == ab_map.end()) { - ab_map[name] = llama_lora_weight(cur, nullptr); + ab_map[name] = llama_adapter_lora_weight(cur, nullptr); } else { ab_map[name].a = cur; } } else if (str_endswith(name, ".lora_b")) { replace_all(name, ".lora_b", ""); if (ab_map.find(name) == ab_map.end()) { - ab_map[name] = llama_lora_weight(nullptr, cur); + ab_map[name] = llama_adapter_lora_weight(nullptr, cur); } else { ab_map[name].b = cur; } + } else if (str_endswith(name, "_norm.weight")) { + // TODO: add support for norm vector + // for now, we don't really care because most adapters still work fine without it + continue; } else { throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); } @@ -250,25 +251,33 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char // add tensors for (auto & it : ab_map) { const std::string & name = it.first; - llama_lora_weight & w = it.second; + llama_adapter_lora_weight & w = it.second; + bool is_token_embd = str_endswith(name, "token_embd.weight"); if (!w.a || !w.b) { throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); } // device buft and device ctx - auto * model_tensor = llama_model_get_tensor(model, name.c_str()); + const auto * model_tensor = model.get_tensor(name.c_str()); if (!model_tensor) { - throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model"); + throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); } struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); // validate tensor shape - if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { - throw std::runtime_error("tensor '" + name + "' has incorrect shape"); - } - if (w.a->ne[1] != w.b->ne[0]) { - throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + if (is_token_embd) { + // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() + if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + } else { + if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + if (w.a->ne[1] != w.b->ne[0]) { + throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + } } // save tensor to adapter @@ -276,7 +285,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_b, w.b->name); - adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b); + adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b); } // allocate tensors / buffers and zero @@ -318,11 +327,11 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } -struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { - struct llama_lora_adapter * adapter = new llama_lora_adapter(); +struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) { + struct llama_adapter_lora * adapter = new llama_adapter_lora(); try { - llama_lora_adapter_init_impl(*model, path_lora, *adapter); + llama_adapter_lora_init_impl(*model, path_lora, *adapter); return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); @@ -332,3 +341,7 @@ struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, return nullptr; } + +void llama_adapter_lora_free(struct llama_adapter_lora * adapter) { + delete adapter; +} diff --git a/llama/llama.cpp/src/llama-adapter.h b/llama/llama.cpp/src/llama-adapter.h index 5f1870cc8..603fa08f6 100644 --- a/llama/llama.cpp/src/llama-adapter.h +++ b/llama/llama.cpp/src/llama-adapter.h @@ -1,66 +1,74 @@ #pragma once -#include "llama-impl.h" -#include "llama-hparams.h" +#include "llama.h" #include "ggml-cpp.h" +#include #include #include +// TODO: pimpl + // // llama_adapter_cvec // -// TODO: rename to llama_adapter_cvec -struct llama_control_vector { - std::vector ctxs; - std::vector bufs; +struct llama_adapter_cvec { + struct ggml_tensor * tensor_for(int il) const; - std::vector tensors; // per layer + struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; + + int32_t apply( + const llama_model & model, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end); + +private: + bool init(const llama_model & model); int32_t layer_start = -1; int32_t layer_end = -1; - struct ggml_tensor * tensor_for(int il) const; + std::vector ctxs; + std::vector bufs; - struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; + std::vector tensors; // per layer }; -int32_t llama_control_vector_apply( - struct llama_control_vector & cvec, - const llama_model & model, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end); - // // llama_adapter_lora // -// TODO: rename to llama_adapter_lora_weight -struct llama_lora_weight { +struct llama_adapter_lora_weight { struct ggml_tensor * a = nullptr; struct ggml_tensor * b = nullptr; - llama_lora_weight() = default; - llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} + // get actual scale based on rank and alpha + float get_scale(float alpha, float adapter_scale) const { + const float rank = (float) b->ne[0]; + const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale; + return scale; + } + + llama_adapter_lora_weight() = default; + llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} }; -// TODO: rename to llama_adapter_lora -struct llama_lora_adapter { +struct llama_adapter_lora { // map tensor name to lora_a_b - std::unordered_map ab_map; + std::unordered_map ab_map; std::vector ctxs; std::vector bufs; float alpha; - llama_lora_adapter() = default; - ~llama_lora_adapter() = default; + llama_adapter_lora() = default; + ~llama_adapter_lora() = default; - llama_lora_weight * get_weight(struct ggml_tensor * w); + llama_adapter_lora_weight * get_weight(struct ggml_tensor * w); }; diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index b35aeb315..b6f20286b 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -28,6 +28,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, @@ -57,6 +58,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, @@ -107,25 +109,26 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, + { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, - { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, - { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, - { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, - { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, - { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, - { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, - { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, - { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, - { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, - { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, - { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, - { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, - { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, - { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, - { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, - { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, - { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, - { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, + { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, + { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -179,6 +182,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -622,6 +627,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PHIMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PLAMO, { @@ -1036,6 +1062,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, @@ -1182,6 +1211,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, @@ -1199,6 +1229,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, }, }, + { + LLM_ARCH_RWKV6QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_GRANITE, { @@ -1253,6 +1309,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, + { + LLM_ARCH_SOLAR, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_BSKCN_TV, "bskcn_tv" }, + }, + }, { LLM_ARCH_WAVTOKENIZER_DEC, { @@ -1278,24 +1352,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, - { - LLM_ARCH_SOLAR, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_BSKCN_TV, "bskcn_tv" }, - }, - }, { LLM_ARCH_UNKNOWN, { @@ -1399,6 +1455,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1455,10 +1512,11 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; -LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {} +LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { - return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) + : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); } std::string LLM_TN_IMPL::str() const { diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index e8235ae00..ec7422244 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -32,6 +32,7 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_PHI2, LLM_ARCH_PHI3, + LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, @@ -61,6 +62,7 @@ enum llm_arch { LLM_ARCH_NEMOTRON, LLM_ARCH_EXAONE, LLM_ARCH_RWKV6, + LLM_ARCH_RWKV6QWEN2, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, @@ -111,6 +113,7 @@ enum llm_kv { LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_RESIDUAL_SCALE, LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -177,6 +180,8 @@ enum llm_kv { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -256,6 +261,7 @@ enum llm_tensor { LLM_TENSOR_TIME_MIX_LERP_V, LLM_TENSOR_TIME_MIX_LERP_R, LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_LERP_FUSED, LLM_TENSOR_TIME_MIX_FIRST, LLM_TENSOR_TIME_MIX_DECAY, LLM_TENSOR_TIME_MIX_DECAY_W1, @@ -343,9 +349,10 @@ enum llm_tensor_layer { }; struct LLM_KV { - LLM_KV(llm_arch arch); + LLM_KV(llm_arch arch, const char * suffix = nullptr); llm_arch arch; + const char * suffix; std::string operator()(llm_kv kv) const; }; diff --git a/llama/llama.cpp/src/llama-chat.cpp b/llama/llama.cpp/src/llama-chat.cpp index 44670d3d8..028a64794 100644 --- a/llama/llama.cpp/src/llama-chat.cpp +++ b/llama/llama.cpp/src/llama-chat.cpp @@ -35,6 +35,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, + { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, { "monarch", LLM_CHAT_TEMPLATE_MONARCH }, @@ -50,6 +51,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE }, { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, @@ -73,7 +75,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return tmpl.find(haystack) != std::string::npos; }; if (tmpl_contains("<|im_start|>")) { - return LLM_CHAT_TEMPLATE_CHATML; + return tmpl_contains("<|im_sep|>") + ? LLM_CHAT_TEMPLATE_PHI_4 + : LLM_CHAT_TEMPLATE_CHATML; } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { if (tmpl_contains("[SYSTEM_PROMPT]")) { return LLM_CHAT_TEMPLATE_MISTRAL_V7; @@ -112,7 +116,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { return LLM_CHAT_TEMPLATE_PHI_3; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { - return LLM_CHAT_TEMPLATE_FALCON_3; + return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { return LLM_CHAT_TEMPLATE_ZEPHYR; } else if (tmpl_contains("bos_token + message['role']")) { @@ -149,7 +153,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_MINICPM; } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { return LLM_CHAT_TEMPLATE_DEEPSEEK_2; - } else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) { + } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) { return LLM_CHAT_TEMPLATE_DEEPSEEK_3; } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb @@ -269,6 +273,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|assistant|>\n"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_4) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "<|im_sep|>" << message->content << "<|im_end|>"; + } + if (add_ass) { + ss << "<|im_start|>assistant<|im_sep|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) { // Falcon 3 for (auto message : chat) { @@ -429,6 +441,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|assistant|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { // MiniCPM-3B-OpenHermes-2.5-v2-GGUF for (auto message : chat) { diff --git a/llama/llama.cpp/src/llama-chat.h b/llama/llama.cpp/src/llama-chat.h index b8e94d9ef..2f6a0e3e2 100644 --- a/llama/llama.cpp/src/llama-chat.h +++ b/llama/llama.cpp/src/llama-chat.h @@ -15,6 +15,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V7, LLM_CHAT_TEMPLATE_PHI_3, + LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_FALCON_3, LLM_CHAT_TEMPLATE_ZEPHYR, LLM_CHAT_TEMPLATE_MONARCH, @@ -30,6 +31,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGML_3, LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_GLMEDGE, LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_RWKV_WORLD, diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 9d0e7ca36..7b22fe137 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -1,5 +1,8 @@ #include "llama-context.h" +#include "llama-impl.h" +#include "llama-mmap.h" + #include #include #include @@ -513,7 +516,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { auto * buft = ggml_backend_cpu_buffer_type(); // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory - auto * output_dev = lctx.model.dev_output.dev; + auto * output_dev = lctx.model.dev_output(); auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; if (output_dev_host_buft) { buft = output_dev_host_buft; diff --git a/llama/llama.cpp/src/llama-context.h b/llama/llama.cpp/src/llama-context.h index 4980a60e8..cf12c9d72 100644 --- a/llama/llama.cpp/src/llama-context.h +++ b/llama/llama.cpp/src/llama-context.h @@ -22,12 +22,12 @@ struct llama_context { const struct llama_model & model; - struct llama_cparams cparams; - struct llama_sbatch sbatch; // TODO: revisit if needed - struct llama_kv_cache kv_self; - struct llama_control_vector cvec; + struct llama_cparams cparams; + struct llama_sbatch sbatch; // TODO: revisit if needed + struct llama_kv_cache kv_self; + struct llama_adapter_cvec cvec; - std::unordered_map lora_adapters; + std::unordered_map lora; std::vector backends; std::vector> set_n_threads_fns; diff --git a/llama/llama.cpp/src/llama-grammar.cpp b/llama/llama.cpp/src/llama-grammar.cpp index 186dc9a25..98af1ba39 100644 --- a/llama/llama.cpp/src/llama-grammar.cpp +++ b/llama/llama.cpp/src/llama-grammar.cpp @@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence( size_t last_sym_start = rule.size(); const char * pos = src; - auto handle_repetitions = [&](int min_times, int max_times) { + auto handle_repetitions = [&](int min_times, int max_times) { - if (last_sym_start == rule.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); - } + if (last_sym_start == rule.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } - // apply transformation to previous symbol (last_sym_start to end) according to - // the following rewrite rules: - // S{m,n} --> S S S (m times) S'(n-m) - // S'(x) ::= S S'(x-1) | - // (... n-m definitions of these S' rules ...) - // S'(1) ::= S | - // S{m,} --> S S S (m times) S' - // S' ::= S S' | - // S* --> S{0,} - // --> S' ::= S S' | - // S+ --> S{1,} - // --> S S' - // S' ::= S S' | - // S? --> S{0,1} - // --> S' - // S' ::= S | + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | - llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); - if (min_times == 0) { - rule.resize(last_sym_start); - } else { - // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { - rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); - } - } - - uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; - - llama_grammar_rule rec_rule(prev_rule); - for (int i = 0; i < n_opt; i++) { - rec_rule.resize(prev_rule.size()); - uint32_t rec_rule_id = generate_symbol_id( rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); - } - rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - rec_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule( rec_rule_id, rec_rule); - last_rec_rule_id = rec_rule_id; - } - if (n_opt > 0) { - rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); - } - }; - - while (*pos) { - if (*pos == '"') { // literal string - pos++; - last_sym_start = rule.size(); - while (*pos != '"') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '[') { // char range(s) - pos++; - enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - if (*pos == '^') { - pos++; - start_type = LLAMA_GRETYPE_CHAR_NOT; - } - last_sym_start = rule.size(); - while (*pos != ']') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - enum llama_gretype type = last_sym_start < rule.size() - ? LLAMA_GRETYPE_CHAR_ALT - : start_type; - - rule.push_back({type, char_pair.first}); - if (pos[0] == '-' && pos[1] != ']') { - if (!pos[1]) { - throw std::runtime_error("unexpected end of input"); - } - auto endchar_pair = parse_char(pos + 1); - pos = endchar_pair.second; - rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); - } - } - pos = parse_space(pos + 1, is_nested); - } else if (is_word_char(*pos)) { // rule reference - const char * name_end = parse_name(pos); - uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); - pos = parse_space(name_end, is_nested); - last_sym_start = rule.size(); - rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - } else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule - pos = parse_space(pos + 1, true); - uint32_t sub_rule_id = generate_symbol_id(rule_name); - pos = parse_alternates(pos, rule_name, sub_rule_id, true); - last_sym_start = rule.size(); - // output reference to synthesized rule - rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - if (*pos != ')') { - throw std::runtime_error(std::string("expecting ')' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '.') { // any char - last_sym_start = rule.size(); - rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, -1); - } else if (*pos == '+') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(1, -1); - } else if (*pos == '?') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, 1); - } else if (*pos == '{') { - pos = parse_space(pos + 1, is_nested); - - if (!is_digit_char(*pos)) { - throw std::runtime_error(std::string("expecting an int at ") + pos); - } - const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - - int max_times = -1; - - if (*pos == '}') { - max_times = min_times; - pos = parse_space(pos + 1, is_nested); - } else if (*pos == ',') { - pos = parse_space(pos + 1, is_nested); - - if (is_digit_char(*pos)) { - const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - } - - if (*pos != '}') { - throw std::runtime_error(std::string("expecting '}' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else { - throw std::runtime_error(std::string("expecting ',' at ") + pos); - } - handle_repetitions(min_times, max_times); - } else { - break; + llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + if (min_times == 0) { + rule.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); } } - return pos; + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + llama_grammar_rule rec_rule(prev_rule); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(prev_rule.size()); + uint32_t rec_rule_id = generate_symbol_id( rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule( rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = rule.size(); + while (*pos != '"') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = rule.size(); + while (*pos != ']') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < rule.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + rule.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + if (!pos[1]) { + throw std::runtime_error("unexpected end of input"); + } + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(rule_name); + pos = parse_alternates(pos, rule_name, sub_rule_id, true); + last_sym_start = rule.size(); + // output reference to synthesized rule + rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); + } else { + break; + } } + return pos; +} const char * llama_grammar_parser::parse_rule(const char * src) { - const char * name_end = parse_name(src); - const char * pos = parse_space(name_end, false); - size_t name_len = name_end - src; - uint32_t rule_id = get_symbol_id(src, name_len); - const std::string name(src, name_len); + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(src, name_len); + const std::string name(src, name_len); - if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { - throw std::runtime_error(std::string("expecting ::= at ") + pos); - } - pos = parse_space(pos + 3, true); - - pos = parse_alternates(pos, name, rule_id, false); - - if (*pos == '\r') { - pos += pos[1] == '\n' ? 2 : 1; - } else if (*pos == '\n') { - pos++; - } else if (*pos) { - throw std::runtime_error(std::string("expecting newline or end at ") + pos); - } - return parse_space(pos, true); + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); +} bool llama_grammar_parser::parse(const char * src) { try { @@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) { } } } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); rules.clear(); return false; } @@ -960,10 +960,28 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy =*/ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_tokens = */ {}, + /* .trigger_words = */ {}, + }; } -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it @@ -1035,10 +1053,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, } } while (true); + std::vector vec_trigger_tokens; + std::vector vec_trigger_words; + for (size_t i = 0; i < num_trigger_tokens; i++) { + GGML_ASSERT(trigger_tokens != nullptr); + vec_trigger_tokens.push_back(trigger_tokens[i]); + } + for (size_t i = 0; i < num_trigger_words; i++) { + GGML_ASSERT(trigger_words != nullptr); + vec_trigger_words.push_back(trigger_words[i]); + } + // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + std::move(vec_trigger_tokens), + std::move(vec_trigger_words), + }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1055,6 +1094,11 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.rules, grammar.stacks, grammar.partial_utf8, + grammar.lazy, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_tokens, + grammar.trigger_words, }; // redirect elements in stacks to point to new rules @@ -1076,6 +1120,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { GGML_ASSERT(grammar.vocab != nullptr); + if (grammar.awaiting_trigger) { + return; + } + bool allow_eog = false; for (const auto & stack : grammar.stacks) { if (stack.empty()) { @@ -1092,9 +1140,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; - const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); + const std::string & piece = grammar.vocab->token_to_piece(id); - if (llama_token_is_eog_impl(*grammar.vocab, id)) { + if (grammar.vocab->is_eog(id)) { if (!allow_eog) { cur_p->data[i].logit = -INFINITY; } @@ -1115,7 +1163,35 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); - if (llama_token_is_eog_impl(*grammar.vocab, token)) { + const auto & piece = grammar.vocab->token_to_piece(token); + + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, piece); + LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); + return; + } else { + // TODO: consider a smarter incremental substring search algorithm (store last position to search from). + grammar.trigger_buffer += piece; + for (const auto & word : grammar.trigger_words) { + auto pos = grammar.trigger_buffer.find(word); + if (pos != std::string::npos) { + grammar.awaiting_trigger = false; + auto constrained_str = grammar.trigger_buffer.substr(pos); + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, constrained_str); + LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str()); + return; + } + } + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); + return; + } + } + + if (grammar.vocab->is_eog(token)) { for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; @@ -1124,8 +1200,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); + llama_grammar_accept_str(grammar, piece); +} +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; @@ -1135,5 +1213,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token } grammar.partial_utf8 = decoded.second; - GGML_ASSERT(!grammar.stacks.empty()); + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); + } } diff --git a/llama/llama.cpp/src/llama-grammar.h b/llama/llama.cpp/src/llama-grammar.h index f8b40c651..b143d834c 100644 --- a/llama/llama.cpp/src/llama-grammar.h +++ b/llama/llama.cpp/src/llama-grammar.h @@ -114,6 +114,15 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + + // lazy grammars wait for trigger words or tokens before constraining the sampling. + // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // (useful e.g. for tool_choice=required) + bool lazy = false; + bool awaiting_trigger = false; // Initialized to true for lazy grammars only + std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). + std::vector trigger_words; }; // @@ -127,7 +136,15 @@ struct llama_grammar * llama_grammar_init_impl( size_t n_rules, size_t start_rule_index); -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); void llama_grammar_free_impl(struct llama_grammar * grammar); @@ -141,3 +158,7 @@ void llama_grammar_apply_impl( void llama_grammar_accept_impl( struct llama_grammar & grammar, llama_token token); + +void llama_grammar_accept_str( + struct llama_grammar & grammar, + const std::string & piece); diff --git a/llama/llama.cpp/src/llama-hparams.cpp b/llama/llama.cpp/src/llama-hparams.cpp index 42f8a58ff..0b8410281 100644 --- a/llama/llama.cpp/src/llama-hparams.cpp +++ b/llama/llama.cpp/src/llama-hparams.cpp @@ -54,7 +54,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { uint32_t llama_hparams::n_embd_k_s() const { if (wkv_head_size != 0) { // for RWKV models - return 2 * n_embd; + return token_shift_count * n_embd; } // TODO: maybe support other convolution strides than 1 @@ -82,4 +82,4 @@ bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const { bool llama_hparams::cross_attention_layers(uint32_t il) const { return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); -} +} \ No newline at end of file diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index f826cd9ac..053830461 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -30,7 +30,6 @@ struct llama_hparams { bool use_par_res; bool swin_norm; - uint32_t n_vocab = 0; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_embd_features = 0; @@ -41,8 +40,8 @@ struct llama_hparams { uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; uint32_t n_expert_used = 0; - uint32_t n_vocab_type = 0; // for BERT-style token types uint32_t n_rel_attn_bkts = 0; + uint32_t n_vocab = 0; // for WavTokenizer struct llama_hparams_posnet posnet; @@ -79,6 +78,7 @@ struct llama_hparams { uint32_t time_mix_extra_dim = 0; uint32_t time_decay_extra_dim = 0; uint32_t wkv_head_size = 0; + uint32_t token_shift_count = 2; float rope_attn_factor = 1.0f; float rope_freq_base_train; @@ -141,7 +141,7 @@ struct llama_hparams { // Block skip connection bool n_bskcn(uint32_t n, uint32_t il) const; - // cross attention layers + // cross attention layers bool cross_attention_layers(uint32_t il) const; }; diff --git a/llama/llama.cpp/src/llama-impl.cpp b/llama/llama.cpp/src/llama-impl.cpp index a05ba4f63..6ec709dd3 100644 --- a/llama/llama.cpp/src/llama-impl.cpp +++ b/llama/llama.cpp/src/llama-impl.cpp @@ -1,5 +1,6 @@ #include "llama-impl.h" +#include "gguf.h" #include "llama.h" #include @@ -138,7 +139,7 @@ std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { { const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); int arr_n = gguf_get_arr_n(ctx_gguf, i); - const void * data = gguf_get_arr_data(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); std::stringstream ss; ss << "["; for (int j = 0; j < arr_n; j++) { diff --git a/llama/llama.cpp/src/llama-impl.h b/llama/llama.cpp/src/llama-impl.h index 12d1fb082..02b1d07f8 100644 --- a/llama/llama.cpp/src/llama-impl.h +++ b/llama/llama.cpp/src/llama-impl.h @@ -6,13 +6,13 @@ #include #ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif #else -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) +# define LLAMA_ATTRIBUTE_FORMAT(...) #endif // diff --git a/llama/llama.cpp/src/llama-kv-cache.cpp b/llama/llama.cpp/src/llama-kv-cache.cpp index cf814dbe5..b541c5a37 100644 --- a/llama/llama.cpp/src/llama-kv-cache.cpp +++ b/llama/llama.cpp/src/llama-kv-cache.cpp @@ -72,39 +72,6 @@ bool llama_kv_cache_init( cache.v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { - // for cross attention layers - if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const llama_model::buft_list_t * buft_list; - if (offload) { - buft_list = model.dev_layer.at(i).buft_list; - } else { - buft_list = &model.cpu_buft_list; - } - ggml_backend_buffer_type_t buft = select_buft(*buft_list, - [&](ggml_context * ctx) { - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) { - return k; - } - ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type); - }); - ggml_context * ctx = ctx_for_buft(buft); - - if (!ctx) { - LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); - return false; - } - ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); - ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); - continue; - } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -112,7 +79,7 @@ bool llama_kv_cache_init( ggml_backend_buffer_type_t buft; if (offload) { - auto * dev = model.dev_layer.at(i).dev; + auto * dev = model.dev_layer(i); buft = ggml_backend_dev_buffer_type(dev); } else { buft = ggml_backend_cpu_buffer_type(); @@ -124,8 +91,17 @@ bool llama_kv_cache_init( return false; } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k, *v; + + // for cross attention layers + if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) { + k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); + v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); + } else { + k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + } + ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); @@ -152,10 +128,10 @@ bool llama_kv_cache_init( struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_ubatch & batch) { - const uint32_t n_tokens = batch.n_tokens; - const uint32_t n_seqs = batch.n_seqs; - const uint32_t n_seq_tokens = batch.n_seq_tokens; + const struct llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; if (cache.recurrent) { // For recurrent state architectures (like Mamba or RWKV), @@ -163,16 +139,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // A slot should be always be contiguous. // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs); int32_t min = cache.size - 1; int32_t max = 0; // everything should fit if all seq_ids are smaller than the max for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = batch.n_seq_id[s]; + const uint32_t n_seq_id = ubatch.n_seq_id[s]; for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; + const llama_seq_id seq_id = ubatch.seq_id[s][j]; if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { // too big seq_id @@ -231,7 +207,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // find usable cell range for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; + const llama_seq_id seq_id = ubatch.seq_id[s][0]; llama_kv_cell & seq_meta = cache.cells[seq_id]; bool has_cell = false; if (seq_meta.tail >= 0) { @@ -270,7 +246,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // gather and re-order for (uint32_t s = 0; s < n_seqs; ++s) { int32_t dst_id = s + min; - int32_t src_id = cache.cells[batch.seq_id[s][0]].tail; + int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail; if (dst_id != src_id) { llama_kv_cell & dst_cell = cache.cells[dst_id]; llama_kv_cell & src_cell = cache.cells[src_id]; @@ -291,7 +267,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // update the pos of the used seqs for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; int32_t cell_id = s + min; llama_kv_cell & cell = cache.cells[cell_id]; @@ -299,12 +275,12 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); } cell.pos = last_pos; cell.seq_id.clear(); - for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; cell.seq_id.insert(seq_id); cache.cells[seq_id].tail = cell_id; } @@ -358,10 +334,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; - cache.cells[cache.head + k].pos = batch.pos[k]; + cache.cells[cache.head + k].pos = ubatch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { - cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { + cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]); } } } diff --git a/llama/llama.cpp/src/llama-kv-cache.h b/llama/llama.cpp/src/llama-kv-cache.h index dca6f3998..1ed688e3b 100644 --- a/llama/llama.cpp/src/llama-kv-cache.h +++ b/llama/llama.cpp/src/llama-kv-cache.h @@ -37,7 +37,7 @@ struct llama_kv_cache { bool can_shift = false; // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_internal also uses it, so it + // for a free KV slot. llama_decode_impl also uses it, so it // cannot be freely changed after a slot has been allocated. uint32_t head = 0; uint32_t size = 0; diff --git a/llama/llama.cpp/src/llama-mmap.cpp b/llama/llama.cpp/src/llama-mmap.cpp index a99326335..b716630a8 100644 --- a/llama/llama.cpp/src/llama-mmap.cpp +++ b/llama/llama.cpp/src/llama-mmap.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #ifdef __has_include #if __has_include() @@ -35,7 +36,7 @@ // TODO: consider moving to llama-impl.h if needed in more places #if defined(_WIN32) -std::string llama_format_win_err(DWORD err) { +static std::string llama_format_win_err(DWORD err) { LPSTR buf; size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); @@ -241,12 +242,16 @@ llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } size_t llama_file::size() const { return pimpl->size; } -int llama_file::fileno() const { +int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); +#else +#if defined(fileno) + return fileno(pimpl->fp); #else return ::fileno(pimpl->fp); #endif +#endif } void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } @@ -265,7 +270,7 @@ struct llama_mmap::impl { impl(struct llama_file * file, size_t prefetch, bool numa) { size = file->size(); - int fd = file->fileno(); + int fd = file->file_id(); int flags = MAP_SHARED; if (numa) { prefetch = 0; } #ifdef __linux__ @@ -357,7 +362,7 @@ struct llama_mmap::impl { size = file->size(); - HANDLE hFile = (HANDLE) _get_osfhandle(file->fileno()); + HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); diff --git a/llama/llama.cpp/src/llama-mmap.h b/llama/llama.cpp/src/llama-mmap.h index 6bcddee8c..4e5aec3f4 100644 --- a/llama/llama.cpp/src/llama-mmap.h +++ b/llama/llama.cpp/src/llama-mmap.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -18,7 +19,7 @@ struct llama_file { size_t tell() const; size_t size() const; - int fileno() const; + int file_id() const; // fileno overload void seek(size_t offset, int whence) const; diff --git a/llama/llama.cpp/src/llama-model-loader.cpp b/llama/llama.cpp/src/llama-model-loader.cpp index b12d65666..45d087213 100644 --- a/llama/llama.cpp/src/llama-model-loader.cpp +++ b/llama/llama.cpp/src/llama-model-loader.cpp @@ -7,6 +7,10 @@ #include #include +static const size_t kiB = 1024; +static const size_t MiB = 1024*kiB; +static const size_t GiB = 1024*MiB; + const char * llama_file_version_name(llama_fver version) { switch (version) { case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; @@ -17,8 +21,78 @@ const char * llama_file_version_name(llama_fver version) { return "unknown"; } +static std::string llama_model_ftype_name(llama_ftype ftype) { + if (ftype & LLAMA_FTYPE_GUESSED) { + return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; + } + + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + + default: return "unknown, may not work"; + } +} + +// return a list of splits for a given path +// for example, given "-00002-of-00004.gguf", returns list of all 4 splits +static std::vector llama_get_list_splits(const std::string & path, const int idx, const int n_split) { + std::vector paths; + std::string split_prefix; + std::vector buf(llama_path_max(), 0); + + { + int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split); + if (!ret) { + throw std::runtime_error(format("invalid split file name: %s", path.c_str())); + } + split_prefix = std::string(buf.data(), ret); + } + + if (split_prefix.empty()) { + throw std::runtime_error(format("invalid split file: %s", path.c_str())); + } + + for (int idx = 0; idx < n_split; ++idx) { + int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split); + paths.push_back(std::string(buf.data(), ret)); + } + + return paths; +} + namespace GGUFMeta { - template + template struct GKV_Base_Type { static constexpr gguf_type gt = gt_; @@ -60,10 +134,11 @@ namespace GGUFMeta { public: static constexpr gguf_type gt = GGUF_TYPE_ARRAY; static ArrayInfo getter(const gguf_context *ctx, const int k) { + const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); return ArrayInfo { - gguf_get_arr_type(ctx, k), + arr_type, size_t(gguf_get_arr_n(ctx, k)), - gguf_get_arr_data(ctx, k), + arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), }; } }; @@ -368,7 +443,12 @@ namespace GGUFMeta { template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required); -llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { +llama_model_loader::llama_model_loader( + const std::string & fname, + std::vector & splits, + bool use_mmap, + bool check_tensors, + const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -380,6 +460,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, } } + // Load the main GGUF struct ggml_context * ctx = NULL; struct gguf_init_params params = { /*.no_alloc = */ true, @@ -415,35 +496,54 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, // Load additional GGML contexts if (n_split > 1) { + // make sure the main file is loaded first uint16_t idx = 0; - get_key(llm_kv(LLM_KV_SPLIT_NO), idx); + const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); + get_key(kv_split_no, idx); if (idx != 0) { - throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx)); + throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); } - std::vector split_prefix(llama_path_max(), 0); - if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) { - throw std::runtime_error(format("invalid split file: %s", fname.c_str())); + // generate list of splits if needed + if (splits.empty()) { + splits = llama_get_list_splits(fname, idx, n_split); + } + + // in case user give a custom list of splits, check if it matches the expected number + if (n_split != (uint16_t)splits.size()) { + throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); } if (trace > 0) { LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); } - std::vector split_path(llama_path_max(), 0); + // load other splits for (idx = 1; idx < n_split; idx++) { - llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split); + const char * fname_split = splits[idx].c_str(); struct gguf_init_params split_params = { /*.no_alloc = */ true, /*.ctx = */ &ctx, }; - gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) }; + gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data())); + throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split)); } - files.emplace_back(new llama_file(split_path.data(), "rb")); + // check idx + { + const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); + if (kid < 0) { + throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + } + int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); + if (idx_gguf != idx) { + throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + } + } + + files.emplace_back(new llama_file(fname_split, "rb")); contexts.emplace_back(ctx); // Save tensors data offset info of the shard. @@ -556,7 +656,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, const enum gguf_type type = gguf_get_kv_type(meta.get(), i); const std::string type_name = type == GGUF_TYPE_ARRAY - ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) : gguf_type_name(type); std::string value = gguf_kv_to_str(meta.get(), i); @@ -722,7 +822,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps for (const auto & file : files) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); - std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn())); + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); mmaps_used.emplace_back(mapping->size(), 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); @@ -1011,3 +1111,17 @@ bool llama_model_loader::load_all_data( return true; } + +std::string llama_model_loader::ftype_name() const { + return llama_model_ftype_name(ftype); +} + +void llama_model_loader::print_info() const { + LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver)); + LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str()); + if (n_bytes < GiB) { + LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements); + } else { + LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements); + } +} diff --git a/llama/llama.cpp/src/llama-model-loader.h b/llama/llama.cpp/src/llama-model-loader.h index 1ec478195..fe35404b2 100644 --- a/llama/llama.cpp/src/llama-model-loader.h +++ b/llama/llama.cpp/src/llama-model-loader.h @@ -90,7 +90,12 @@ struct llama_model_loader { size_t size_data = 0; std::vector> mmaps_used; - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p); + llama_model_loader( + const std::string & fname, + std::vector & splits, // optional, only need if the split does not follow naming scheme + bool use_mmap, + bool check_tensors, + const struct llama_model_kv_override * param_overrides_p); template typename std::enable_if::value, bool>::type @@ -155,4 +160,8 @@ struct llama_model_loader { llama_mlocks * lmlocks, llama_progress_callback progress_callback, void * progress_callback_user_data); + + std::string ftype_name() const; + + void print_info() const; }; diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 4f9bbf90e..218190805 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -1,128 +1,85 @@ #include "llama-model.h" #include "llama-impl.h" +#include "llama-mmap.h" #include "llama-model-loader.h" -#include "unicode.h" // TODO: remove +#include "ggml-cpp.h" #include #include +#include #include +#include #include #include -static const size_t kiB = 1024; -static const size_t MiB = 1024*kiB; -static const size_t GiB = 1024*MiB; - const char * llm_type_name(llm_type type) { switch (type) { - case MODEL_14M: return "14M"; - case MODEL_17M: return "17M"; - case MODEL_22M: return "22M"; - case MODEL_33M: return "33M"; - case MODEL_60M: return "60M"; - case MODEL_70M: return "70M"; - case MODEL_80M: return "80M"; - case MODEL_109M: return "109M"; - case MODEL_137M: return "137M"; - case MODEL_160M: return "160M"; - case MODEL_220M: return "220M"; - case MODEL_250M: return "250M"; - case MODEL_270M: return "270M"; - case MODEL_335M: return "335M"; - case MODEL_410M: return "410M"; - case MODEL_450M: return "450M"; - case MODEL_770M: return "770M"; - case MODEL_780M: return "780M"; - case MODEL_0_5B: return "0.5B"; - case MODEL_1B: return "1B"; - case MODEL_1_3B: return "1.3B"; - case MODEL_1_4B: return "1.4B"; - case MODEL_1_5B: return "1.5B"; - case MODEL_1_6B: return "1.6B"; - case MODEL_2B: return "2B"; - case MODEL_2_8B: return "2.8B"; - case MODEL_3B: return "3B"; - case MODEL_4B: return "4B"; - case MODEL_6B: return "6B"; - case MODEL_6_9B: return "6.9B"; - case MODEL_7B: return "7B"; - case MODEL_8B: return "8B"; - case MODEL_9B: return "9B"; - case MODEL_11B: return "11B"; - case MODEL_12B: return "12B"; - case MODEL_13B: return "13B"; - case MODEL_14B: return "14B"; - case MODEL_15B: return "15B"; - case MODEL_16B: return "16B"; - case MODEL_20B: return "20B"; - case MODEL_30B: return "30B"; - case MODEL_32B: return "32B"; - case MODEL_34B: return "34B"; - case MODEL_35B: return "35B"; - case MODEL_40B: return "40B"; - case MODEL_65B: return "65B"; - case MODEL_70B: return "70B"; - case MODEL_236B: return "236B"; - case MODEL_314B: return "314B"; - case MODEL_671B: return "671B"; - case MODEL_SMALL: return "0.1B"; - case MODEL_MEDIUM: return "0.4B"; - case MODEL_LARGE: return "0.8B"; - case MODEL_XL: return "1.5B"; - case MODEL_A1_7B: return "A1.7B"; - case MODEL_A2_7B: return "A2.7B"; - case MODEL_8x7B: return "8x7B"; - case MODEL_8x22B: return "8x22B"; - case MODEL_16x12B: return "16x12B"; - case MODEL_10B_128x3_66B: return "10B+128x3.66B"; - case MODEL_57B_A14B: return "57B.A14B"; - case MODEL_27B: return "27B"; - default: return "?B"; - } -} - -static std::string llama_model_ftype_name(llama_ftype ftype) { - if (ftype & LLAMA_FTYPE_GUESSED) { - return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; - } - - switch (ftype) { - case LLAMA_FTYPE_ALL_F32: return "all F32"; - case LLAMA_FTYPE_MOSTLY_F16: return "F16"; - case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; - case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; - case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; - case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; - case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; - case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; - case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; - case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; - case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; - case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; - - default: return "unknown, may not work"; + case LLM_TYPE_14M: return "14M"; + case LLM_TYPE_17M: return "17M"; + case LLM_TYPE_22M: return "22M"; + case LLM_TYPE_33M: return "33M"; + case LLM_TYPE_60M: return "60M"; + case LLM_TYPE_70M: return "70M"; + case LLM_TYPE_80M: return "80M"; + case LLM_TYPE_109M: return "109M"; + case LLM_TYPE_137M: return "137M"; + case LLM_TYPE_160M: return "160M"; + case LLM_TYPE_220M: return "220M"; + case LLM_TYPE_250M: return "250M"; + case LLM_TYPE_270M: return "270M"; + case LLM_TYPE_335M: return "335M"; + case LLM_TYPE_410M: return "410M"; + case LLM_TYPE_450M: return "450M"; + case LLM_TYPE_770M: return "770M"; + case LLM_TYPE_780M: return "780M"; + case LLM_TYPE_0_5B: return "0.5B"; + case LLM_TYPE_1B: return "1B"; + case LLM_TYPE_1_3B: return "1.3B"; + case LLM_TYPE_1_4B: return "1.4B"; + case LLM_TYPE_1_5B: return "1.5B"; + case LLM_TYPE_1_6B: return "1.6B"; + case LLM_TYPE_2B: return "2B"; + case LLM_TYPE_2_8B: return "2.8B"; + case LLM_TYPE_3B: return "3B"; + case LLM_TYPE_4B: return "4B"; + case LLM_TYPE_6B: return "6B"; + case LLM_TYPE_6_9B: return "6.9B"; + case LLM_TYPE_7B: return "7B"; + case LLM_TYPE_8B: return "8B"; + case LLM_TYPE_9B: return "9B"; + case LLM_TYPE_11B: return "11B"; + case LLM_TYPE_12B: return "12B"; + case LLM_TYPE_13B: return "13B"; + case LLM_TYPE_14B: return "14B"; + case LLM_TYPE_15B: return "15B"; + case LLM_TYPE_16B: return "16B"; + case LLM_TYPE_20B: return "20B"; + case LLM_TYPE_30B: return "30B"; + case LLM_TYPE_32B: return "32B"; + case LLM_TYPE_34B: return "34B"; + case LLM_TYPE_35B: return "35B"; + case LLM_TYPE_40B: return "40B"; + case LLM_TYPE_65B: return "65B"; + case LLM_TYPE_70B: return "70B"; + case LLM_TYPE_236B: return "236B"; + case LLM_TYPE_314B: return "314B"; + case LLM_TYPE_671B: return "671B"; + case LLM_TYPE_SMALL: return "0.1B"; + case LLM_TYPE_MEDIUM: return "0.4B"; + case LLM_TYPE_LARGE: return "0.8B"; + case LLM_TYPE_XL: return "1.5B"; + case LLM_TYPE_A1_7B: return "A1.7B"; + case LLM_TYPE_A2_7B: return "A2.7B"; + case LLM_TYPE_8x7B: return "8x7B"; + case LLM_TYPE_8x22B: return "8x22B"; + case LLM_TYPE_16x12B: return "16x12B"; + case LLM_TYPE_16x3_8B: return "16x3.8B"; + case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B"; + case LLM_TYPE_57B_A14B: return "57B.A14B"; + case LLM_TYPE_27B: return "27B"; + default: return "?B"; } } @@ -134,44 +91,6 @@ static const char * llama_expert_gating_func_name(llama_expert_gating_func_type } } -std::string llama_model_arch_name (const llama_model & model) { - return llm_arch_name(model.arch); -} - -std::string llama_model_type_name (const llama_model & model) { - return llm_type_name(model.type); -} - -std::string llama_model_ftype_name(const llama_model & model) { - return llama_model_ftype_name(model.ftype); -} - -ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il) { - return select_buft( - *model.dev_layer.at(il).buft_list, - [&](ggml_context * ctx) { - ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd); - ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd); - return ggml_add(ctx, cur, layer_dir); - }); -} - -struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name) { - auto it = std::find_if(model.tensors_by_name.begin(), model.tensors_by_name.end(), - [name](const std::pair & it) { - return it.first == name; - }); - if (it == model.tensors_by_name.end()) { - return nullptr; - } - - return it->second; -} - -size_t llama_model_max_nodes(const llama_model & model) { - return std::max(8192, model.tensors_by_name.size()*5); -} - static const std::map LLAMA_ROPE_SCALING_TYPES = { { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, @@ -189,37 +108,284 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } -// NOTE: avoid ever using this except for building the token_to_piece caches -static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { - std::string piece; - piece.resize(piece.capacity()); // using string internal cache - const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - if (n_chars < 0) { - piece.resize(-n_chars); - int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - GGML_ASSERT(check == -n_chars); - } - else { - piece.resize(n_chars); +// checks if the weight tensor can be used with the specified buffer type and device +static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + GGML_ASSERT(w != nullptr); + + if (op == GGML_OP_NONE) { + return true; } - return piece; + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + switch (op) { + case GGML_OP_GET_ROWS: + { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_get_rows(ctx, w, b); + } break; + case GGML_OP_MUL_MAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } break; + case GGML_OP_MUL_MAT_ID: + { + int n_expert_used = hparams.n_expert_used; + ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_mul_mat_id(ctx, w, b, ids); + } break; + case GGML_OP_ADD: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_add(ctx, a, w); + } break; + case GGML_OP_MUL: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_mul(ctx, a, w); + } break; + case GGML_OP_DIV: + { + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); + op_tensor = ggml_div(ctx, a, w); + } break; + case GGML_OP_ROPE: + { + int n_embd_head = hparams.n_embd_head_v; + int n_head = hparams.n_head(); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_rope_ext( + ctx, a, b, w, + 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ); + + } break; + case GGML_OP_SSM_CONV: + { + // FIXME + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + op_tensor = ggml_ssm_conv(ctx, conv_x, w); + } break; + case GGML_OP_SSM_SCAN: + { + // FIXME + const int64_t d_state = w->ne[0]; + const int64_t d_inner = w->ne[1]; + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 1; + ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); + ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + } break; + case GGML_OP_RWKV_WKV6: + { + // FIXME + const int64_t S = 123; + const int64_t H = 123; + const int64_t n_tokens = 123; + const int64_t n_seqs = 123; + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * tf = w; + ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); + } break; + case GGML_OP_IM2COL: + { + const int n_embd = hparams.n_embd; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); + op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); + } break; + default: + GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + + return op_supported; } -void llm_load_stats(llama_model_loader & ml, llama_model & model) { - model.n_elements = ml.n_elements; - model.n_bytes = ml.n_bytes; +// lists of buffer types used for each layer +using buft_list_t = std::vector>; + +// find the first buffer type in the list that can use the tensor +static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t & buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { + return cur_buft; + } + } + return nullptr; } -void llm_load_arch(llama_model_loader & ml, llama_model & model) { - model.arch = ml.get_arch(); - if (model.arch == LLM_ARCH_UNKNOWN) { +// CPU: ACCEL -> CPU extra -> GPU host -> CPU +static buft_list_t make_cpu_buft_list(const std::vector & devices) { + buft_list_t buft_list; + + // add ACCEL buffer types + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + auto * buft = ggml_backend_dev_buffer_type(dev); + // skip + if (buft != ggml_backend_cpu_buffer_type()) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add extra buffer types + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // add a host buffer type + // storing the tensors in a host buffer is useful when the processing of large batches + // is offloaded to a GPU device, since it reduces the time spent on data transfers + // generally, this will be done using the first device in the list + // a better approach would be to handle this on a weight-by-weight basis using the offload_op + // function of the device to determine if it would benefit from being stored in a host buffer + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } + } + + // add the CPU buffer type + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) { + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + } + } + + return buft_list; +} + +// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU +static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) { + buft_list_t buft_list; + + // add the device split buffer type if requested and available + if (split_mode == LLAMA_SPLIT_MODE_ROW) { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type"); + if (ggml_backend_split_buffer_type_fn) { + size_t dev_index = [&]() { + auto * reg = ggml_backend_dev_backend_reg(dev); + for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) { + if (ggml_backend_reg_dev_get(reg, i) == dev) { + return i; + } + } + throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev))); + }(); + auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split); + if (buft != nullptr) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add the device default buffer type + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + + return buft_list; +} + +struct llama_model::impl { + impl() {} + ~impl() {} + + uint64_t n_elements = 0; + + size_t n_bytes = 0; + + std::string desc_str; + + // model memory mapped files + llama_mmaps mappings; + + // objects representing data potentially being locked in memory + llama_mlocks mlock_bufs; + llama_mlocks mlock_mmaps; + + // contexts where the model tensors metadata is stored + std::vector ctxs; + + // the model memory buffers for the tensor data + std::vector bufs; + + buft_list_t cpu_buft_list; + std::map gpu_buft_list; + + struct layer_dev { + ggml_backend_dev_t dev; + buft_list_t * buft_list; + }; + + layer_dev dev_input = {}; + layer_dev dev_output = {}; + std::vector dev_layer; +}; + +llama_model::llama_model(const struct llama_model_params & params) : params(params), pimpl(std::make_unique()) { +} + +llama_model::~llama_model() {} + +void llama_model::load_stats(llama_model_loader & ml) { + pimpl->n_elements = ml.n_elements; + pimpl->n_bytes = ml.n_bytes; +} + +void llama_model::load_arch(llama_model_loader & ml) { + arch = ml.get_arch(); + if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } } -void llm_load_hparams(llama_model_loader & ml, llama_model & model) { - auto & hparams = model.hparams; +void llama_model::load_hparams(llama_model_loader & ml) { const gguf_context * ctx = ml.meta.get(); // get metadata as string @@ -230,13 +396,11 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { } const char * name = gguf_get_key(ctx, i); const std::string value = gguf_kv_to_str(ctx, i); - model.gguf_kv.emplace(name, value); + gguf_kv.emplace(name, value); } // get general kv - ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); - - // get hparams kv + ml.get_key(LLM_KV_GENERAL_NAME, name, false); ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false); // everything past this point is not vocab-related @@ -249,8 +413,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false); - if (model.arch == LLM_ARCH_WAVTOKENIZER_DEC) { + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); @@ -274,8 +439,8 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false); // n_head_kv is optional, default to n_head @@ -325,7 +490,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) { + if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_MLLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -336,34 +501,36 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { hparams.n_embd_head_v = 0; } - using e_model = llm_type; // TMP + // for differentiating model types + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); // arch-specific KVs - switch (model.arch) { + switch (arch) { case LLM_ARCH_LLAMA: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (hparams.n_expert == 8) { switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_8x7B; break; - case 56: model.type = e_model::MODEL_8x22B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_8x7B; break; + case 56: type = LLM_TYPE_8x22B; break; + default: type = LLM_TYPE_UNKNOWN; } } else { switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_1B; break; // Llama 3.2 1B - case 22: model.type = e_model::MODEL_1B; break; - case 26: model.type = e_model::MODEL_3B; break; - case 28: model.type = e_model::MODEL_3B; break; // Llama 3.2 3B + case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B + case 22: type = LLM_TYPE_1B; break; + case 26: type = LLM_TYPE_3B; break; + case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B // granite uses a vocab with len 49152 - case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; - case 36: model.type = e_model::MODEL_8B; break; // granite - case 40: model.type = e_model::MODEL_13B; break; - case 48: model.type = e_model::MODEL_34B; break; - case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 36: type = LLM_TYPE_8B; break; // granite + case 40: type = LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_34B; break; + case 60: type = LLM_TYPE_30B; break; + case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } } break; @@ -372,42 +539,42 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_11B; break; - case 100: model.type = e_model::MODEL_90B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_11B; break; + case 100: type = LLM_TYPE_90B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DECI: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 80: model.type = e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_MINICPM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); switch (hparams.n_layer) { - case 52: model.type = e_model::MODEL_1B; break; - case 40: model.type = e_model::MODEL_2B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 52: type = LLM_TYPE_1B; break; + case 40: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_MINICPM3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); switch (hparams.n_layer) { - case 62: model.type = e_model::MODEL_4B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 62: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GROK: @@ -415,8 +582,8 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 64: model.type = e_model::MODEL_314B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 64: type = LLM_TYPE_314B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_FALCON: @@ -424,21 +591,21 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 60: model.type = e_model::MODEL_40B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 60: type = LLM_TYPE_40B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_BAICHUAN: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } - if (model.type == e_model::MODEL_13B) { + if (type == LLM_TYPE_13B) { // TODO: become GGUF KV parameter hparams.f_max_alibi_bias = 8.0f; } @@ -447,19 +614,19 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 36: model.type = e_model::MODEL_3B; break; - case 42: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_15B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + case 42: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_REFACT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_1B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; } // TODO: become GGUF KV parameter @@ -469,48 +636,45 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 3: - model.type = e_model::MODEL_17M; break; // bge-micro + type = LLM_TYPE_17M; break; // bge-micro case 6: - model.type = e_model::MODEL_22M; break; // MiniLM-L6 + type = LLM_TYPE_22M; break; // MiniLM-L6 case 12: switch (hparams.n_embd) { - case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small - case 768: model.type = e_model::MODEL_109M; break; // bge-base - default: model.type = e_model::MODEL_UNKNOWN; + case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small + case 768: type = LLM_TYPE_109M; break; // bge-base + default: type = LLM_TYPE_UNKNOWN; } break; case 24: - model.type = e_model::MODEL_335M; break; // bge-large - default: model.type = e_model::MODEL_UNKNOWN; + type = LLM_TYPE_335M; break; // bge-large + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { - case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small - case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base - default: model.type = e_model::MODEL_UNKNOWN; + case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small + case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_NOMIC_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); if (hparams.n_layer == 12 && hparams.n_embd == 768) { - model.type = e_model::MODEL_137M; + type = LLM_TYPE_137M; } } break; case LLM_ARCH_BLOOM: @@ -518,14 +682,14 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; + case 24: type = LLM_TYPE_1B; break; case 30: switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - case 4096: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } // TODO: become GGUF KV parameter @@ -538,9 +702,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_30B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_30B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_STABLELM: @@ -548,10 +712,10 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_3B; break; - case 40: model.type = e_model::MODEL_12B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_QWEN: @@ -559,9 +723,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_QWEN2VL: @@ -573,27 +737,27 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break; - case 28: model.type = hparams.n_embd == 1536 ? e_model::MODEL_1_5B : e_model::MODEL_7B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 36: model.type = e_model::MODEL_3B; break; - case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; - case 48: model.type = e_model::MODEL_14B; break; - case 64: model.type = e_model::MODEL_32B; break; - case 80: model.type = e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; + case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; + case 32: type = LLM_TYPE_7B; break; + case 36: type = LLM_TYPE_3B; break; + case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_QWEN2MOE: { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_A2_7B; break; - case 28: model.type = e_model::MODEL_57B_A14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_A2_7B; break; + case 28: type = LLM_TYPE_57B_A14B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_PHI2: @@ -601,9 +765,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_PHI3: @@ -611,10 +775,10 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_3B; break; - case 40: model.type = e_model::MODEL_14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; } // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931 @@ -633,32 +797,41 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { throw std::runtime_error("invalid value for sliding_window"); } } break; + case LLM_ARCH_PHIMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_16x3_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_PLAMO: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 12: model.type = e_model::MODEL_SMALL; break; - case 24: model.type = e_model::MODEL_MEDIUM; break; - case 36: model.type = e_model::MODEL_LARGE; break; - case 48: model.type = e_model::MODEL_XL; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 12: type = LLM_TYPE_SMALL; break; + case 24: type = LLM_TYPE_MEDIUM; break; + case 36: type = LLM_TYPE_LARGE; break; + case 48: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_CODESHELL: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 42: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 42: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_ORION: @@ -666,17 +839,17 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_INTERNLM2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_20B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GEMMA: @@ -684,37 +857,37 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 18: model.type = e_model::MODEL_2B; break; - case 28: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 18: type = LLM_TYPE_2B; break; + case 28: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GEMMA2: { hparams.n_swa = 4096; // default value of gemma 2 - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); hparams.attn_soft_cap = true; switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_2B; break; - case 42: model.type = e_model::MODEL_9B; break; - case 46: model.type = e_model::MODEL_27B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 26: type = LLM_TYPE_2B; break; + case 42: type = LLM_TYPE_9B; break; + case 46: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 30: model.type = e_model::MODEL_3B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_15B; break; - case 52: model.type = e_model::MODEL_20B; break; // granite - case 88: model.type = e_model::MODEL_34B; break; // granite - default: model.type = e_model::MODEL_UNKNOWN; + case 30: type = LLM_TYPE_3B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + case 52: type = LLM_TYPE_20B; break; // granite + case 88: type = LLM_TYPE_34B; break; // granite + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_MAMBA: @@ -730,51 +903,51 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { switch (hparams.n_layer) { case 24: switch (hparams.n_embd) { - case 768: model.type = e_model::MODEL_SMALL; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 48: switch (hparams.n_embd) { - case 1024: model.type = e_model::MODEL_MEDIUM; break; - case 1536: model.type = e_model::MODEL_LARGE; break; - case 2048: model.type = e_model::MODEL_XL; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 64: switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 2560: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - case 80: model.type = e_model::MODEL_65B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 80: type = LLM_TYPE_65B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_COMMAND_R: { - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_35B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_35B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_COHERE2: { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DBRX: @@ -783,8 +956,8 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_16x12B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OLMO: @@ -793,10 +966,10 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); switch (hparams.n_layer) { - case 22: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 80: model.type = e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 22: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OLMO2: @@ -804,18 +977,18 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OLMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_A1_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OPENELM: @@ -823,57 +996,57 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_270M; break; - case 20: model.type = e_model::MODEL_450M; break; - case 28: model.type = e_model::MODEL_1B; break; - case 36: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GPTNEOX: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); switch (hparams.n_layer) { case 6: switch (hparams.n_ff()) { - case 512: model.type = e_model::MODEL_14M; break; - case 2048: model.type = e_model::MODEL_70M; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 512: type = LLM_TYPE_14M; break; + case 2048: type = LLM_TYPE_70M; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 12: switch (hparams.n_ff()) { - case 3072: model.type = e_model::MODEL_160M; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 3072: type = LLM_TYPE_160M; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 16: switch (hparams.n_ff()) { - case 8192: model.type = e_model::MODEL_1B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 8192: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 24: switch (hparams.n_ff()) { - case 4096: model.type = e_model::MODEL_410M; break; - case 8192: model.type = e_model::MODEL_1_4B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 4096: type = LLM_TYPE_410M; break; + case 8192: type = LLM_TYPE_1_4B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 32: switch (hparams.n_ff()) { - case 10240: model.type = e_model::MODEL_2_8B; break; - case 16384: model.type = e_model::MODEL_6_9B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 10240: type = LLM_TYPE_2_8B; break; + case 16384: type = LLM_TYPE_6_9B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 36: switch (hparams.n_ff()) { - case 20480: model.type = e_model::MODEL_12B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 20480: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 44: switch (hparams.n_ff()) { - case 24576: model.type = e_model::MODEL_20B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24576: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_ARCTIC: @@ -882,40 +1055,40 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { if (hparams.n_expert == 128) { switch (hparams.n_layer) { - case 35: model.type = e_model::MODEL_10B_128x3_66B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 35: type = LLM_TYPE_10B_128x3_66B; break; + default: type = LLM_TYPE_UNKNOWN; } } else { - model.type = e_model::MODEL_UNKNOWN; + type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DEEPSEEK: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); switch (hparams.n_layer) { - case 28: model.type = e_model::MODEL_20B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 28: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DEEPSEEK2: { bool is_lite = (hparams.n_layer == 27); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set @@ -924,19 +1097,31 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); switch (hparams.n_layer) { - case 27: model.type = e_model::MODEL_16B; break; - case 60: model.type = e_model::MODEL_236B; break; - case 61: model.type = e_model::MODEL_671B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 27: type = LLM_TYPE_16B; break; + case 60: type = LLM_TYPE_236B; break; + case 61: type = LLM_TYPE_671B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_CHATGLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 28: model.type = e_model::MODEL_6B; break; - case 40: model.type = e_model::MODEL_9B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 28: { + if (hparams.n_head(0) == 16) { + type = LLM_TYPE_1_5B; + } else { + type = LLM_TYPE_6B; + } + } break; + case 40: { + if (hparams.n_head(0) == 24) { + type = LLM_TYPE_4B; + } else { + type = LLM_TYPE_9B; + } + } break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_BITNET: @@ -944,13 +1129,13 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 26: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_T5: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); uint32_t dec_start_token_id; @@ -959,32 +1144,32 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { } switch (hparams.n_layer) { - case 6: model.type = e_model::MODEL_60M; break; // t5-small - case 8: model.type = e_model::MODEL_80M; break; // flan-t5-small + case 6: type = LLM_TYPE_60M; break; // t5-small + case 8: type = LLM_TYPE_80M; break; // flan-t5-small case 12: switch (hparams.n_ff()) { - case 3072: model.type = e_model::MODEL_220M; break; // t5-base - case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base - default: model.type = e_model::MODEL_UNKNOWN; + case 3072: type = LLM_TYPE_220M; break; // t5-base + case 2048: type = LLM_TYPE_250M; break; // flan-t5-base + default: type = LLM_TYPE_UNKNOWN; } break; case 24: switch (hparams.n_ff()) { - case 4096: model.type = e_model::MODEL_770M; break; // t5-large - case 2816: model.type = e_model::MODEL_780M; break; // flan-t5-large - case 16384: model.type = e_model::MODEL_3B; break; // t5-3b - case 5120: model.type = e_model::MODEL_3B; break; // flan-t5-xl - case 65536: model.type = e_model::MODEL_11B; break; // t5-11b - case 10240: model.type = e_model::MODEL_11B; break; // flan-t5-xxl - default: model.type = e_model::MODEL_UNKNOWN; + case 4096: type = LLM_TYPE_770M; break; // t5-large + case 2816: type = LLM_TYPE_780M; break; // flan-t5-large + case 16384: type = LLM_TYPE_3B; break; // t5-3b + case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl + case 65536: type = LLM_TYPE_11B; break; // t5-11b + case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_T5ENCODER: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); - model.type = e_model::MODEL_UNKNOWN; + type = LLM_TYPE_UNKNOWN; } break; case LLM_ARCH_JAIS: { @@ -992,18 +1177,18 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1_3B; break; - case 40: model.type = e_model::MODEL_13B; break; + case 24: type = LLM_TYPE_1_3B; break; + case 40: type = LLM_TYPE_13B; break; /* TODO: add variants */ - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_NEMOTRON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_4B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_EXAONE: @@ -1011,44 +1196,48 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); - ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); - ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); - ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1_6B; break; + case 24: type = LLM_TYPE_1_6B; break; case 32: switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - case 4096: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - case 61: model.type = e_model::MODEL_14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_3B; break; - case 40: model.type = e_model::MODEL_3B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; // Add additional layer/vocab/etc checks here for other model sizes - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_CHAMELEON: @@ -1058,9 +1247,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_34B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_34B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_SOLAR: @@ -1069,13 +1258,13 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { for (size_t i = 0; i < hparams.n_bskcn_arr.max_size(); ++i) { auto & bskcn = hparams.n_bskcn_arr[i]; bskcn.fill(0); - auto kv = LLM_KV(model.arch); + auto kv = LLM_KV(arch); ml.get_key_or_arr(format((kv(LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION) + ".%d").c_str(), i), bskcn, hparams.n_layer, false); } switch (hparams.n_layer) { - case 64: model.type = e_model::MODEL_22B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 64: type = LLM_TYPE_22B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_WAVTOKENIZER_DEC: @@ -1088,724 +1277,2426 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { default: throw std::runtime_error("unsupported model architecture"); } - model.ftype = ml.ftype; + pimpl->n_bytes = ml.n_bytes; + + pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); if (hparams.f_max_alibi_bias > 0.0f) { hparams.use_alibi = true; } - hparams.rope_type = llama_rope_type(&model); + hparams.rope_type = llama_model_rope_type(this); } -void llm_load_vocab(llama_model_loader & ml, llama_model & model) { - auto & vocab = model.vocab; +void llama_model::load_vocab(llama_model_loader & ml) { + const auto kv = LLM_KV(arch); - struct gguf_context * ctx = ml.meta.get(); + vocab.load(ml, kv); +} - const auto kv = LLM_KV(model.arch); +bool llama_model::load_tensors(llama_model_loader & ml) { + const auto & split_mode = params.split_mode; + const auto & n_gpu_layers = params.n_gpu_layers; + const auto & use_mlock = params.use_mlock; + const auto & tensor_split = params.tensor_split; - // determine vocab type - { - std::string tokenizer_model; - std::string tokenizer_pre; + const int n_layer = hparams.n_layer; - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); - ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + const bool use_mmap_buffer = true; - if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { - vocab.type = LLAMA_VOCAB_TYPE_NONE; + LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = LLAMA_TOKEN_NULL; - vocab.special_unk_id = LLAMA_TOKEN_NULL; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - vocab.linefeed_id = LLAMA_TOKEN_NULL; - - // read vocab size from metadata - if (!ml.get_key(LLM_KV_VOCAB_SIZE, vocab.n_vocab, false)) { - vocab.n_vocab = 0; - LLAMA_LOG_WARN("%s: there is no vocab_size in metadata, vocab.n_vocab will be set to %u\n", __func__, vocab.n_vocab); - } - return; - } - - if (tokenizer_model == "llama") { - vocab.type = LLAMA_VOCAB_TYPE_SPM; - - // default special tokens - vocab.special_bos_id = 1; - vocab.special_eos_id = 2; - vocab.special_unk_id = 0; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - } else if (tokenizer_model == "bert") { - vocab.type = LLAMA_VOCAB_TYPE_WPM; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = LLAMA_TOKEN_NULL; - vocab.special_unk_id = 100; - vocab.special_sep_id = 102; - vocab.special_pad_id = 0; - vocab.special_cls_id = 101; - vocab.special_mask_id = 103; - } else if (tokenizer_model == "gpt2") { - vocab.type = LLAMA_VOCAB_TYPE_BPE; - - // read bpe merges and populate bpe ranks - const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); - if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } - - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - - std::string first; - std::string second; - - const size_t pos = word.find(' ', 1); - - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); - } - - vocab.bpe_ranks.emplace(std::make_pair(first, second), i); - } - - // default special tokens - vocab.special_bos_id = 11; - vocab.special_eos_id = 11; - vocab.special_unk_id = LLAMA_TOKEN_NULL; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - } else if (tokenizer_model == "t5") { - vocab.type = LLAMA_VOCAB_TYPE_UGM; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = 1; - vocab.special_unk_id = 2; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = 0; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - - const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); - if (precompiled_charsmap_keyidx != -1) { - size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); - const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); - vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap); -#ifdef IS_BIG_ENDIAN - // correct endiannes of data in precompiled_charsmap binary blob - uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0]; - *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); - assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); - size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); - uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)]; - for (size_t i = 0; i < xcda_array_size; ++i) { - xcda_array[i] = __builtin_bswap32(xcda_array[i]); - } -#endif - } - } else if (tokenizer_model == "rwkv") { - vocab.type = LLAMA_VOCAB_TYPE_RWKV; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = LLAMA_TOKEN_NULL; - vocab.special_unk_id = LLAMA_TOKEN_NULL; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - } else { - throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); - } - - // for now, only BPE models have pre-tokenizers - if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { - vocab.tokenizer_add_space_prefix = false; - vocab.tokenizer_clean_spaces = true; - if (tokenizer_pre == "default") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if ( - tokenizer_pre == "llama3" || - tokenizer_pre == "llama-v3" || - tokenizer_pre == "llama-bpe"|| - tokenizer_pre == "falcon3") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; - vocab.tokenizer_ignore_merges = true; - vocab.tokenizer_add_bos = true; - } else if ( - tokenizer_pre == "deepseek-llm") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "deepseek-coder") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "deepseek-v3") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "falcon") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; - } else if ( - tokenizer_pre == "mpt") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT; - } else if ( - tokenizer_pre == "starcoder") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; - } else if ( - tokenizer_pre == "gpt-2" || - tokenizer_pre == "phi-2" || - tokenizer_pre == "jina-es" || - tokenizer_pre == "jina-de" || - tokenizer_pre == "gigachat" || - tokenizer_pre == "jina-v1-en" || - tokenizer_pre == "jina-v2-es" || - tokenizer_pre == "jina-v2-de" || - tokenizer_pre == "jina-v2-code" || - tokenizer_pre == "roberta-bpe") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; - } else if ( - tokenizer_pre == "refact") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT; - } else if ( - tokenizer_pre == "command-r") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "qwen2") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "stablelm2") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2; - } else if ( - tokenizer_pre == "olmo") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO; - } else if ( - tokenizer_pre == "dbrx") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX; - } else if ( - tokenizer_pre == "smaug-bpe") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; - } else if ( - tokenizer_pre == "poro-chat") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "chatglm-bpe") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; - vocab.special_bos_id = LLAMA_TOKEN_NULL; - } else if ( - tokenizer_pre == "viking") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "jais") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS; - } else if ( - tokenizer_pre == "tekken") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN; - vocab.tokenizer_clean_spaces = false; - vocab.tokenizer_ignore_merges = true; - vocab.tokenizer_add_bos = true; - } else if ( - tokenizer_pre == "smollm") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "codeshell") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; - } else if ( - tokenizer_pre == "bloom") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BLOOM; - } else if ( - tokenizer_pre == "gpt3-finnish") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; - } else if ( - tokenizer_pre == "exaone") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; - } else if ( - tokenizer_pre == "chameleon") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; - vocab.tokenizer_add_bos = true; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "minerva-7b") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MINERVA; - } else if ( - tokenizer_pre == "megrez") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; - } else { - LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } - } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_space_prefix = true; - vocab.tokenizer_clean_spaces = false; - vocab.tokenizer_add_bos = true; - vocab.tokenizer_add_eos = false; - } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_space_prefix = false; - vocab.tokenizer_clean_spaces = true; - vocab.tokenizer_add_bos = true; - vocab.tokenizer_add_eos = false; - } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_bos = false; - vocab.tokenizer_add_eos = true; - } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_space_prefix = false; - vocab.tokenizer_clean_spaces = false; - vocab.tokenizer_add_bos = false; - vocab.tokenizer_add_eos = false; - } else { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } - - ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.tokenizer_add_space_prefix, false); - ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false); + // build a list of buffer types for the CPU and GPU devices + pimpl->cpu_buft_list = make_cpu_buft_list(devices); + for (auto * dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); + // add CPU buffer types as a fallback + buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); + pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); } - const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); - if (token_idx == -1) { - throw std::runtime_error("cannot find tokenizer vocab in model file\n"); - } - - const float * scores = nullptr; - const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); - if (score_idx != -1) { - scores = (const float * ) gguf_get_arr_data(ctx, score_idx); - } - - const int * toktypes = nullptr; - const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); - if (toktype_idx != -1) { - toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); - } - - const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); - - vocab.n_vocab = n_vocab; - vocab.id_to_token.resize(n_vocab); - - for (uint32_t i = 0; i < n_vocab; i++) { - std::string word = gguf_get_arr_str(ctx, token_idx, i); - if (word.empty()) { - LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); - word = "[EMPTY_" + std::to_string(i) + "]"; + // calculate the split points + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); + std::vector splits(n_devices()); + if (all_zero) { + // default split, by free memory + for (size_t i = 0; i < n_devices(); ++i) { + ggml_backend_dev_t dev = devices[i]; + size_t total; + size_t free; + ggml_backend_dev_memory(dev, &free, &total); + splits[i] = free; } - - vocab.token_to_id[word] = i; - vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); - - auto & token_data = vocab.id_to_token[i]; - token_data.text = std::move(word); - token_data.score = scores ? scores[i] : 0.0f; - token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; - - if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file - switch(toktypes[i]) { - case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; - case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; - case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; - case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; - case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; - case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; - case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; - default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; - } - } - } - GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); - - vocab.init_tokenizer(); - - // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' - if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - try { - vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n'); - } catch (const std::exception & e) { - LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what()); - vocab.linefeed_id = vocab.special_pad_id; - } - } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { - vocab.linefeed_id = vocab.special_pad_id; - } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { - const std::vector ids = llama_tokenize_internal(vocab, "\n", false); - GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - vocab.linefeed_id = ids[0]; } else { - const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A - - //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - if (ids.empty()) { - LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); - vocab.linefeed_id = vocab.special_pad_id; - } else { - vocab.linefeed_id = ids[0]; - } + std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); } - // special tokens - { - const std::vector> special_token_types = { - { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, - { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, - { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, - { LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id }, - { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, - { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, - { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, - { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id }, - { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id }, - { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id }, - { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id }, - { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id }, - { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id }, - { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id }, - { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id }, + // sum and normalize the splits to get the split points + float split_sum = 0.0f; + for (size_t i = 0; i < n_devices(); ++i) { + split_sum += splits[i]; + splits[i] = split_sum; + } + for (size_t i = 0; i < n_devices(); ++i) { + splits[i] /= split_sum; + } - // deprecated - { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_fim_pre_id }, - { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_fim_suf_id }, - { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_fim_mid_id }, + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); + auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { + if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev)); + return {cpu_dev, &pimpl->cpu_buft_list}; + } + const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); + auto * dev = devices.at(layer_gpu); + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev)); + return {dev, &pimpl->gpu_buft_list.at(dev)}; + }; + + // assign the input layer + // there is very little benefit to offloading the input layer, so always keep it on the CPU + pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; + + // assign the repeating layers to the devices according to the splits + pimpl->dev_layer.resize(n_layer); + for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer[il] = get_layer_buft_list(il); + } + + // assign the output layer + pimpl->dev_output = get_layer_buft_list(n_layer); + + // one ggml context per buffer type + int max_n_tensors = ml.n_tensors; + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += n_layer*2; // duplicated rope freq tensors + const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; + + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ctx_map[buft] = ctx; + pimpl->ctxs.emplace_back(ctx); + + return ctx; + } + return it->second; + }; + + const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + + // create tensors for the weights + { + // note: cast to int64_t since we will use these for the tensor dimensions + const int64_t n_head = hparams.n_head(); + const int64_t n_head_kv = hparams.n_head_kv(); + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_ff = hparams.n_ff(); + const int64_t n_embd_gqa = n_embd_v_gqa; + const int64_t n_vocab = hparams.n_vocab; + const int64_t n_token_types = vocab.n_token_types(); + const int64_t n_rot = hparams.n_rot; + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_ctx_train = hparams.n_ctx_train; + + if (n_expert > 0 && hparams.n_expert_used == 0) { + throw std::runtime_error("model has expert layers but no expert layers are used"); + } + + int n_moved_tensors = 0; + ggml_tensor * first_moved_tensor = nullptr; + ggml_backend_buffer_type_t first_moved_from_buft = nullptr; + ggml_backend_buffer_type_t first_moved_to_buft = nullptr; + + auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { + ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); + + if (!t_meta) { + if (flags & TENSOR_NOT_REQUIRED) { + return nullptr; + } + throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + } + + // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops + // the tensor is duplicated + // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor + llm_tensor tn_tensor = tn.tensor; + if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & TENSOR_DUPLICATED) { + tn_tensor = LLM_TENSOR_OUTPUT; + } + + llm_tensor_info info; + try { + info = llm_tensor_info_for(tn_tensor); + } catch (const std::out_of_range & e) { + throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); + } + + // skip unused tensors + if (info.op == GGML_OP_NONE) { + LLAMA_LOG_WARN("model has unused tensor %s -- ignoring\n", tn.str().c_str()); + ml.n_created++; + + return nullptr; + } + + // tensors with "bias" suffix are always used with GGML_OP_ADD + ggml_op op; + bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; + if (bias) { + op = GGML_OP_ADD; + } else { + op = info.op; + } + + // sanity checks + if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (tn.bid != -1) { + GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); + } + } else { + if (tn.bid == -1) { + GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); + } + } + + // select the buffer type for this tensor + buft_list_t * buft_list; + switch (info.layer) { + case LLM_TENSOR_LAYER_INPUT: + buft_list = pimpl->dev_input.buft_list; + break; + case LLM_TENSOR_LAYER_OUTPUT: + buft_list = pimpl->dev_output.buft_list; + break; + case LLM_TENSOR_LAYER_REPEATING: + buft_list = pimpl->dev_layer.at(tn.bid).buft_list; + break; + default: + GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); + } + + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } + + // avoid using a host buffer when using mmap + auto * buft_dev = ggml_backend_buft_get_device(buft); + if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + buft = ggml_backend_dev_buffer_type(cpu_dev); + } + + if (buft != buft_list->front().second) { + n_moved_tensors++; + if (!first_moved_tensor) { + first_moved_tensor = t_meta; + first_moved_from_buft = buft_list->front().second; + first_moved_to_buft = buft; + } + } + + ggml_context * ctx = ctx_for_buft(buft); + + // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one + if (flags & TENSOR_DUPLICATED) { + ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); + if (t) { + return t; + } + } + return ml.create_tensor(ctx, tn, ne, flags); }; - for (const auto & it : special_token_types) { - const std::string & key = kv(std::get<0>(it)); - int32_t & id = std::get<1>(it); + layers.resize(n_layer); - uint32_t new_id; - if (!ml.get_key(std::get<0>(it), new_id, false)) { - continue; - } - if (new_id >= vocab.id_to_token.size()) { - LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n", - __func__, key.c_str(), new_id, id); - } else { - id = new_id; - } - } + // TODO: move to a separate function + const auto tn = LLM_TN(arch); + switch (arch) { + case LLM_ARCH_LLAMA: + case LLM_ARCH_REFACT: + case LLM_ARCH_MINICPM: + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - // Handle add_bos_token and add_eos_token - { - bool temp = true; + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { - vocab.tokenizer_add_bos = temp; - } - if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { - vocab.tokenizer_add_eos = temp; - } - } + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } - // auto-detect special tokens by text - // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... - // for now, we apply this workaround to find the tokens based on their text + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; - for (const auto & t : vocab.token_to_id) { - // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. - if (vocab.special_eot_id == LLAMA_TOKEN_NULL) { - if (false - || t.first == "<|eot_id|>" - || t.first == "<|im_end|>" - || t.first == "<|end|>" - || t.first == "" - || t.first == "<|endoftext|>" - || t.first == "" - || t.first == "<|end▁of▁sentence|>" // DeepSeek - ) { - vocab.special_eot_id = t.second; - if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { - LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", - __func__, t.second, t.first.c_str()); - vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } + } break; + case LLM_ARCH_MLLAMA: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + if (hparams.cross_attention_layers(i)) { + layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}, 0); + layer.cross_attn_k_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}, 0); + layer.cross_attn_o_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}, 0); + layer.cross_attn_q_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}, 0); + layer.cross_attn_q_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}, 0); + layer.cross_attn_v_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}, 0); + layer.cross_attn_attn_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}, 0); + layer.cross_attn_mlp_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } else { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; + case LLM_ARCH_DECI: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + } + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_MINICPM3: + { + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_GROK: + { + if (n_expert == 0) { + throw std::runtime_error("Grok model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_DBRX: + { + if (n_expert == 0) { + throw std::runtime_error("DBRX model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } break; + case LLM_ARCH_BAICHUAN: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_FALCON: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_STARCODER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + // needs to be on GPU + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_BERT: + case LLM_ARCH_NOMIC_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + if (arch == LLM_ARCH_BERT) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } else { + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + + if (arch == LLM_ARCH_BERT) { + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_JINA_BERT_V2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; // JinaBertLayer + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_BLOOM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_MPT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + // AWQ ScaleActivation layer + layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_STABLELM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors, present in Stable LM 2 1.6B + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + + // optional q and k layernorms, present in StableLM 2 12B + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + // optional FFN norm, not present in StableLM 2 12B which uses parallel residual + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); + } + } break; + case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2VL: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN2MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } + } break; + case LLM_ARCH_PHI2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_PHI3: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_PHIMOE: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_PLAMO: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GPT2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_CODESHELL: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_ORION: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_INTERNLM2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GEMMA: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_GEMMA2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_STARCODER2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional bias tensors + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); + } + } break; + case LLM_ARCH_MAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + if (2 * n_embd != d_inner) { + throw std::runtime_error("only an expansion factor of 2 is supported for now"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } + } break; + case LLM_ARCH_XVERSE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_COMMAND_R: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_layer >= 64){ + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_COHERE2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); } } - } + break; + case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - // find EOM token: "<|eom_id|>" - if (vocab.special_eom_id == LLAMA_TOKEN_NULL) { - if (false - || t.first == "<|eom_id|>" - ) { - vocab.special_eom_id = t.second; - if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { - LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", - __func__, t.second, t.first.c_str()); - vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - } - } - // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
-            if (vocab.special_fim_pre_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_prefix|>"  // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁begin|>" // DeepSeek
-                        || t.first == "
"
-                        ) {
-                    vocab.special_fim_pre_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
 
-            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
-            if (vocab.special_fim_suf_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_suffix|>" // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁hole|>" // DeepSeek
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_suf_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
-            if (vocab.special_fim_mid_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_middle|>" // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁end|>"  // DeepSeek
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_mid_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
-                }
-            }
+                } break;
+            case LLM_ARCH_OLMO2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
-            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
-            if (vocab.special_fim_pad_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_pad|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_pad_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
-            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
-            if (vocab.special_fim_rep_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_repo|>"  // Qwen
-                        || t.first == "<|repo_name|>"
-                        || t.first == ""
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_rep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
 
-            // find FIM_SEP token: "<|file_sep|>"
-            if (vocab.special_fim_sep_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|file_sep|>" // Qwen
-                        ) {
-                    vocab.special_fim_sep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
                     }
-                }
-            }
+                } break;
+            case LLM_ARCH_OLMOE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                        if (n_expert == 0) {
+                            throw std::runtime_error("n_expert must be > 0");
+                        }
+                        if (n_expert_used == 0) {
+                            throw std::runtime_error("n_expert_used must be > 0");
+                        }
+
+                        // MoE branch
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                    }
+                } break;
+            case LLM_ARCH_OPENELM:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    // init output from the input tok embed
+                    output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        const int64_t n_head      =   hparams.n_head(i);
+                        const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
+                        const int64_t n_ff        =   hparams.n_ff(i);
+
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_GPTNEOX:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_ARCTIC:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                        layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
+                    }
+                } break;
+            case LLM_ARCH_DEEPSEEK:
+                {
+
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (i < (int) hparams.n_layer_dense_lead) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        } else {
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            // MoE branch
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                            // Shared expert branch
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_DEEPSEEK2:
+                {
+                    const bool is_lite = (hparams.n_layer == 27);
+
+                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
+                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
+                    const int64_t q_lora_rank  = hparams.n_lora_q;
+                    const int64_t kv_lora_rank = hparams.n_lora_kv;
+
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        if (!is_lite) {
+                            layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
+                        }
+
+                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
+
+                        if (!is_lite) {
+                            layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
+                            layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
+                        } else {
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        }
+
+                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
+                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
+                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (i < (int) hparams.n_layer_dense_lead) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        } else {
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            // MoE branch
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                            // Shared expert branch
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_BITNET:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm     = create_tensor(tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd}, 0);
+                        layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq       = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.wk       = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.wv       = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.wo       = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm     = create_tensor(tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd}, 0);
+                        layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
+
+                        layer.ffn_gate       = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down       = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up         = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_scale   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
+            case LLM_ARCH_T5:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm     = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        layer.attn_norm  = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.attn_norm_cross  = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        // this tensor seems to be unused in HF transformers implementation
+                        layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_T5ENCODER:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_JAIS:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_CHATGLM:
+                {
+                    tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        if (layer.wqkv == nullptr) {
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        }
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_NEMOTRON:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
+
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        // optional MLP bias
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
+            case LLM_ARCH_EXAONE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN,   "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,     "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_RWKV6:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // Block 0, LN0
+                    tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int ffn_size = hparams.n_ff_arr[0];
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, 0);
+
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
+
+                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
+
+                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
+                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
+                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
+                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
+
+                        layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
+                        layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
+
+                        layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
+                        layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
+                        layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0);
+                    }
+
+                } break;
+            case LLM_ARCH_RWKV6QWEN2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int n_head_kv = hparams.n_head_kv();
+                    int attn_key_value_size;
+                    if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) {
+                        attn_key_value_size = attn_hidden_size;
+                    } else {
+                        attn_key_value_size = n_head_kv * head_size;
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
+
+                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
+
+                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
+                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
+                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        // optional bias tensors
+                        layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_CHAMELEON:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
+                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i),  {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i),  {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_SOLAR:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    {
+                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.bskcn_tv = create_tensor(tn(LLM_TENSOR_BSKCN_TV, "weight", i), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_WAVTOKENIZER_DEC:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);
+
+                    conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0);
+                    conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"),   {1, hparams.posnet.n_embd}, 0);
+
+                    // posnet
+                    {
+                        const int64_t n_embd = hparams.posnet.n_embd;
+
+                        for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) {
+                            auto & layer = layers[i].posnet;
+
+                            // posnet:
+                            //
+                            //  - resnet
+                            //  - resnet
+                            //  - attn
+                            //  - resnet
+                            //  - resnet
+                            //  - norm
+                            //
+                            switch (i) {
+                                case 0:
+                                case 1:
+                                case 3:
+                                case 4:
+                                    {
+                                        layer.norm1   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0);
+                                        layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.conv1   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0);
+                                        layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.norm2   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0);
+                                        layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.conv2   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0);
+                                        layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                case 2:
+                                    {
+                                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
+                                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_q      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_q_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_k      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_k_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_v      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_v_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_o      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_o_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                case 5:
+                                    {
+                                        layer.norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
+                                        layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                default: GGML_ABORT("unknown posnet layer");
+                            };
+                        }
+                    }
+
+                    GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
+
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {hparams.posnet.n_embd}, 0);
+
+                    // convnext
+                    {
+                        const int64_t n_embd = hparams.convnext.n_embd;
+
+                        for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) {
+                            auto & layer = layers[i].convnext;
+
+                            layer.dw     = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "weight", i), {7, 1, n_embd}, 0);
+                            layer.dw_b   = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "bias",   i), {1, n_embd}, 0);
+
+                            layer.norm   = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "weight", i), {n_embd}, 0);
+                            layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "bias",   i), {n_embd}, 0);
+
+                            layer.pw1    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "weight", i), {n_embd, n_ff}, 0);
+                            layer.pw1_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "bias",   i), {n_ff}, 0);
+
+                            layer.pw2    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "weight", i), {n_ff, n_embd}, 0);
+                            layer.pw2_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "bias",   i), {n_embd}, 0);
+
+                            layer.gamma  = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0);
+                        }
+
+                        // output
+                        output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    }
+
+                    output   = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
+                    output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"),   {n_embd}, 0);
+                } break;
+            default:
+                throw std::runtime_error("unknown architecture");
         }
 
-        // maintain a list of tokens that cause end-of-generation
-        // this is currently determined based on the token text, which is obviously not ideal
-        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
-        vocab.special_eog_ids.clear();
-
-        if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
-        }
-
-        if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
-        }
-
-        if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
-        }
-
-        for (const auto & t : vocab.token_to_id) {
-            if (false
-                    || t.first == "<|eot_id|>"
-                    || t.first == "<|im_end|>"
-                    || t.first == "<|end|>"
-                    || t.first == ""
-                    || t.first == "<|endoftext|>"
-                    || t.first == "<|eom_id|>"
-                    || t.first == ""
-               ) {
-                vocab.special_eog_ids.insert(t.second);
-                if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                            __func__, t.second, t.first.c_str());
-                    vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                }
-            } else {
-                // token is control, but not marked as EOG -> print a debug log
-                if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
-                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
-                            __func__, t.second, t.first.c_str());
-                }
-            }
-        }
-
-        // sanity checks
-        if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eos_id);
-            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
-
-        if (vocab.special_eot_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eot_id);
-            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
-
-        if (vocab.special_eom_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eom_id);
-            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        if (n_moved_tensors > 0) {
+            LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n",
+                __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
+                ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
         }
     }
 
-    // build special tokens cache
-    {
-        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
-                vocab.cache_special_tokens.push_back(id);
-            }
+    ml.done_getting_tensors();
+
+    ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr);
+    pimpl->mappings.reserve(ml.mappings.size());
+
+    // create the backend buffers
+    std::vector> ctx_bufs;
+    ctx_bufs.reserve(ctx_map.size());
+
+    // Ensure we have enough capacity for the maximum backend buffer we will potentially create
+    const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
+    pimpl->bufs.reserve(n_max_backend_buffer);
+
+    for (auto & it : ctx_map) {
+        ggml_backend_buffer_type_t buft = it.first;
+        ggml_context * ctx              = it.second;
+
+        // skip contexts without tensors
+        if (ggml_get_first_tensor(ctx) == nullptr) {
+            continue;
         }
 
-        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
-            [&] (const llama_vocab::id a, const llama_vocab::id b) {
-                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
-            }
-        );
+        llama_buf_map buf_map;
+        buf_map.reserve(n_max_backend_buffer);
 
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
-    }
-
-    // build token to piece cache
-    {
-        size_t size_cache = 0;
-
-        std::vector cache_token_to_piece(n_vocab);
-
-        for (uint32_t id = 0; id < n_vocab; ++id) {
-            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
-
-            size_cache += cache_token_to_piece[id].size();
+        // check if it is possible to use buffer_from_host_ptr with this buffer type
+        ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
+        if (!dev) {
+            // FIXME: workaround for CPU backend buft having a NULL device
+            dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
         }
+        ggml_backend_dev_props props;
+        ggml_backend_dev_get_props(dev, &props);
+        bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
+        bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
 
-        std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
-
-        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
-    }
-
-    // Handle per token attributes
-    //NOTE: Each model customizes per token attributes.
-    //NOTE: Per token attributes are missing from the GGUF file.
-    //TODO: Extract attributes from GGUF file.
-    {
-        auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool {
-            for (auto substr : substrs) {
-                if (str.find(substr) < std::string::npos) {
-                    return true;
+        if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
+            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
+                // only the mmap region containing the tensors in the model is mapped to the backend buffer
+                // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
+                // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
+                void * addr = nullptr;
+                size_t first, last; // NOLINT
+                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
+                if (first >= last) {
+                    continue;
                 }
+                const size_t max_size = ggml_get_max_tensor_size(ctx);
+                ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
+                if (buf == nullptr) {
+                    throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
+                }
+                pimpl->bufs.emplace_back(buf);
+                buf_map.emplace(idx, buf);
             }
+        }
+        else {
+            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+            if (buf == nullptr) {
+                throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
+            }
+            pimpl->bufs.emplace_back(buf);
+            if (use_mlock && ggml_backend_buffer_is_host(buf)) {
+                pimpl->mlock_bufs.emplace_back(new llama_mlock);
+                auto & mlock_buf = pimpl->mlock_bufs.back();
+                mlock_buf->init   (ggml_backend_buffer_get_base(buf));
+                mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
+            }
+            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
+                buf_map.emplace(idx, buf);
+            }
+        }
+
+        if (pimpl->bufs.empty()) {
+            throw std::runtime_error("failed to allocate buffer");
+        }
+
+        for (auto & buf : buf_map) {
+            // indicate that this buffer contains weights
+            // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight
+            ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+        }
+
+        ctx_bufs.emplace_back(ctx, buf_map);
+    }
+
+    if (llama_supports_gpu_offload()) {
+        const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
+
+        LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
+        if (n_gpu_layers > (int) hparams.n_layer) {
+            LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
+        }
+
+        const int max_backend_supported_layers = hparams.n_layer + 1;
+        const int max_offloadable_layers       = hparams.n_layer + 1;
+
+        LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
+    }
+
+    // print memory requirements per buffer type
+    for (auto & buf : pimpl->bufs) {
+        LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
+    }
+
+    // populate tensors_by_name
+    for (auto & ctx : pimpl->ctxs) {
+        for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
+            tensors_by_name.emplace_back(ggml_get_name(cur), cur);
+        }
+    }
+
+    // load tensor data
+    for (auto & it : ctx_bufs) {
+        ggml_context * ctx = it.first;
+        auto & bufs = it.second;
+        if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
             return false;
-        };
-
-        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
-            uint32_t current = vocab.id_to_token.at(id).attr;
-            current = value ? (current | attr) : (current & ~attr);
-            vocab.id_to_token[id].attr = (llama_token_attr) current;
-        };
-
-        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
-            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
-        };
-
-        std::string model_name;
-        std::string tokenizer_pre;
-
-        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
-        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
-
-        // model name to lowercase
-        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
-            [] (const std::string::value_type x) {
-                return std::tolower(x);
-            }
-        );
-
-        // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
-            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
-        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
-            for (auto id : vocab.cache_special_tokens) {
-                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {""}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {"", "", "<|endoftext|>"}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
-            }
         }
     }
+
+    if (use_mmap_buffer) {
+        for (auto & mapping : ml.mappings) {
+            pimpl->mappings.emplace_back(std::move(mapping));
+        }
+    }
+
+    return true;
 }
 
-void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
-    const auto & hparams = model.hparams;
-    const auto & vocab   = model.vocab;
+std::string llama_model::arch_name() const {
+    return llm_arch_name(arch);
+}
 
+std::string llama_model::type_name() const {
+    return llm_type_name(type);
+}
+
+std::string llama_model::desc() const {
+    return pimpl->desc_str;
+}
+
+size_t llama_model::size() const {
+    return pimpl->n_bytes;
+}
+
+size_t llama_model::max_nodes() const {
+    return std::max(8192, tensors_by_name.size()*5);
+}
+
+size_t llama_model::n_devices() const {
+    return devices.size();
+}
+
+uint64_t llama_model::n_elements() const {
+    return pimpl->n_elements;
+}
+
+void llama_model::print_info() const {
     const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
 
     auto print_f = [](const std::function & f, uint32_t n) {
@@ -1838,11 +3729,7 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     };
 
     // hparams
-    LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
-    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, llm_arch_name(model.arch));
-    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
-    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
-    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
+    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, arch_name().c_str());
     LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
 
     if (!hparams.vocab_only) {
@@ -1881,60 +3768,28 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
     }
 
-    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model).c_str());
-    LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model).c_str());
-    if (ml.n_elements >= 1e12) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, ml.n_elements*1e-12);
-    } else if (ml.n_elements >= 1e9) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
-    } else if (ml.n_elements >= 1e6) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f M\n", __func__, ml.n_elements*1e-6);
+    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, type_name().c_str());
+    if (pimpl->n_elements >= 1e12) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, pimpl->n_elements*1e-12);
+    } else if (pimpl->n_elements >= 1e9) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, pimpl->n_elements*1e-9);
+    } else if (pimpl->n_elements >= 1e6) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f M\n", __func__, pimpl->n_elements*1e-6);
     } else {
-        LLAMA_LOG_INFO("%s: model params     = %.2f K\n", __func__, ml.n_elements*1e-3);
-    }
-    if (ml.n_bytes < GiB) {
-        LLAMA_LOG_INFO("%s: model size       = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0,        ml.n_bytes*8.0/ml.n_elements);
-    } else {
-        LLAMA_LOG_INFO("%s: model size       = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
+        LLAMA_LOG_INFO("%s: model params     = %.2f K\n", __func__, pimpl->n_elements*1e-3);
     }
 
     // general kv
-    LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
+    LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, name.c_str());
 
-    // special tokens
-    if (vocab.special_bos_id  != -1)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,     vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id  != -1)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,     vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_eot_id  != -1)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,     vocab.id_to_token[vocab.special_eot_id].text.c_str() );  }
-    if (vocab.special_eom_id  != -1)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,     vocab.id_to_token[vocab.special_eom_id].text.c_str() );  }
-    if (vocab.special_unk_id  != -1)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,     vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id  != -1)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,     vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id  != -1)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,     vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id  != -1)    { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,     vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id != -1)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id,    vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-
-    if (vocab.linefeed_id != -1)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,        vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
-
-    if (vocab.special_fim_pre_id != -1) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token[vocab.special_fim_pre_id].text.c_str() ); }
-    if (vocab.special_fim_suf_id != -1) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token[vocab.special_fim_suf_id].text.c_str() ); }
-    if (vocab.special_fim_mid_id != -1) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token[vocab.special_fim_mid_id].text.c_str() ); }
-    if (vocab.special_fim_pad_id != -1) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token[vocab.special_fim_pad_id].text.c_str() ); }
-    if (vocab.special_fim_rep_id != -1) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token[vocab.special_fim_rep_id].text.c_str() ); }
-    if (vocab.special_fim_sep_id != -1) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token[vocab.special_fim_sep_id].text.c_str() ); }
-
-    for (const auto & id : vocab.special_eog_ids) {
-        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
-    }
-
-    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
-
-    if (model.arch == LLM_ARCH_DEEPSEEK) {
+    if (arch == LLM_ARCH_DEEPSEEK) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
     }
 
-    if (model.arch == LLM_ARCH_DEEPSEEK2) {
+    if (arch == LLM_ARCH_DEEPSEEK2) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
         LLAMA_LOG_INFO("%s: n_lora_kv            = %d\n",     __func__, hparams.n_lora_kv);
@@ -1946,16 +3801,88 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
     }
 
-    if (model.arch == LLM_ARCH_QWEN2MOE) {
+    if (arch == LLM_ARCH_QWEN2MOE) {
         LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
     }
 
-    if (model.arch == LLM_ARCH_MINICPM || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
+    if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) {
         LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
         LLAMA_LOG_INFO("%s: f_residual_scale  = %f\n", __func__, hparams.f_residual_scale);
         LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
     }
+
+    vocab.print_info();
+}
+
+ggml_backend_dev_t llama_model::dev_layer(int il) const {
+    return pimpl->dev_layer.at(il).dev;
+}
+
+ggml_backend_dev_t llama_model::dev_output() const {
+    return pimpl->dev_output.dev;
+}
+
+template
+static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
+    ggml_init_params params = {
+        /*.mem_size   =*/ ggml_tensor_overhead()*8,
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+
+    ggml_context_ptr ctx { ggml_init(params) };
+    if (!ctx) {
+        throw std::runtime_error(format("failed to create ggml context"));
+    }
+
+    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
+    ggml_tensor * op_tensor = fn(ctx.get());
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (op_tensor->src[i] != nullptr) {
+            assert(op_tensor->src[i]->buffer == nullptr);
+            op_tensor->src[i]->buffer = buf.get();
+        }
+    }
+
+    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
+
+    return op_supported;
+}
+
+template
+static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) {
+    for (const auto & cur : buft_list) {
+        ggml_backend_dev_t cur_dev = cur.first;
+        ggml_backend_buffer_type_t cur_buft = cur.second;
+        if (buft_supported(cur_buft, cur_dev, fn)) {
+            return cur_buft;
+        }
+    }
+
+    throw std::runtime_error(format("no suitable buffer type found"));
+}
+
+ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
+    return ::select_buft(
+            *pimpl->dev_layer.at(il).buft_list,
+            [&](ggml_context * ctx) {
+                ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
+                ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
+                return ggml_add(ctx, cur, layer_dir);
+            });
+}
+
+const struct ggml_tensor * llama_model::get_tensor(const char * name) const {
+    auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
+            [name](const std::pair & it) {
+                return it.first == name;
+            });
+    if (it == tensors_by_name.end()) {
+        return nullptr;
+    }
+
+    return it->second;
 }
 
 //
@@ -1969,7 +3896,6 @@ struct llama_model_params llama_model_default_params() {
         /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ nullptr,
-        /*.rpc_servers                 =*/ nullptr,
         /*.progress_callback           =*/ nullptr,
         /*.progress_callback_user_data =*/ nullptr,
         /*.kv_overrides                =*/ nullptr,
@@ -1987,35 +3913,59 @@ struct llama_model_params llama_model_default_params() {
     return result;
 }
 
+const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model) {
+    return &model->vocab;
+}
+
 void llama_free_model(struct llama_model * model) {
+    llama_model_free(model);
+}
+
+void llama_model_free(struct llama_model * model) {
     delete model;
 }
 
-enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
-    return model->vocab.type;
-}
-
-int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->hparams.n_vocab;
-}
-
-int32_t llama_n_ctx_train(const struct llama_model * model) {
+int32_t llama_model_n_ctx_train(const struct llama_model * model) {
     return model->hparams.n_ctx_train;
 }
 
-int32_t llama_n_embd(const struct llama_model * model) {
+int32_t llama_model_n_embd(const struct llama_model * model) {
     return model->hparams.n_embd;
 }
 
-int32_t llama_n_layer(const struct llama_model * model) {
+int32_t llama_model_n_layer(const struct llama_model * model) {
     return model->hparams.n_layer;
 }
 
-int32_t llama_n_head(const struct llama_model * model) {
+int32_t llama_model_n_head(const struct llama_model * model) {
     return model->hparams.n_head();
 }
 
-enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+int32_t llama_model_n_head_kv(const struct llama_model * model) {
+    return model->hparams.n_head_kv();
+}
+
+// deprecated
+int32_t llama_n_ctx_train(const struct llama_model * model) {
+    return llama_model_n_ctx_train(model);
+}
+
+// deprecated
+int32_t llama_n_embd(const struct llama_model * model) {
+    return llama_model_n_embd(model);
+}
+
+// deprecated
+int32_t llama_n_layer(const struct llama_model * model) {
+    return llama_model_n_layer(model);
+}
+
+// deprecated
+int32_t llama_n_head(const struct llama_model * model) {
+    return llama_model_n_head(model);
+}
+
+enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
     switch (model->arch) {
         // these models do not use RoPE
         case LLM_ARCH_GPT2:
@@ -2029,6 +3979,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
         case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
         case LLM_ARCH_WAVTOKENIZER_DEC:
             return LLAMA_ROPE_TYPE_NONE;
 
@@ -2071,6 +4022,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_OLMOE:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_STARCODER2:
@@ -2093,7 +4045,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
     return LLAMA_ROPE_TYPE_NONE;
 }
 
-float llama_rope_freq_scale_train(const struct llama_model * model) {
+float llama_model_rope_freq_scale_train(const struct llama_model * model) {
     return model->hparams.rope_freq_scale_train;
 }
 
@@ -2137,18 +4089,26 @@ int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int3
 }
 
 int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
-    return snprintf(buf, buf_size, "%s %s %s",
-            llama_model_arch_name (*model).c_str(),
-            llama_model_type_name (*model).c_str(),
-            llama_model_ftype_name(*model).c_str());
+    return snprintf(buf, buf_size, "%s", model->desc().c_str());
 }
 
 uint64_t llama_model_size(const struct llama_model * model) {
-    return model->n_bytes;
+    return model->size();
+}
+
+const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
+    const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
+        : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
+    const auto & it = model->gguf_kv.find(key);
+    if (it == model->gguf_kv.end()) {
+        return nullptr;
+    }
+
+    return it->second.c_str();
 }
 
 uint64_t llama_model_n_params(const struct llama_model * model) {
-    return model->n_elements;
+    return model->n_elements();
 }
 
 bool llama_model_has_encoder(const struct llama_model * model) {
@@ -2174,6 +4134,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
     switch (model->arch) {
         case LLM_ARCH_MAMBA:  return true;
         case LLM_ARCH_RWKV6:  return true;
+        case LLM_ARCH_RWKV6QWEN2: return true;
         default:              return false;
     }
 }
diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h
index 5b23e2baf..7cf575872 100644
--- a/llama/llama.cpp/src/llama-model.h
+++ b/llama/llama.cpp/src/llama-model.h
@@ -4,81 +4,83 @@
 #include "llama-arch.h"
 #include "llama-hparams.h"
 #include "llama-vocab.h"
-#include "llama-mmap.h"
-
-#include "ggml-cpp.h"
 
+#include 
+#include 
+#include 
 #include 
 #include 
 
+struct llama_model_loader;
+
 // available models
-// TODO: this enum does not follow the enum naming convention
 enum llm_type {
-    MODEL_UNKNOWN,
-    MODEL_14M,
-    MODEL_17M,
-    MODEL_22M,
-    MODEL_33M,
-    MODEL_60M,
-    MODEL_70M,
-    MODEL_80M,
-    MODEL_109M,
-    MODEL_137M,
-    MODEL_160M,
-    MODEL_220M,
-    MODEL_250M,
-    MODEL_270M,
-    MODEL_335M,
-    MODEL_410M,
-    MODEL_450M,
-    MODEL_770M,
-    MODEL_780M,
-    MODEL_0_5B,
-    MODEL_1B,
-    MODEL_1_3B,
-    MODEL_1_4B,
-    MODEL_1_5B,
-    MODEL_1_6B,
-    MODEL_2B,
-    MODEL_2_8B,
-    MODEL_3B,
-    MODEL_4B,
-    MODEL_6B,
-    MODEL_6_9B,
-    MODEL_7B,
-    MODEL_8B,
-    MODEL_9B,
-    MODEL_11B,
-    MODEL_12B,
-    MODEL_13B,
-    MODEL_14B,
-    MODEL_15B,
-    MODEL_16B,
-    MODEL_20B,
-    MODEL_22B,
-    MODEL_30B,
-    MODEL_32B,
-    MODEL_34B,
-    MODEL_35B,
-    MODEL_40B,
-    MODEL_65B,
-    MODEL_70B,
-    MODEL_90B,
-    MODEL_236B,
-    MODEL_314B,
-    MODEL_671B,
-    MODEL_SMALL,
-    MODEL_MEDIUM,
-    MODEL_LARGE,
-    MODEL_XL,
-    MODEL_A1_7B,
-    MODEL_A2_7B,
-    MODEL_8x7B,
-    MODEL_8x22B,
-    MODEL_16x12B,
-    MODEL_10B_128x3_66B,
-    MODEL_57B_A14B,
-    MODEL_27B,
+    LLM_TYPE_UNKNOWN,
+    LLM_TYPE_14M,
+    LLM_TYPE_17M,
+    LLM_TYPE_22M,
+    LLM_TYPE_33M,
+    LLM_TYPE_60M,
+    LLM_TYPE_70M,
+    LLM_TYPE_80M,
+    LLM_TYPE_109M,
+    LLM_TYPE_137M,
+    LLM_TYPE_160M,
+    LLM_TYPE_220M,
+    LLM_TYPE_250M,
+    LLM_TYPE_270M,
+    LLM_TYPE_335M,
+    LLM_TYPE_410M,
+    LLM_TYPE_450M,
+    LLM_TYPE_770M,
+    LLM_TYPE_780M,
+    LLM_TYPE_0_5B,
+    LLM_TYPE_1B,
+    LLM_TYPE_1_3B,
+    LLM_TYPE_1_4B,
+    LLM_TYPE_1_5B,
+    LLM_TYPE_1_6B,
+    LLM_TYPE_2B,
+    LLM_TYPE_2_8B,
+    LLM_TYPE_3B,
+    LLM_TYPE_4B,
+    LLM_TYPE_6B,
+    LLM_TYPE_6_9B,
+    LLM_TYPE_7B,
+    LLM_TYPE_8B,
+    LLM_TYPE_9B,
+    LLM_TYPE_11B,
+    LLM_TYPE_12B,
+    LLM_TYPE_13B,
+    LLM_TYPE_14B,
+    LLM_TYPE_15B,
+    LLM_TYPE_16B,
+    LLM_TYPE_20B,
+    LLM_TYPE_22B,
+    LLM_TYPE_30B,
+    LLM_TYPE_32B,
+    LLM_TYPE_34B,
+    LLM_TYPE_35B,
+    LLM_TYPE_40B,
+    LLM_TYPE_65B,
+    LLM_TYPE_70B,
+    LLM_TYPE_90B,
+    LLM_TYPE_236B,
+    LLM_TYPE_314B,
+    LLM_TYPE_671B,
+    LLM_TYPE_SMALL,
+    LLM_TYPE_MEDIUM,
+    LLM_TYPE_LARGE,
+    LLM_TYPE_XL,
+    LLM_TYPE_A1_7B,
+    LLM_TYPE_A2_7B,
+    LLM_TYPE_8x7B,
+    LLM_TYPE_8x22B,
+    LLM_TYPE_16x12B,
+    LLM_TYPE_16x3_8B,
+    LLM_TYPE_10B_128x3_66B,
+    LLM_TYPE_57B_A14B,
+    LLM_TYPE_27B,
 };
 
 struct llama_layer_posnet {
@@ -243,15 +245,19 @@ struct llama_layer {
     struct ggml_tensor * time_mix_lerp_v     = nullptr;
     struct ggml_tensor * time_mix_lerp_r     = nullptr;
     struct ggml_tensor * time_mix_lerp_g     = nullptr;
+    struct ggml_tensor * time_mix_lerp_fused = nullptr;
 
-    struct ggml_tensor * time_mix_first      = nullptr;
-    struct ggml_tensor * time_mix_decay      = nullptr;
-    struct ggml_tensor * time_mix_decay_w1   = nullptr;
-    struct ggml_tensor * time_mix_decay_w2   = nullptr;
-    struct ggml_tensor * time_mix_key        = nullptr;
-    struct ggml_tensor * time_mix_value      = nullptr;
-    struct ggml_tensor * time_mix_receptance = nullptr;
-    struct ggml_tensor * time_mix_gate       = nullptr;
+    struct ggml_tensor * time_mix_first        = nullptr;
+    struct ggml_tensor * time_mix_decay        = nullptr;
+    struct ggml_tensor * time_mix_decay_w1     = nullptr;
+    struct ggml_tensor * time_mix_decay_w2     = nullptr;
+    struct ggml_tensor * time_mix_key          = nullptr;
+    struct ggml_tensor * time_mix_key_b        = nullptr;
+    struct ggml_tensor * time_mix_value        = nullptr;
+    struct ggml_tensor * time_mix_value_b      = nullptr;
+    struct ggml_tensor * time_mix_receptance   = nullptr;
+    struct ggml_tensor * time_mix_receptance_b = nullptr;
+    struct ggml_tensor * time_mix_gate         = nullptr;
 
     struct ggml_tensor * time_mix_ln     = nullptr;
     struct ggml_tensor * time_mix_ln_b   = nullptr;
@@ -280,7 +286,7 @@ struct llama_layer {
 
     struct ggml_tensor * bskcn_tv = nullptr;
 
-     // cross attention
+    // cross attention
     struct ggml_tensor * cross_attn_k_norm = nullptr;
     struct ggml_tensor * cross_attn_k_proj = nullptr;
     struct ggml_tensor * cross_attn_o_proj = nullptr;
@@ -296,11 +302,9 @@ struct llama_layer {
 };
 
 struct llama_model {
-    llm_type type = MODEL_UNKNOWN;
+    llm_type type = LLM_TYPE_UNKNOWN;
     llm_arch arch = LLM_ARCH_UNKNOWN;
 
-    llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
-
     std::string name = "n/a";
 
     llama_hparams hparams = {};
@@ -329,117 +333,53 @@ struct llama_model {
 
     std::vector layers;
 
+    llama_model_params params;
+
     // gguf metadata
     std::unordered_map gguf_kv;
 
-    llama_split_mode split_mode;
-    int main_gpu;
-    int n_gpu_layers;
-
-    std::vector rpc_servers;
-
     // list of devices used in this model
     std::vector devices;
 
-
-    // lists of buffer types used for each layer
-    using buft_list_t = std::vector>;
-    buft_list_t cpu_buft_list;
-    std::map gpu_buft_list;
-
-    struct layer_dev {
-        ggml_backend_dev_t dev;
-        buft_list_t * buft_list;
-    };
-
-    layer_dev dev_input = {};
-    layer_dev dev_output = {};
-    std::vector dev_layer;
-
-    // contexts where the model tensors metadata is stored
-    std::vector ctxs;
-
-    // the model memory buffers for the tensor data
-    std::vector bufs;
-
-    // model memory mapped files
-    llama_mmaps mappings;
-
-    // objects representing data potentially being locked in memory
-    llama_mlocks mlock_bufs;
-    llama_mlocks mlock_mmaps;
-
     // for quantize-stats only
     std::vector> tensors_by_name;
 
     int64_t t_load_us  = 0;
     int64_t t_start_us = 0;
 
-    // total number of parameters in the model
-    uint64_t n_elements = 0;
+    explicit llama_model(const struct llama_model_params & params);
+    ~llama_model();
 
-    // total size of all the tensors in the model in bytes
-    size_t  n_bytes     = 0;
+    void load_stats  (llama_model_loader & ml);
+    void load_arch   (llama_model_loader & ml);
+    void load_hparams(llama_model_loader & ml);
+    void load_vocab  (llama_model_loader & ml);
+    bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
+
+    std::string arch_name() const;
+    std::string type_name() const;
+
+    std::string desc() const;
+
+    size_t size() const;
+    size_t max_nodes() const;
+    size_t n_devices() const;
+
+    // total number of parameters in the model
+    uint64_t n_elements() const;
+
+    void print_info() const;
+
+    ggml_backend_dev_t dev_layer(int il) const;
+    ggml_backend_dev_t dev_output() const;
+
+    ggml_backend_buffer_type_t select_buft(int il) const;
+
+    const struct ggml_tensor * get_tensor(const char * name) const;
+
+private:
+    struct impl;
+    std::unique_ptr pimpl;
 };
 
 const char * llm_type_name(llm_type type);
-
-std::string llama_model_arch_name (const llama_model & model);
-std::string llama_model_type_name (const llama_model & model);
-std::string llama_model_ftype_name(const llama_model & model);
-
-template
-bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
-    ggml_init_params params = {
-        /*.mem_size   =*/ ggml_tensor_overhead()*8,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-
-    ggml_context_ptr ctx { ggml_init(params) };
-    if (!ctx) {
-        throw std::runtime_error("failed to create ggml context");
-    }
-
-    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
-    ggml_tensor * op_tensor = fn(ctx.get());
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (op_tensor->src[i] != nullptr) {
-            op_tensor->src[i]->buffer = buf.get();
-        }
-    }
-
-    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
-
-    return op_supported;
-}
-
-template
-ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) {
-    for (const auto & cur : buft_list) {
-        ggml_backend_dev_t cur_dev = cur.first;
-        ggml_backend_buffer_type_t cur_buft = cur.second;
-        if (buft_supported(cur_buft, cur_dev, fn)) {
-            return cur_buft;
-        }
-    }
-
-    throw std::runtime_error("no suitable buffer type found");
-}
-
-// used by llama_adapter_cvec
-ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
-
-// used by llama_adapter_lora
-struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name);
-
-size_t llama_model_max_nodes(const llama_model & model);
-
-struct llama_model_loader;
-
-// TODO: become llama_model methods
-void llm_load_stats     (llama_model_loader & ml, llama_model & model);
-void llm_load_arch      (llama_model_loader & ml, llama_model & model);
-void llm_load_hparams   (llama_model_loader & ml, llama_model & model);
-void llm_load_vocab     (llama_model_loader & ml, llama_model & model);
-void llm_load_print_meta(llama_model_loader & ml, llama_model & model);
diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp
index 27def6fd9..6eb1da08e 100644
--- a/llama/llama.cpp/src/llama-quant.cpp
+++ b/llama/llama.cpp/src/llama-quant.cpp
@@ -7,14 +7,12 @@
 #include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
 #include 
 
-// TODO: replace with ggml API call
-#define QK_K 256
-
 static void zeros(std::ofstream & file, size_t n) {
     char zero = 0;
     for (size_t i = 0; i < n; ++i) {
@@ -22,7 +20,7 @@ static void zeros(std::ofstream & file, size_t n) {
     }
 }
 
-struct quantize_state_internal {
+struct quantize_state_impl {
     const llama_model                 & model;
     const llama_model_quantize_params * params;
 
@@ -43,13 +41,13 @@ struct quantize_state_internal {
     // used to figure out if a model shares tok_embd with the output weight
     bool has_output = false;
 
-    quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
+    quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
         : model(model)
         , params(params)
         {}
 };
 
-static void llama_tensor_dequantize_internal(
+static void llama_tensor_dequantize_impl(
     struct ggml_tensor * tensor, std::vector> & output, std::vector & workers,
     const size_t nelements, const int nthread
 ) {
@@ -121,7 +119,7 @@ static void llama_tensor_dequantize_internal(
     workers.clear();
 }
 
-static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
+static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
     const std::string name = ggml_get_name(tensor);
 
     // TODO: avoid hardcoded tensor names - use the TN_* constants
@@ -154,8 +152,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
         if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
             new_type = qs.params->output_tensor_type;
         } else {
-            int nx = tensor->ne[0];
-            if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
+            const int64_t nx = tensor->ne[0];
+            const int64_t qk_k = ggml_blck_size(new_type);
+
+            if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
                 new_type = GGML_TYPE_Q8_0;
             }
             else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
@@ -235,7 +235,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
         else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
                 use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
         else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
-        if (qs.model.type == MODEL_70B) {
+        if (qs.model.type == LLM_TYPE_70B) {
             // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
             // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
             // nearly negligible increase in model size by quantizing this tensor with more bits:
@@ -367,20 +367,19 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
     //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
     //}
     bool convert_incompatible_tensor = false;
-    if (new_type == GGML_TYPE_Q2_K    || new_type == GGML_TYPE_Q3_K    || new_type == GGML_TYPE_Q4_K   ||
-        new_type == GGML_TYPE_Q5_K    || new_type == GGML_TYPE_Q6_K    || new_type == GGML_TYPE_IQ4_XS ||
-        new_type == GGML_TYPE_IQ2_XS  || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S  ||
-        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S   || new_type == GGML_TYPE_IQ3_S  ||
-        new_type == GGML_TYPE_IQ1_M) {
-        int nx = tensor->ne[0];
-        int ny = tensor->ne[1];
-        if (nx % QK_K != 0) {
-            LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
+    {
+        const int64_t nx = tensor->ne[0];
+        const int64_t ny = tensor->ne[1];
+        const int64_t qk_k = ggml_blck_size(new_type);
+
+        if (nx % qk_k != 0) {
+            LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
             convert_incompatible_tensor = true;
         } else {
             ++qs.n_k_quantized;
         }
     }
+
     if (convert_incompatible_tensor) {
         switch (new_type) {
             case GGML_TYPE_TQ1_0:
@@ -410,7 +409,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
     return new_type;
 }
 
-static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) {
+static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) {
     if (nthread < 2) {
         // single-thread
         size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
@@ -464,7 +463,7 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa
     return new_size;
 }
 
-static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
+static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
     ggml_type default_type;
     llama_ftype ftype = params->ftype;
 
@@ -526,18 +525,21 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         auto v = (std::vector*)params->kv_overrides;
         kv_overrides = v->data();
     }
-    llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
+
+    std::vector splits = {};
+    llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
     ml.init_mappings(false); // no prefetching
 
-    llama_model model;
-    llm_load_arch   (ml, model);
-    llm_load_hparams(ml, model);
-    llm_load_stats  (ml, model);
+    llama_model model(llama_model_default_params());
 
-    struct quantize_state_internal qs(model, params);
+    model.load_arch   (ml);
+    model.load_hparams(ml);
+    model.load_stats  (ml);
+
+    struct quantize_state_impl qs(model, params);
 
     if (params->only_copy) {
-        ftype = model.ftype;
+        ftype = ml.ftype;
     }
     const std::unordered_map> * imatrix_data = nullptr;
     if (params->imatrix) {
@@ -621,7 +623,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
     qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
 
-    // sanity checks
+    // sanity checks for models that have attention layers
+    if (qs.n_attention_wv != 0)
     {
         const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
         // attention layers have a non-zero number of kv heads
@@ -761,6 +764,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         quantize &= name.find("time_mix_w2.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
 
         // do not quantize relative position bias (T5)
         quantize &= name.find("attn_rel_b.weight") == std::string::npos;
@@ -839,7 +843,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
                 throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
             } else {
-                llama_tensor_dequantize_internal(tensor, f32_conv_buf, workers, nelements, nthread);
+                llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
                 f32_data = (float *) f32_conv_buf.data();
             }
 
@@ -868,7 +872,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
                 const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
 
-                new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
+                new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
             }
             LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
         }
@@ -877,7 +881,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
         // update the gguf meta data as we go
         gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
-        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data, new_size);
+        GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
+        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
 
         // write tensor data + padding
         fout.write((const char *) new_data, new_size);
@@ -921,7 +926,7 @@ uint32_t llama_model_quantize(
         const char * fname_out,
         const llama_model_quantize_params * params) {
     try {
-        llama_model_quantize_internal(fname_inp, fname_out, params);
+        llama_model_quantize_impl(fname_inp, fname_out, params);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
         return 1;
diff --git a/llama/llama.cpp/src/llama-sampling.cpp b/llama/llama.cpp/src/llama-sampling.cpp
index 69cea2f14..f40bf2db8 100644
--- a/llama/llama.cpp/src/llama-sampling.cpp
+++ b/llama/llama.cpp/src/llama-sampling.cpp
@@ -257,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
             for (int i = 0; i < (int)cur_p->size; ++i) {
                 const float val = cur_p->data[i].logit;
                 int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
-                ib = std::max(0, std::min(nbuckets-1, ib));
+                ib = std::max(0, std::min(nbuckets - 1, ib));
                 bucket_idx[i] = ib;
                 ++histo[ib];
             }
@@ -280,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
             for (int i = 0; i < (int)cur_p->size; ++i) {
                 int j = bucket_idx[i];
                 if (j >= ib) {
-                    *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
+                    *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
                 }
             }
 
             ptr = tmp_tokens.data();
             int ndone = 0;
-            for (int j = nbuckets-1; j > ib; --j) {
+            for (int j = nbuckets - 1; j > ib; --j) {
                 std::sort(ptr, ptr + histo[j], comp);
                 ptr += histo[j];
                 ndone += histo[j];
@@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
 
 // llama_sampler API
 
+struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
+    return new llama_sampler {
+        /* .iface = */ iface,
+        /* .ctx   = */ ctx,
+    };
+}
+
 const char * llama_sampler_name(const struct llama_sampler * smpl) {
     if (!smpl->iface) {
         return "(null)";
@@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
     }
 
     if (smpl->ctx == nullptr) {
-        return new llama_sampler {
+        return llama_sampler_init(
             /* .iface = */ smpl->iface,
-            /* .ctx   = */ nullptr,
-        };
+            /* .ctx   = */ nullptr
+        );
     }
 
     GGML_ABORT("the sampler does not support cloning");
@@ -371,7 +378,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
 llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
     const auto * logits = llama_get_logits_ith(ctx, idx);
 
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    const int n_vocab = llama_vocab_n_tokens(vocab);
 
     // TODO: do not allocate each time
     std::vector cur;
@@ -469,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
 };
 
 struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_chain_i,
         /* .ctx   = */ new llama_sampler_chain {
             /* .params      = */ params,
             /* .samplers    = */ {},
             /* .t_sample_us = */ 0,
             /* .n_sample    = */ 0,
-        },
-    };
+        }
+    );
 }
 
 void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
@@ -543,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
 };
 
 struct llama_sampler * llama_sampler_init_greedy() {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_greedy_i,
-        /* .ctx   = */ nullptr,
-    };
+        /* .ctx   = */ nullptr
+    );
 }
 
 // dist
@@ -605,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
 
 struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_dist_i,
         /* .ctx   = */ new llama_sampler_dist {
             /* .seed     = */ seed,
             /* .seed_cur = */ seed_cur,
             /* .rng      = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // softmax
@@ -635,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
 };
 
 struct llama_sampler * llama_sampler_init_softmax() {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_softmax_i,
-        /* .ctx   = */ nullptr,
-    };
+        /* .ctx   = */ nullptr
+    );
 }
 
 // top-k
@@ -675,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
 };
 
 struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_top_k_i,
         /* .ctx   = */ new llama_sampler_top_k {
             /* .k = */ k,
-        },
-    };
+        }
+    );
 }
 
 // top-p
@@ -741,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
 };
 
 struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_top_p_i,
         /* .ctx   = */ new llama_sampler_top_p {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
-        },
-    };
+        }
+    );
 }
 
 // min-p
@@ -837,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
 };
 
 struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_min_p_i,
         /* .ctx   = */ new llama_sampler_min_p {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
-        },
-    };
+        }
+    );
 }
 
 // typical
@@ -936,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
 };
 
 struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_typical_i,
         /* .ctx   = */ new llama_sampler_typical {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
-        },
-    };
+        }
+    );
 }
 
 // temp
@@ -980,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
 };
 
 struct llama_sampler * llama_sampler_init_temp(float temp) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_temp_i,
         /* .ctx   = */ new llama_sampler_temp {
             /*.temp = */ temp,
-        },
-    };
+        }
+    );
 }
 
 // temp-ext
@@ -1090,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
 };
 
 struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_temp_ext_i,
         /* .ctx   = */ new llama_sampler_temp_ext {
             /* .temp     = */ temp,
             /* .delta    = */ delta,
             /* .exponent = */ exponent,
-        },
-    };
+        }
+    );
 }
 
 // xtc
@@ -1182,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
 
 struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_xtc_i,
         /* .ctx   = */ new llama_sampler_xtc {
             /* .probability   = */ p,
@@ -1191,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
             /* .seed          = */ seed,
             /* .seed_cur      = */ seed_cur,
             /* .rng           = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // mirostat
@@ -1289,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
 
 struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_mirostat_i,
         /* .ctx   = */ new llama_sampler_mirostat {
             /* .n_vocab  = */ n_vocab,
@@ -1300,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
             /* .m        = */ m,
             /* .mu       = */ 2.0f*tau,
             /* .rng      = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // mirostat v2
@@ -1388,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
 
 struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_mirostat_v2_i,
         /* .ctx   = */ new llama_sampler_mirostat_v2 {
             /* .seed     = */ seed,
@@ -1397,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
             /* .eta      = */ eta,
             /* .mu       = */ 2.0f*tau,
             /* .rng      = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // grammar
@@ -1430,13 +1440,30 @@ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token
     }
 }
 
+// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
+static struct llama_sampler * llama_sampler_init_grammar_impl(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root,
+                              bool lazy,
+                     const char ** trigger_words,
+                            size_t num_trigger_words,
+               const llama_token * trigger_tokens,
+                            size_t num_trigger_tokens);
+
 static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
     auto * ctx = (llama_sampler_grammar *) smpl->ctx;
     if (!ctx->grammar) {
         return;
     }
 
-    auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
+    std::vector  trigger_words;
+    for (auto & word : ctx->grammar->trigger_words) {
+        trigger_words.push_back(word.c_str());
+    }
+    auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
+                                                 ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
+                                                 ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
 
     llama_grammar_free_impl(ctx->grammar);
     ctx->grammar = grammar_new;
@@ -1445,7 +1472,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
 static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
 
-    auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
+    auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
 
     // copy the state
     {
@@ -1481,29 +1508,55 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
     /* .free   = */ llama_sampler_grammar_free,
 };
 
-struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
+static struct llama_sampler * llama_sampler_init_grammar_impl(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root,
+                              bool lazy,
+                     const char ** trigger_words,
+                            size_t num_trigger_words,
+               const llama_token * trigger_tokens,
+                            size_t num_trigger_tokens) {
     auto * ctx = new llama_sampler_grammar;
 
     if (grammar_str != nullptr && grammar_str[0] != '\0') {
         *ctx = {
-            /* .vocab        = */ &vocab,
+            /* .vocab        = */ vocab,
             /* .grammar_str  = */ grammar_str,
             /* .grammar_root = */ grammar_root,
-            /* .grammar      = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
+            /* .grammar      = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
         };
     } else {
         *ctx = {
-            /* .vocab        = */ &vocab,
+            /* .vocab        = */ vocab,
             /* .grammar_str  = */ {},
             /* .grammar_root = */ {},
             /* .grammar      = */ nullptr,
         };
     }
 
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_grammar_i,
-        /* .ctx   = */ ctx,
-    };
+        /* .ctx   = */ ctx
+    );
+}
+
+struct llama_sampler * llama_sampler_init_grammar(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root) {
+    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
+}
+
+struct llama_sampler * llama_sampler_init_grammar_lazy(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root,
+                     const char ** trigger_words,
+                            size_t num_trigger_words,
+               const llama_token * trigger_tokens,
+                            size_t num_trigger_tokens) {
+    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
 }
 
 // penalties
@@ -1632,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties(
         float penalty_present) {
     penalty_last_n = std::max(penalty_last_n, 0);
 
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_penalties_i,
         /* .ctx   = */ new llama_sampler_penalties {
             /* .penalty_last_n  = */ penalty_last_n,
@@ -1641,8 +1694,75 @@ struct llama_sampler * llama_sampler_init_penalties(
             /* .penalty_present = */ penalty_present,
             /* .prev            = */ ring_buffer(penalty_last_n),
             /* .token_count     = */ {},
-        },
-    };
+        }
+    );
+}
+
+// top-n-sigma
+
+struct llama_sampler_top_n_sigma {
+    const float n;
+};
+
+static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
+    return "top-n-sigma";
+}
+
+static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
+
+    // find max logit and calculate mean
+    float max = cur_p->data[0].logit;
+    float logits_sum = 0;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (cur_p->data[i].logit > max) {
+            max = cur_p->data[i].logit;
+        }
+        logits_sum += cur_p->data[i].logit;
+    }
+    float mean = logits_sum/cur_p->size;
+
+    // calculate standard deviation
+    float acc = 0;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        acc += pow(cur_p->data[i].logit - mean, 2);
+    }
+    float std = sqrt(acc/cur_p->size);
+
+    //apply mask
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (cur_p->data[i].logit < max - (ctx->n * std)) {
+            cur_p->data[i].logit = -INFINITY;
+        }
+    }
+    llama_sampler_softmax_impl(cur_p);
+}
+
+static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
+    return llama_sampler_init_top_n_sigma(ctx->n);
+}
+
+static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_top_n_sigma *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
+    /* .name   = */ llama_sampler_top_n_sigma_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_top_n_sigma_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_top_n_sigma_clone,
+    /* .free   = */ llama_sampler_top_n_sigma_free,
+};
+
+struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
+    return llama_sampler_init(
+        /* .iface = */ &llama_sampler_top_n_sigma_i,
+        /* .ctx   = */ new llama_sampler_top_n_sigma {
+            /* .n = */ n,
+        }
+    );
 }
 
 // DRY
@@ -1663,8 +1783,8 @@ struct llama_sampler_dry {
 
 // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
 static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) {
-    for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
-        std::string word = llama_detokenize(vocab, {token_id}, true);
+    for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
+        std::string word = vocab.detokenize({token_id}, true);
         if (word.find(str) != std::string::npos) {
             token_sequences.emplace(token_id, std::vector());
         } else {
@@ -1681,7 +1801,7 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
                     }
                 }
                 if (match) {
-                    std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
+                    std::vector tokenization = vocab.tokenize(str.substr(i), false, false);
                     if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
                         tokenization.resize(max_tail_len);
                     }
@@ -1832,7 +1952,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
                 ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
                 if (n > 0) {
                     lt = k;
-                    rt = k+n-1;
+                    rt = k + n - 1;
                 }
             } else {
                 // If k is inside the current Z-box, consider two cases.
@@ -1937,7 +2057,7 @@ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler
     llama_vocab dummy_vocab;
 
     // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
-    auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
+    auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
 
     // Copy the state, including the processed breakers
     {
@@ -1964,7 +2084,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
     /* .free   = */ llama_sampler_dry_free,
 };
 
-struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
+struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
     int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
     std::unordered_multimap> processed_breakers;
     const int MAX_CHAR_LEN = 40;
@@ -1991,11 +2111,11 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
                 sequence_break.resize(MAX_CHAR_LEN);
             }
 
-            get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
+            get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
         }
     }
 
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_dry_i,
         /* .ctx   = */ new llama_sampler_dry {
             /* .total_context_size     = */ context_size,
@@ -2007,14 +2127,14 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
             /* .dry_repeat_count       = */ dry_enabled ? std::vector(effective_dry_penalty_last_n, 0) : std::vector{},
             /* .dry_max_token_repeat   = */ {},
             /* .last_tokens            = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0),
-        },
-    };
+        }
+    );
 }
 
 // wrapper for test-sampling.cpp
 struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers) {
     llama_vocab dummy_vocab;
-    auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
+    auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
     auto * ctx = (llama_sampler_dry *) result->ctx;
 
     // Process the token-based sequence breakers
@@ -2109,14 +2229,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
                          int32_t   n_vocab,
                          int32_t   n_logit_bias,
           const llama_logit_bias * logit_bias) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_logit_bias_i,
         /* .ctx   = */ new llama_sampler_logit_bias {
             /* .n_vocab    = */ n_vocab,
             /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias),
             /* .to_search  = */ {},
-        },
-    };
+        }
+    );
 }
 
 // infill
@@ -2153,7 +2273,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     float p_eog_sum = 0.0f;
 
     for (size_t i = 0; i < cur_p->size; ++i) {
-        if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+        if (ctx->vocab->is_eog(cur_p->data[i].id)) {
             p_eog_sum += cur_p->data[i].p;
         } else {
             p_txt_sum += cur_p->data[i].p;
@@ -2175,7 +2295,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
         float p_sum = 0.0f;
 
         for (size_t i = 0; i < size_org; ++i) {
-            if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+            if (ctx->vocab->is_eog(cur_p->data[i].id)) {
                 p_sum += cur_p->data[i].p;
 
                 cur_p->data[cur_p->size++] = cur_p->data[i];
@@ -2203,17 +2323,17 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
                 continue;
             }
 
-            int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+            int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
             if (len0 < 0) {
                 ctx->buf0.resize(len0);
-                len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+                len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
                 assert(len0 > 0);
             }
 
-            int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+            int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
             if (len1 < 0) {
                 ctx->buf1.resize(len1);
-                len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+                len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
                 assert(len1 > 0);
             }
 
@@ -2248,7 +2368,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
 
     for (size_t i = 0; i < size_org; ++i) {
-        const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+        const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
 
         if (cur_p->data[i].p < thold && !is_eog) {
             continue;
@@ -2269,7 +2389,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     // if no non-EOG tokens are left -> reduce cur_p to single EOT token
     if (n_non_eog == 0) {
         cur_p->size = 1;
-        cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
+        cur_p->data[0].id = ctx->vocab->token_eot();
         cur_p->data[0].logit = 1.0f;
 
         return;
@@ -2291,7 +2411,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
 
     for (size_t i = 0; i < size_org; ++i) {
-        const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+        const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
 
         if (cur_p->data[i].p < thold && !is_eog) {
             continue;
@@ -2314,7 +2434,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
 
 static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
-    return llama_sampler_init_infill_impl(*ctx->vocab);
+    return llama_sampler_init_infill(ctx->vocab);
 }
 
 static void llama_sampler_infill_free(struct llama_sampler * smpl) {
@@ -2330,16 +2450,15 @@ static struct llama_sampler_i llama_sampler_infill_i = {
     /* .free   = */ llama_sampler_infill_free,
 };
 
-struct llama_sampler * llama_sampler_init_infill_impl(
-        const struct llama_vocab & vocab) {
-    return new llama_sampler {
+struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_infill_i,
         /* .ctx   = */ new llama_sampler_infill {
-            /* .vocab = */ &vocab,
-            /* .buf0 = */ std::vector(512),
-            /* .buf1 = */ std::vector(512),
-        },
-    };
+            /* .vocab = */ vocab,
+            /* .buf0  = */ std::vector(512),
+            /* .buf1  = */ std::vector(512),
+        }
+    );
 }
 
 // utils
diff --git a/llama/llama.cpp/src/llama-sampling.h b/llama/llama.cpp/src/llama-sampling.h
index 919f6fdfc..759dd7dcb 100644
--- a/llama/llama.cpp/src/llama-sampling.h
+++ b/llama/llama.cpp/src/llama-sampling.h
@@ -2,7 +2,9 @@
 
 // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
 
-#include "llama-grammar.h"
+#include "llama.h"
+
+#include 
 
 struct llama_vocab;
 struct llama_grammar;
@@ -21,24 +23,6 @@ struct llama_sampler_chain {
     mutable int32_t n_sample;
 };
 
-struct llama_sampler * llama_sampler_init_grammar_impl(
-        const struct llama_vocab & vocab,
-                      const char * grammar_str,
-                      const char * grammar_root);
-
-struct llama_sampler * llama_sampler_init_infill_impl(
-        const struct llama_vocab & vocab);
-
-struct llama_sampler * llama_sampler_init_dry_impl(
-        const struct llama_vocab &  vocab,
-                         int32_t    context_size,
-                           float    dry_multiplier,
-                           float    dry_base,
-                         int32_t    dry_allowed_length,
-                         int32_t    dry_penalty_last_n,
-                      const char ** seq_breakers,
-                          size_t    num_breakers);
-
 struct llama_sampler * llama_sampler_init_dry_testing(
                          int32_t   context_size,
                            float   dry_multiplier,
diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp
index 8f44705ab..1ca827ebd 100644
--- a/llama/llama.cpp/src/llama-vocab.cpp
+++ b/llama/llama.cpp/src/llama-vocab.cpp
@@ -1,6 +1,7 @@
 #include "llama-vocab.h"
 
 #include "llama-impl.h"
+#include "llama-model-loader.h"
 
 #include "unicode.h"
 
@@ -11,8 +12,10 @@
 #include 
 #include 
 #include 
+#include 
 #include 
-#include 
+#include 
+#include 
 
 //
 // helpers
@@ -62,96 +65,14 @@ struct naive_trie {
 };
 
 //
-// impl
+// tokenizers
 //
 
 struct llm_tokenizer {
-   llm_tokenizer() {}
-   virtual ~llm_tokenizer() = default;
+    llm_tokenizer() {}
+    virtual ~llm_tokenizer() = default;
 };
 
-llama_vocab::~llama_vocab() {
-    delete tokenizer;
-}
-
-int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
-    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
-    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
-    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
-    GGML_ASSERT(token_right.find('\n') == std::string::npos);
-
-    auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
-    if (it == bpe_ranks.end()) {
-        return -1;
-    }
-
-    return it->second;
-}
-
-static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
-    return vocab.type;
-}
-
-static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
-}
-
-static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
-}
-
-static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
-}
-
-static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
-}
-
-static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
-}
-
-static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
-}
-
-static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    GGML_ASSERT(llama_is_byte_token(vocab, id));
-    const auto & token_data = vocab.id_to_token.at(id);
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            auto buf = token_data.text.substr(3, 2);
-            return strtol(buf.c_str(), NULL, 16);
-        }
-        case LLAMA_VOCAB_TYPE_BPE: {
-            GGML_ABORT("fatal error");
-            //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT?
-        }
-        case LLAMA_VOCAB_TYPE_WPM: {
-            GGML_ABORT("fatal error");
-        }
-        default:
-            GGML_ABORT("fatal error");
-    }
-}
-
-static void llama_escape_whitespace(std::string & text) {
-    replace_all(text, " ", "\xe2\x96\x81");
-}
-
-static void llama_unescape_whitespace(std::string & word) {
-    replace_all(word, "\xe2\x96\x81", " ");
-}
-
 struct llm_symbol {
     using index = int;
     index prev;
@@ -183,14 +104,13 @@ struct llm_bigram_spm {
 };
 
 struct llm_tokenizer_spm : llm_tokenizer {
-    llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+    llm_tokenizer_spm(const llama_vocab & /*vocab*/) {}
 };
 
 struct llm_tokenizer_spm_session {
     llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
-
+    void tokenize(const std::string & text, std::vector & output) {
         // split string into utf8 chars
         int index = 0;
         size_t offs = 0;
@@ -249,13 +169,13 @@ struct llm_tokenizer_spm_session {
     }
 
 private:
-    void resegment(llm_symbol & symbol, std::vector & output) {
+    void resegment(llm_symbol & symbol, std::vector & output) {
         auto text = std::string(symbol.text, symbol.n);
-        auto token = vocab.token_to_id.find(text);
+        auto token = vocab.text_to_token(text);
 
         // Do we need to support is_unused?
-        if (token != vocab.token_to_id.end()) {
-            output.push_back((*token).second);
+        if (token != LLAMA_TOKEN_NULL) {
+            output.push_back(token);
             return;
         }
 
@@ -265,8 +185,8 @@ private:
             // output any symbols that did not form tokens as bytes.
             output.reserve(output.size() + symbol.n);
             for (int j = 0; j < (int)symbol.n; ++j) {
-                llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]);
-                output.push_back(token_id);
+                llama_token id = vocab.byte_to_token(symbol.text[j]);
+                output.push_back(id);
             }
             return;
         }
@@ -280,17 +200,17 @@ private:
             return;
         }
         const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
-        auto token = vocab.token_to_id.find(text);
+        auto token = vocab.text_to_token(text);
 
-        if (token == vocab.token_to_id.end()) {
+        if (token == LLAMA_TOKEN_NULL) {
             return;
         }
 
-        if (static_cast((*token).second) >= vocab.id_to_token.size()) {
+        if (static_cast(token) >= vocab.n_tokens()) {
             return;
         }
 
-        const auto & tok_data = vocab.id_to_token[(*token).second];
+        const auto & tok_data = vocab.get_token_data(token);
 
         llm_bigram_spm bigram;
         bigram.left  = left;
@@ -353,9 +273,9 @@ struct llm_bigram_bpe {
 };
 
 struct llm_tokenizer_bpe : llm_tokenizer {
-    llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
-        GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
-        switch (vocab.type_pre) {
+    llm_tokenizer_bpe(const llama_vocab & vocab) {
+        GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_BPE);
+        switch (vocab.get_pre_type()) {
             case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
                 regex_exprs = {
                     // original regex from tokenizer.json
@@ -488,39 +408,38 @@ struct llm_tokenizer_bpe : llm_tokenizer {
 };
 
 struct llm_tokenizer_bpe_session {
-    llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
-        bpe_tokenizer(static_cast(vocab.tokenizer)) {}
+    llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
 
-    static void append(const llama_vocab::id token_id, std::vector & output)  {
+    static void append(const llama_token token_id, std::vector & output)  {
         output.push_back(token_id);
     }
 
-    bool append_bos(std::vector & output) const {
-        if (vocab.tokenizer_add_bos) {
-            GGML_ASSERT(vocab.special_bos_id != -1);
-            output.push_back(vocab.special_bos_id);
+    bool append_bos(std::vector & output) const {
+        if (vocab.get_add_bos()) {
+            GGML_ASSERT(vocab.token_bos() != LLAMA_TOKEN_NULL);
+            output.push_back(vocab.token_bos());
             return true;
         }
         return false;
     }
 
-    bool append_eos(std::vector & output) const {
-        if (vocab.tokenizer_add_eos) {
-            GGML_ASSERT(vocab.special_eos_id != -1);
-            output.push_back(vocab.special_eos_id);
+    bool append_eos(std::vector & output) const {
+        if (vocab.get_add_eos()) {
+            GGML_ASSERT(vocab.token_eos() != LLAMA_TOKEN_NULL);
+            output.push_back(vocab.token_eos());
             return true;
         }
         return false;
     }
 
-    void check_double_bos_eos(const std::vector & output) const {
-        if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+    void check_double_bos_eos(const std::vector & output) const {
+        if (vocab.get_add_bos() && output.size() >= 2 && output[1] == vocab.token_bos()) {
             LLAMA_LOG_WARN(
                 "%s: Added a BOS token to the prompt as specified by the model but the prompt "
                 "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
                 "Are you sure this is what you want?\n", __FUNCTION__);
         }
-        if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
+        if (vocab.get_add_eos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) {
             LLAMA_LOG_WARN(
                 "%s: Added a EOS token to the prompt as specified by the model but the prompt "
                 "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
@@ -528,9 +447,9 @@ struct llm_tokenizer_bpe_session {
         }
     }
 
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) {
         int final_prev_index = -1;
-        const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
+        const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs);
 
         symbols_final.clear();
 
@@ -541,7 +460,8 @@ struct llm_tokenizer_bpe_session {
             int index = 0;
             size_t offset = 0;
 
-            if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+            //if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+            if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
                 symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
                 offset = word.size();
             }
@@ -615,18 +535,18 @@ struct llm_tokenizer_bpe_session {
                 }
 
                 const std::string str = std::string(symbol.text, symbol.n);
-                const auto token = vocab.token_to_id.find(str);
+                const auto token = vocab.text_to_token(str);
 
-                if (token == vocab.token_to_id.end()) {
+                if (token == LLAMA_TOKEN_NULL) {
                     for (auto j = str.begin(); j != str.end(); ++j) {
                         std::string byte_str(1, *j);
-                        auto token_multibyte = vocab.token_to_id.find(byte_str);
-                        if (token_multibyte != vocab.token_to_id.end()) {
-                            output.push_back(token_multibyte->second);
+                        auto token_multibyte = vocab.text_to_token(byte_str);
+                        if (token_multibyte != LLAMA_TOKEN_NULL) {
+                            output.push_back(token_multibyte);
                         }
                     }
                 } else {
-                    output.push_back((*token).second);
+                    output.push_back(token);
                 }
             }
         }
@@ -660,7 +580,7 @@ private:
     }
 
     const llama_vocab & vocab;
-    const llm_tokenizer_bpe * bpe_tokenizer;
+    const llm_tokenizer_bpe & tokenizer;
 
     std::vector symbols;
     std::vector symbols_final;
@@ -672,14 +592,13 @@ private:
 //
 
 struct llm_tokenizer_wpm : llm_tokenizer {
-    llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+    llm_tokenizer_wpm(const llama_vocab & /*vocab*/) {}
 };
 
 struct llm_tokenizer_wpm_session {
     llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
-        const auto & token_map = vocab.token_to_id;
+    void tokenize(const std::string & text, std::vector & output) {
         // normalize and split by whitespace
         std::vector words = preprocess(text);
         // bos token prepended already
@@ -702,10 +621,10 @@ struct llm_tokenizer_wpm_session {
             for (int i = 0; i < n; ++i) {
                 // loop through possible match length
                 bool match = false;
-                for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
-                    auto it = token_map.find(word1.substr(i, j - i));
-                    if (it != token_map.end()) {
-                        output.push_back(it->second);
+                for (int j = std::min(n, i + vocab.max_token_len() + 1); j > i; j--) {
+                    auto id = vocab.text_to_token(word1.substr(i, j - i));
+                    if (id != LLAMA_TOKEN_NULL) {
+                        output.push_back(id);
                         match = true;
                         i = j - 1;
                         break;
@@ -720,7 +639,7 @@ struct llm_tokenizer_wpm_session {
 
             // we didn't find any matches for this word
             if (current_tokens == output.size()) {
-                output.push_back(vocab.special_unk_id);
+                output.push_back(vocab.token_unk());
             }
         }
     }
@@ -789,45 +708,45 @@ private:
 //
 
 struct llm_tokenizer_ugm : llm_tokenizer {
-    llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
-        if (vocab.precompiled_charsmap.size() > 0) {
+    llm_tokenizer_ugm(const llama_vocab & vocab, const std::vector & precompiled_charsmap) {
+        if (precompiled_charsmap.size() > 0) {
             size_t charsmap_offset = 0;
 
             // First four bytes of precompiled_charsmap contains length of binary
             // blob containing XOR-compressed compact double array (XCDA) entries
-            uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0];
+            uint32_t xcda_blob_size = *(const uint32_t *) &precompiled_charsmap[0];
             charsmap_offset += sizeof(xcda_blob_size);
-            if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) {
+            if (xcda_blob_size + charsmap_offset >= precompiled_charsmap.size()) {
                 throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
             }
 
             // Next xcda_blob_size bytes contain entries of XOR-compressed compact
             // double array (XCDA). Each entry is bit-packed into a 32-bit integer.
-            xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset];
+            xcda_array = (const uint32_t *) &precompiled_charsmap[charsmap_offset];
             xcda_array_size = xcda_blob_size / sizeof(uint32_t);
             charsmap_offset += xcda_blob_size;
 
             // Remaining bytes of precompiled charsmap contain null-terminated
             // replacement strings for prefixes matched by the XCDA.
-            prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
-            prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
+            prefix_replacements = &precompiled_charsmap[charsmap_offset];
+            prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset;
         }
 
-        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
-            const auto &token_data = vocab.id_to_token[id];
+        for (uint32_t id = 0; id < vocab.n_tokens(); ++id) {
+            const auto & token_data = vocab.get_token_data(id);
 
-            if (llama_is_normal_token(vocab, id)) {
+            if (vocab.is_normal(id)) {
                 min_score = std::min(min_score, token_data.score);
                 max_score = std::max(max_score, token_data.score);
             }
 
-            if (llama_is_normal_token(vocab, id) ||
-                llama_is_user_defined_token(vocab, id) ||
-                llama_is_unused_token(vocab, id)) {
+            if (vocab.is_normal(id) ||
+                vocab.is_user_defined(id) ||
+                vocab.is_unused(id)) {
                 token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
             }
 
-            if (llama_is_user_defined_token(vocab, id)) {
+            if (vocab.is_user_defined(id)) {
                 user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
             }
         }
@@ -856,8 +775,7 @@ struct llm_tokenizer_ugm : llm_tokenizer {
 };
 
 struct llm_tokenizer_ugm_session {
-    llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
-        ugm_tokenizer(static_cast(vocab.tokenizer)) {}
+    llm_tokenizer_ugm_session(const llama_vocab & vocab, const llm_tokenizer_ugm & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
 
     /* This implementation is based on SentencePiece optimized Viterbi algorithm for
      * unigram language models. The general idea is to:
@@ -872,7 +790,7 @@ struct llm_tokenizer_ugm_session {
      * After processing the whole sequence we backtrack from the end to get
      * the best tokenization.
     */
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) {
         // get current size of output (for reversal later)
         size_t output_size = output.size();
 
@@ -885,9 +803,9 @@ struct llm_tokenizer_ugm_session {
         }
 
         // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
-        std::vector tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX});
+        std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
         // at the beginning tokenization score is zero
-        tokenization_results[0] = { vocab.special_unk_id, 0, 0 };
+        tokenization_results[0] = { vocab.token_unk(), 0, 0 };
 
         for (size_t input_offset = 0; input_offset < input_len;) {
             size_t prefix_offset = input_offset;
@@ -897,7 +815,7 @@ struct llm_tokenizer_ugm_session {
             // traverse the token matcher trie to find a matching token
             bool single_codepoint_token_found = false;
             const struct best_tokenization & current_best = tokenization_results[input_offset];
-            const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
+            const struct naive_trie * node = tokenizer.token_matcher.traverse(normalized[prefix_offset++]);
 
             while (prefix_offset <= input_len && node != NULL) {
                 // check if we found valid token in prefix
@@ -907,13 +825,13 @@ struct llm_tokenizer_ugm_session {
                         single_codepoint_token_found = true;
                     }
                     llama_token token_id = node->value;
-                    const auto & token_data = vocab.id_to_token[token_id];
+                    const auto & token_data = vocab.get_token_data(token_id);
 
                     // we set the user-defined token scores to 0 to make them more likely to be selected
                     // (normal token scores are log probabilities, so they are negative)
                     // score type is double here to make tokenization results exactly
                     // the same as in the HF tokenizer using SentencePiece
-                    const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score;
+                    const double token_score = vocab.is_user_defined(token_id) ? 0.0 : token_data.score;
                     const double challenger_score = current_best.score_sum + token_score;
                     struct best_tokenization & current_champ = tokenization_results[prefix_offset];
                     if (challenger_score > current_champ.score_sum) {
@@ -927,11 +845,11 @@ struct llm_tokenizer_ugm_session {
             // if we didn't find a valid token corresponding to the whole UTF code point
             // then use unknown token as the tokenization of this UTF code point
             if (!single_codepoint_token_found) {
-                const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
+                const double challenger_score = current_best.score_sum + tokenizer.unknown_token_score;
                 prefix_offset = input_offset + n_utf8_code_units;
                 struct best_tokenization & current_champ = tokenization_results[prefix_offset];
                 if (challenger_score > current_champ.score_sum) {
-                    struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score };
+                    struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
                     current_champ = challenger;
                 }
             }
@@ -944,7 +862,7 @@ struct llm_tokenizer_ugm_session {
         // merge sequences of consecutive unknown tokens into single unknown tokens
         bool is_prev_unknown = false;
         for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) {
-            bool is_unknown = tokenization.token_id == vocab.special_unk_id;
+            bool is_unknown = tokenization.token_id == vocab.token_unk();
             if (!(is_prev_unknown && is_unknown)) {
                 output.push_back(tokenization.token_id);
             }
@@ -971,11 +889,11 @@ private:
         normalized->clear();
         normalized->reserve(input.size() * 3);
 
-        const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
+        const std::string space = vocab.get_escape_whitespaces() ? tokenizer.escaped_space : " ";
 
-        bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
-        bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
-        bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces;
+        const bool shall_prepend_space = !vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix();
+        const bool shall_append_space  =  vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix();
+        const bool shall_merge_spaces  =  vocab.get_remove_extra_whitespaces();
 
         bool is_space_prepended = false;
         bool processing_non_ws = false;
@@ -1067,7 +985,7 @@ private:
 
         // if input prefix matches some user-defined token return this token as normalization result
         auto user_defined_token_match =
-           ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
+           tokenizer.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
         if (user_defined_token_match.second > 0) {
             return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
         }
@@ -1075,8 +993,8 @@ private:
         size_t longest_prefix_length = 0;
         size_t longest_prefix_offset = 0;
 
-        if (ugm_tokenizer->xcda_array_size > 0) {
-            struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
+        if (tokenizer.xcda_array_size > 0) {
+            struct xcda_array_view xcda_view(tokenizer.xcda_array, tokenizer.xcda_array_size);
 
             // Find the longest normalized sequence matching the input prefix by walking
             // the XOR-compressed compact double array (XCDA) starting from the root node
@@ -1112,10 +1030,10 @@ private:
 
         if (longest_prefix_length > 0) {
             // we have a match, so return the replacement sequence
-            if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
+            if (longest_prefix_offset >= tokenizer.prefix_replacements_size) {
                 throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
             }
-            const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
+            const char * prefix_replacement = &(tokenizer.prefix_replacements)[longest_prefix_offset];
             return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
         }
 
@@ -1132,7 +1050,7 @@ private:
     }
 
     const llama_vocab & vocab;
-    const llm_tokenizer_ugm * ugm_tokenizer;
+    const llm_tokenizer_ugm & tokenizer;
 };
 
 //
@@ -1194,15 +1112,15 @@ static std::vector llama_unescape_rwkv_token(const std::string & escape
 }
 
 struct llm_tokenizer_rwkv : llm_tokenizer {
-    llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
+    llm_tokenizer_rwkv(const llama_vocab & vocab) {
         // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
         // For now, we decode the vocab here into the lookup we'll use for tokenization.
 
         // build trie
-        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
-            const auto & token = vocab.id_to_token[id];
-            const auto data = llama_unescape_rwkv_token(token.text);
-            token_matcher.insert((const char *) data.data(), data.size(), id);
+        for (uint32_t id = 0; id < vocab.n_tokens(); ++id) {
+            const auto & data = vocab.get_token_data(id);
+            const auto text = llama_unescape_rwkv_token(data.text);
+            token_matcher.insert((const char *) text.data(), text.size(), id);
         }
     }
 
@@ -1210,16 +1128,15 @@ struct llm_tokenizer_rwkv : llm_tokenizer {
 };
 
 struct llm_tokenizer_rwkv_session {
-    llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
-        rwkv_tokenizer(static_cast(*vocab.tokenizer)) {}
+    llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) {
         uint32_t position = 0;
         while (position < text.size()) {
-            const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
+            const struct naive_trie * node = tokenizer.token_matcher.traverse(text[position]);
             if (node == NULL) {
                 // no matching token found, add unknown token
-                output.push_back(vocab.special_unk_id);
+                output.push_back(vocab.token_unk());
                 position += 1;
                 continue;
             }
@@ -1243,33 +1160,11 @@ struct llm_tokenizer_rwkv_session {
 
 private:
     const llama_vocab & vocab;
-    const llm_tokenizer_rwkv & rwkv_tokenizer;
+    const llm_tokenizer_rwkv & tokenizer;
 };
 
-void llama_vocab::init_tokenizer() {
-    switch (type) {
-        case LLAMA_VOCAB_TYPE_SPM:
-            tokenizer = new llm_tokenizer_spm(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_BPE:
-            tokenizer = new llm_tokenizer_bpe(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_WPM:
-            tokenizer = new llm_tokenizer_wpm(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_UGM:
-            tokenizer = new llm_tokenizer_ugm(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_RWKV:
-            tokenizer = new llm_tokenizer_rwkv(*this);
-            break;
-        default:
-            GGML_ABORT("unsupported vocab type");
-    }
-}
-
 //
-// (de-) tokenize
+// impl
 //
 
 typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
@@ -1278,7 +1173,7 @@ typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
 } FRAGMENT_BUFFER_VARIANT_TYPE;
 
 struct fragment_buffer_variant {
-    fragment_buffer_variant(llama_vocab::id _token)
+    fragment_buffer_variant(llama_token _token)
     :
         type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
         token(_token),
@@ -1289,7 +1184,7 @@ struct fragment_buffer_variant {
     fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
     :
         type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
-        token((llama_vocab::id) - 1),
+        token((llama_token) - 1),
         raw_text(_raw_text),
         offset(_offset),
         length(_length){
@@ -1299,20 +1194,955 @@ struct fragment_buffer_variant {
         }
 
     const FRAGMENT_BUFFER_VARIANT_TYPE type;
-    const llama_vocab::id token;
+    const llama_token token;
     const std::string _dummy;
     const std::string & raw_text;
     const uint64_t offset;
     const uint64_t length;
 };
 
+struct llama_vocab::impl {
+    uint32_t n_token_types = 0; // for BERT-style token types
+
+    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
+    enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+
+    int max_token_len = 0; // used for optimizing longest token search
+
+    // default LLaMA special tokens
+    // TODO: should we set all of these to LLAMA_TOKEN_NULL?
+    llama_token special_bos_id  = 1;
+    llama_token special_eos_id  = 2;
+    llama_token special_eot_id  = LLAMA_TOKEN_NULL;
+    llama_token special_eom_id  = LLAMA_TOKEN_NULL;
+    llama_token special_unk_id  = 0;
+    llama_token special_sep_id  = LLAMA_TOKEN_NULL;
+    llama_token special_pad_id  = LLAMA_TOKEN_NULL;
+    llama_token special_mask_id = LLAMA_TOKEN_NULL;
+
+    llama_token linefeed_id = 13;
+
+    // fim tokens
+    llama_token special_fim_pre_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_suf_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_mid_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_pad_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
+    llama_token special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
+
+    // tokenizer flags
+    bool add_space_prefix           = false;
+    bool add_bos                    = false;
+    bool add_eos                    = false;
+    bool ignore_merges              = false;
+    bool clean_spaces               = false;  // clean_up_tokenization_spaces
+    bool remove_extra_whitespaces   = false;
+    bool escape_whitespaces         = true;
+    bool treat_whitespace_as_suffix = false;
+
+    std::unordered_map token_to_id;
+    std::vector                      id_to_token;
+
+    std::vector cache_special_tokens;
+    std::vector cache_token_to_piece; // llama_token_to_piece(special = true);
+    struct pair_hash {
+        size_t operator()(const std::pair & p) const {
+            return std::hash{}(p.first) ^  //create some hash for pair
+                   (std::hash{}(p.second) << 1);
+        }
+    };
+    std::unordered_map, int, pair_hash> bpe_ranks;
+
+    // set of all tokens that cause "end of generation"
+    std::set special_eog_ids;
+
+    std::unique_ptr tokenizer;
+
+    std::vector precompiled_charsmap;
+
+    impl(const llama_vocab & vocab) : vocab(vocab) {
+    }
+
+    ~impl() = default;
+
+    void load(llama_model_loader & ml, const LLM_KV & kv);
+
+    enum llama_vocab_type get_type() const;
+
+    std::string type_name() const;
+
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
+
+    uint8_t token_to_byte(llama_token id) const;
+
+    llama_token_attr token_get_attr(llama_token id) const;
+
+    void init_tokenizer(enum llama_vocab_type type);
+
+    void tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const;
+
+    std::string token_to_piece_for_cache(
+                  llama_token   token,
+                         bool   special) const;
+
+
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
+
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
+
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
+
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
+
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
+
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
+
+    void print_info() const;
+
+private:
+    const llama_vocab & vocab;
+};
+
+void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+    struct gguf_context * ctx = ml.meta.get();
+
+    // determine vocab type
+    {
+        std::string tokenizer_model;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
+        ml.get_key(LLM_KV_TOKENIZER_PRE,   tokenizer_pre, false);
+
+        ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false);
+
+        if (tokenizer_model == "no_vocab" || tokenizer_model == "none") {
+            type = LLAMA_VOCAB_TYPE_NONE;
+
+            // default special tokens
+            special_bos_id  = LLAMA_TOKEN_NULL;
+            special_eos_id  = LLAMA_TOKEN_NULL;
+            special_unk_id  = LLAMA_TOKEN_NULL;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = LLAMA_TOKEN_NULL;
+            special_mask_id = LLAMA_TOKEN_NULL;
+            linefeed_id     = LLAMA_TOKEN_NULL;
+
+            // read vocab size from metadata
+            uint32_t n_tokens = 0;
+            if (ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) {
+                LLAMA_LOG_WARN("%s: adding %u dummy tokens\n", __func__, n_tokens);
+                id_to_token.resize(n_tokens);
+            }
+
+            return;
+        }
+
+        if (tokenizer_model == "llama") {
+            type = LLAMA_VOCAB_TYPE_SPM;
+
+            // default special tokens
+            special_bos_id  = 1;
+            special_eos_id  = 2;
+            special_unk_id  = 0;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = LLAMA_TOKEN_NULL;
+            special_mask_id = LLAMA_TOKEN_NULL;
+        } else if (tokenizer_model == "bert") {
+            type = LLAMA_VOCAB_TYPE_WPM;
+
+            // default special tokens
+            special_bos_id  = 101;
+            special_eos_id  = LLAMA_TOKEN_NULL;
+            special_unk_id  = 100;
+            special_sep_id  = 102;
+            special_pad_id  = 0;
+            special_mask_id = 103;
+        } else if (tokenizer_model == "gpt2") {
+            type = LLAMA_VOCAB_TYPE_BPE;
+
+            // read bpe merges and populate bpe ranks
+            const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
+            if (merges_keyidx == -1) {
+                throw std::runtime_error("cannot find tokenizer merges in model file\n");
+            }
+
+            const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
+            for (int i = 0; i < n_merges; i++) {
+                const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
+                //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
+
+                std::string first;
+                std::string second;
+
+                const size_t pos = word.find(' ', 1);
+
+                if (pos != std::string::npos) {
+                    first  = word.substr(0, pos);
+                    second = word.substr(pos + 1);
+                }
+
+                bpe_ranks.emplace(std::make_pair(first, second), i);
+            }
+
+            // default special tokens
+            special_bos_id  = 11;
+            special_eos_id  = 11;
+            special_unk_id  = LLAMA_TOKEN_NULL;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = LLAMA_TOKEN_NULL;
+            special_mask_id = LLAMA_TOKEN_NULL;
+        } else if (tokenizer_model == "t5") {
+            type = LLAMA_VOCAB_TYPE_UGM;
+
+            // default special tokens
+            special_bos_id  = LLAMA_TOKEN_NULL;
+            special_eos_id  = 1;
+            special_unk_id  = 2;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = 0;
+            special_mask_id = LLAMA_TOKEN_NULL;
+
+            const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
+            if (precompiled_charsmap_keyidx != -1) {
+                size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
+                const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
+                precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
+#ifdef IS_BIG_ENDIAN
+                // correct endiannes of data in precompiled_charsmap binary blob
+                uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0];
+                *xcda_blob_size = __builtin_bswap32(*xcda_blob_size);
+                assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap);
+                size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t);
+                uint32_t * xcda_array = (uint32_t *) &precompiled_charsmap[sizeof(uint32_t)];
+                for (size_t i = 0; i < xcda_array_size; ++i) {
+                    xcda_array[i] = __builtin_bswap32(xcda_array[i]);
+                }
+#endif
+            }
+        } else if (tokenizer_model == "rwkv") {
+            type = LLAMA_VOCAB_TYPE_RWKV;
+
+            // default special tokens
+            special_bos_id = LLAMA_TOKEN_NULL;
+            special_eos_id = LLAMA_TOKEN_NULL;
+            special_unk_id = LLAMA_TOKEN_NULL;
+            special_sep_id = LLAMA_TOKEN_NULL;
+            special_pad_id = LLAMA_TOKEN_NULL;
+        } else {
+            throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
+        }
+
+        // for now, only BPE models have pre-tokenizers
+        if (type == LLAMA_VOCAB_TYPE_BPE) {
+            add_space_prefix = false;
+            clean_spaces = true;
+            if (tokenizer_pre == "default") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            } else if (
+                    tokenizer_pre == "llama3"   ||
+                    tokenizer_pre == "llama-v3" ||
+                    tokenizer_pre == "llama-bpe"||
+                    tokenizer_pre == "falcon3") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
+                ignore_merges = true;
+                add_bos = true;
+            } else if (
+                    tokenizer_pre == "deepseek-llm") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "deepseek-coder") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "deepseek-v3") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "falcon") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON;
+            } else if (
+                    tokenizer_pre == "mpt") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_MPT;
+            } else if (
+                    tokenizer_pre == "starcoder") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_STARCODER;
+            } else if (
+                    tokenizer_pre == "gpt-2"   ||
+                    tokenizer_pre == "phi-2"   ||
+                    tokenizer_pre == "jina-es" ||
+                    tokenizer_pre == "jina-de" ||
+                    tokenizer_pre == "gigachat"   ||
+                    tokenizer_pre == "jina-v1-en" ||
+                    tokenizer_pre == "jina-v2-es" ||
+                    tokenizer_pre == "jina-v2-de" ||
+                    tokenizer_pre == "jina-v2-code" ||
+                    tokenizer_pre == "roberta-bpe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
+            } else if (
+                    tokenizer_pre == "refact") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
+            } else if (
+                tokenizer_pre == "command-r") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "qwen2" ||
+                    tokenizer_pre == "deepseek-r1-qwen") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "stablelm2") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2;
+            } else if (
+                tokenizer_pre == "olmo") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_OLMO;
+            } else if (
+                tokenizer_pre == "dbrx") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DBRX;
+            } else if (
+                tokenizer_pre == "smaug-bpe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_SMAUG;
+            } else if (
+                tokenizer_pre == "poro-chat") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "chatglm-bpe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
+                special_bos_id = LLAMA_TOKEN_NULL;
+            } else if (
+                tokenizer_pre == "viking") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_VIKING;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "jais") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS;
+            } else if (
+                tokenizer_pre == "tekken") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
+                clean_spaces = false;
+                ignore_merges = true;
+                add_bos = true;
+            } else if (
+                tokenizer_pre == "smollm") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "codeshell") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
+            } else if (
+                tokenizer_pre == "bloom") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_BLOOM;
+            } else if (
+                tokenizer_pre == "gpt3-finnish") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH;
+            } else if (
+                tokenizer_pre == "exaone") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE;
+            } else if (
+                tokenizer_pre == "chameleon") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
+                add_bos = true;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "minerva-7b") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_MINERVA;
+            } else if (
+                tokenizer_pre == "megrez") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+            } else {
+                LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            }
+        } else if (type == LLAMA_VOCAB_TYPE_SPM) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_space_prefix = true;
+            clean_spaces = false;
+            add_bos = true;
+            add_eos = false;
+        } else if (type == LLAMA_VOCAB_TYPE_WPM) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_space_prefix = false;
+            clean_spaces = true;
+            add_bos = true;
+            add_eos = false;
+        } else if (type == LLAMA_VOCAB_TYPE_UGM) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_bos = false;
+            add_eos = true;
+        } else if (type == LLAMA_VOCAB_TYPE_RWKV) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_space_prefix = false;
+            clean_spaces = false;
+            add_bos = false;
+            add_eos = false;
+        } else {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+        }
+
+        ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX,      add_space_prefix,         false);
+        ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false);
+    }
+
+    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
+    if (token_idx == -1) {
+        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
+    }
+
+    const float * scores = nullptr;
+    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
+    if (score_idx != -1) {
+        scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
+    }
+
+    const int * toktypes = nullptr;
+    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
+    if (toktype_idx != -1) {
+        toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
+    }
+
+    uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx);
+    id_to_token.resize(n_tokens);
+
+    for (uint32_t i = 0; i < n_tokens; i++) {
+        std::string word = gguf_get_arr_str(ctx, token_idx, i);
+        if (word.empty()) {
+            LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i);
+            word = "[EMPTY_" + std::to_string(i) + "]";
+        }
+
+        token_to_id[word] = i;
+        max_token_len = std::max(max_token_len, (int) word.size());
+
+        auto & token_data = id_to_token[i];
+        token_data.text  = std::move(word);
+        token_data.score = scores ? scores[i] : 0.0f;
+        token_data.attr  = LLAMA_TOKEN_ATTR_NORMAL;
+
+        if (toktypes) {  //TODO: remove, required until per token attributes are available from GGUF file
+            switch(toktypes[i]) {
+                case LLAMA_TOKEN_TYPE_UNKNOWN:      token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN;      break;
+                case LLAMA_TOKEN_TYPE_UNUSED:       token_data.attr = LLAMA_TOKEN_ATTR_UNUSED;       break;
+                case LLAMA_TOKEN_TYPE_NORMAL:       token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;       break;
+                case LLAMA_TOKEN_TYPE_CONTROL:      token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;      break;
+                case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
+                case LLAMA_TOKEN_TYPE_BYTE:         token_data.attr = LLAMA_TOKEN_ATTR_BYTE;         break;
+                case LLAMA_TOKEN_TYPE_UNDEFINED:    token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
+                default:                            token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
+            }
+        }
+    }
+    GGML_ASSERT(id_to_token.size() == token_to_id.size());
+
+    init_tokenizer(type);
+
+    // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
+    if (type == LLAMA_VOCAB_TYPE_SPM) {
+        try {
+            linefeed_id = vocab.byte_to_token('\n');
+        } catch (const std::exception & e) {
+            LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
+            linefeed_id = special_pad_id;
+        }
+    } else if (type == LLAMA_VOCAB_TYPE_WPM) {
+        linefeed_id = special_pad_id;
+    } else if (type == LLAMA_VOCAB_TYPE_RWKV) {
+        const std::vector ids = tokenize("\n", false);
+        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        linefeed_id = ids[0];
+    } else {
+        const std::vector ids = tokenize("\n", false);
+
+        //GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        if (ids.empty()) {
+            LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__);
+            linefeed_id = special_pad_id;
+        } else {
+            linefeed_id = ids[0];
+        }
+    }
+
+    // special tokens
+    {
+        const std::vector> special_token_types = {
+            { LLM_KV_TOKENIZER_BOS_ID,     special_bos_id     },
+            { LLM_KV_TOKENIZER_EOS_ID,     special_eos_id     },
+            { LLM_KV_TOKENIZER_EOT_ID,     special_eot_id     },
+            { LLM_KV_TOKENIZER_EOM_ID,     special_eom_id     },
+            { LLM_KV_TOKENIZER_UNK_ID,     special_unk_id     },
+            { LLM_KV_TOKENIZER_SEP_ID,     special_sep_id     },
+            { LLM_KV_TOKENIZER_PAD_ID,     special_pad_id     },
+            { LLM_KV_TOKENIZER_MASK_ID,    special_mask_id    },
+            { LLM_KV_TOKENIZER_FIM_PRE_ID, special_fim_pre_id },
+            { LLM_KV_TOKENIZER_FIM_SUF_ID, special_fim_suf_id },
+            { LLM_KV_TOKENIZER_FIM_MID_ID, special_fim_mid_id },
+            { LLM_KV_TOKENIZER_FIM_PAD_ID, special_fim_pad_id },
+            { LLM_KV_TOKENIZER_FIM_REP_ID, special_fim_rep_id },
+            { LLM_KV_TOKENIZER_FIM_SEP_ID, special_fim_sep_id },
+
+            // deprecated
+            { LLM_KV_TOKENIZER_PREFIX_ID, special_fim_pre_id },
+            { LLM_KV_TOKENIZER_SUFFIX_ID, special_fim_suf_id },
+            { LLM_KV_TOKENIZER_MIDDLE_ID, special_fim_mid_id },
+        };
+
+        for (const auto & it : special_token_types) {
+            const std::string & key = kv(std::get<0>(it));
+            int32_t & id = std::get<1>(it);
+
+            uint32_t new_id;
+            if (!ml.get_key(std::get<0>(it), new_id, false)) {
+                continue;
+            }
+            if (new_id >= id_to_token.size()) {
+                LLAMA_LOG_WARN("%s: bad special token: '%s' = %u, using default id %d\n",
+                    __func__, key.c_str(), new_id, id);
+            } else {
+                id = new_id;
+            }
+        }
+
+        // Handle add_bos and add_eos
+        {
+            bool temp = true;
+
+            if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
+                add_bos = temp;
+            }
+            if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
+                add_eos = temp;
+            }
+        }
+
+        // auto-detect special tokens by text
+        // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_...
+        //       for now, we apply this workaround to find the tokens based on their text
+
+        for (const auto & t : token_to_id) {
+            // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
+            if (special_eot_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|eot_id|>"
+                        || t.first == "<|im_end|>"
+                        || t.first == "<|end|>"
+                        || t.first == ""
+                        || t.first == "<|endoftext|>"
+                        || t.first == ""
+                        || t.first == "<|end▁of▁sentence|>" // DeepSeek
+                   ) {
+                    special_eot_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find EOM token: "<|eom_id|>"
+            if (special_eom_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|eom_id|>"
+                        ) {
+                    special_eom_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        ) {
+                    special_fim_pre_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        ) {
+                    special_fim_suf_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        ) {
+                    special_fim_mid_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    special_fim_pad_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    special_fim_rep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    special_fim_sep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        special_eog_ids.clear();
+
+        if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
+            special_eog_ids.insert(special_fim_pad_id);
+        }
+
+        if (special_fim_rep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_rep_id) == 0) {
+            special_eog_ids.insert(special_fim_rep_id);
+        }
+
+        if (special_fim_sep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_sep_id) == 0) {
+            special_eog_ids.insert(special_fim_sep_id);
+        }
+
+        for (const auto & t : token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == ""
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == ""
+               ) {
+                special_eog_ids.insert(t.second);
+                if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
+                    id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
+            }
+        }
+
+        // sanity checks
+        if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
+            special_eog_ids.insert(special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eot_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eot_id) == 0) {
+            special_eog_ids.insert(special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eom_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eom_id) == 0) {
+            special_eog_ids.insert(special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+    }
+
+    // build special tokens cache
+    {
+        for (llama_token id = 0; id < (llama_token) n_tokens; ++id) {
+            if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                cache_special_tokens.push_back(id);
+            }
+        }
+
+        std::sort(cache_special_tokens.begin(), cache_special_tokens.end(),
+            [&] (const llama_token a, const llama_token b) {
+                return id_to_token[a].text.size() > id_to_token[b].text.size();
+            }
+        );
+
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t) cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache(n_tokens);
+
+        for (uint32_t id = 0; id < n_tokens; ++id) {
+            cache[id] = token_to_piece_for_cache(id, true);
+
+            size_cache += cache[id].size();
+        }
+
+        std::swap(cache_token_to_piece, cache);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
+    }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string & str, const std::vector & substrs) -> bool {
+            for (const auto & substr : substrs) {
+                if (str.find(substr) < std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_token id, llama_token_attr attr, bool value) {
+            uint32_t current = id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_GENERAL_NAME,  model_name,    false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer name
+        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
+            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {""}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {"", "", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
+        }
+    }
+}
+
+enum llama_vocab_type llama_vocab::impl::get_type() const {
+    return type;
+}
+
+std::string llama_vocab::impl::type_name() const{
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
+        default:                    return "unknown";
+    }
+}
+
+bool llama_vocab::impl::is_normal(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
+}
+
+bool llama_vocab::impl::is_unknown(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
+}
+
+bool llama_vocab::impl::is_control(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
+}
+
+bool llama_vocab::impl::is_byte(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
+}
+
+bool llama_vocab::impl::is_user_defined(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
+}
+
+bool llama_vocab::impl::is_unused(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
+}
+
+bool llama_vocab::impl::is_eog(llama_token id) const {
+    return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
+}
+
+uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    GGML_ASSERT(is_byte(id));
+    const auto & token_data = id_to_token.at(id);
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            auto buf = token_data.text.substr(3, 2);
+            return strtol(buf.c_str(), NULL, 16);
+        }
+        case LLAMA_VOCAB_TYPE_BPE: {
+            GGML_ABORT("fatal error");
+        }
+        case LLAMA_VOCAB_TYPE_WPM: {
+            GGML_ABORT("fatal error");
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token_attr llama_vocab::impl::token_get_attr(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token.at(id).attr;
+}
+
+void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
+    LLAMA_LOG_DEBUG("%s: initializing tokenizer for type %d\n", __func__, type);
+
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            tokenizer = std::make_unique(vocab, precompiled_charsmap);
+            break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            tokenizer = std::make_unique(vocab);
+            break;
+        default:
+            GGML_ABORT("unsupported vocab type");
+    }
+}
+
+//
+// (de-) tokenize
+//
+
 // #define PRETOKENIZERDEBUG
 
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) {
+void llama_vocab::impl::tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const {
     // for each special token
-    for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
-        const auto & data = vocab.id_to_token[special_id];
-        const auto & special_token = data.text;
+    for (const llama_token special_id : cache_special_tokens) {
+        const auto & data = vocab.get_token_data(special_id);
+        const auto & text = data.text;
 
         if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
             // Ignore control and unknown tokens when parse_special == false
@@ -1339,13 +2169,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     // find the first occurrence of a given special token in this fragment
                     //  passing offset argument only limit the "search area" but match coordinates
                     //  are still relative to the source full raw_text
-                    auto match = raw_text.find(special_token, raw_text_base_offset);
+                    auto match = raw_text.find(text, raw_text_base_offset);
 
                     // no occurrences found, stop processing this fragment for a given special token
                     if (match == std::string::npos) break;
 
                     // check if match is within bounds of offset <-> length
-                    if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
+                    if (match + text.length() > raw_text_base_offset + raw_text_base_length) break;
 
 #ifdef PRETOKENIZERDEBUG
                     LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
@@ -1380,9 +2210,9 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     it++;
 
                     // right
-                    if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
-                        int64_t right_reminder_offset = match + special_token.length();
-                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
+                    if (match + text.length() < raw_text_base_offset + raw_text_base_length) {
+                        int64_t right_reminder_offset = match + text.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + text.length());
 
                         if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
                             while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
@@ -1403,7 +2233,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
                         }
 
                         // repeat for the right side
@@ -1417,7 +2247,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
                         }
                         break;
                     }
@@ -1428,322 +2258,29 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
     }
 }
 
-std::vector llama_tokenize_internal(
-        const llama_vocab & vocab,
-        std::string raw_text,
-        bool add_special,
-        bool parse_special) {
-    GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
-
-    std::vector output;
-    std::forward_list fragment_buffer;
-
-    if (!raw_text.empty()) {
-        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
-        tokenizer_st_partition(vocab, fragment_buffer, parse_special);
+// NOTE: avoid ever using this except for building the token_to_piece caches
+std::string llama_vocab::impl::token_to_piece_for_cache(llama_token token, bool special) const {
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache
+    const int n_chars = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
     }
 
-    switch (vocab.type) {
-        case LLAMA_VOCAB_TYPE_SPM:
-            {
-                // OG tokenizer behavior:
-                //
-                // tokenizer.encode('', add_special_tokens=True)  returns [1]
-                // tokenizer.encode('', add_special_tokens=False) returns []
-
-                bool is_prev_special = true;  // prefix with space if first token
-
-                if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                    is_prev_special = true;
-                }
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-                        // prefix with space if previous is special
-                        if (vocab.tokenizer_add_space_prefix && is_prev_special) {
-                            raw_text = " " + raw_text;
-                        }
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        llama_escape_whitespace(raw_text);
-                        llm_tokenizer_spm_session session(vocab);
-                        session.tokenize(raw_text, output);
-                        is_prev_special = false;
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                        is_prev_special = true;
-                    }
-                }
-
-                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_BPE:
-            {
-                llm_tokenizer_bpe_session session(vocab);
-                // it calls some other methods that are not exist in llm_tokenizer,
-                // here just cast it to bpe tokenizer object
-                if (add_special) {
-                    session.append_bos(output);
-                }
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        session.append(fragment.token, output);
-                    }
-                }
-
-                if (add_special) {
-                    session.append_eos(output);
-                    session.check_double_bos_eos(output);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_WPM:
-            {
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_cls_id != -1);
-                    output.push_back(vocab.special_cls_id);
-                }
-
-                llm_tokenizer_wpm_session session(vocab);
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_sep_id != -1);
-                    output.push_back(vocab.special_sep_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_UGM:
-            {
-                if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                }
-                llm_tokenizer_ugm_session session(vocab);
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_RWKV:
-            {
-                llm_tokenizer_rwkv_session session(vocab);
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_NONE:
-            GGML_ABORT("fatal error");
-    }
-
-    return output;
+    return piece;
 }
 
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    static const char * hex = "0123456789ABCDEF";
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
-            auto token = vocab.token_to_id.find(buf);
-            if (token != vocab.token_to_id.end()) {
-                return (*token).second;
-            }
-            // Try to fall back to just the byte as a string
-            const char buf2[2] = { (char)ch, 0 };
-            return vocab.token_to_id.at(buf2);
-        }
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_BPE: {
-            return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
-        }
-        default:
-            GGML_ABORT("fatal error");
-    }
+static void llama_escape_whitespace(std::string & text) {
+    replace_all(text, " ", "\xe2\x96\x81");
 }
 
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].text.c_str();
-}
-
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].score;
-}
-
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].attr;
-}
-
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
-    return token != -1 && vocab.special_eog_ids.count(token) > 0;
-}
-
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
-    return llama_is_control_token(vocab, token);
-}
-
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
-    return vocab.type != LLAMA_VOCAB_TYPE_WPM ? vocab.special_bos_id : vocab.special_cls_id;
-}
-
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eos_id;
-}
-
-llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eot_id;
-}
-
-llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eom_id;
-}
-
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
-    return vocab.special_cls_id;
-}
-
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_sep_id;
-}
-
-llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
-    return vocab.linefeed_id;
-}
-
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_pad_id;
-}
-
-bool llama_add_bos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_bos;
-}
-
-bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_eos;
-}
-
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pre_id;
-}
-
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_mid_id;
-}
-
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_suf_id;
-}
-
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pre_id;
-}
-
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_suf_id;
-}
-
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_mid_id;
-}
-
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pad_id;
-}
-
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_rep_id;
-}
-
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_sep_id;
-}
-
-int32_t llama_tokenize_impl(
-        const struct llama_vocab & vocab,
-                      const char * text,
-                         int32_t   text_len,
-                     llama_token * tokens,
-                         int32_t   n_tokens_max,
-                            bool   add_special,
-                            bool   parse_special) {
-    auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
-    if (n_tokens_max < (int) res.size()) {
-        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
-        return -((int) res.size());
-    }
-
-    for (size_t i = 0; i < res.size(); i++) {
-        tokens[i] = res[i];
-    }
-
-    return res.size();
+static void llama_unescape_whitespace(std::string & word) {
+    replace_all(word, "\xe2\x96\x81", " ");
 }
 
 static std::string llama_decode_text(const std::string & text) {
@@ -1766,11 +2303,185 @@ static std::string llama_decode_text(const std::string & text) {
     return decoded_text;
 }
 
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
+std::vector llama_vocab::impl::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    std::vector output;
+    std::forward_list fragment_buffer;
+
+    if (!raw_text.empty()) {
+        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+        tokenizer_st_partition(fragment_buffer, parse_special);
+    }
+
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            {
+                // OG tokenizer behavior:
+                //
+                // tokenizer.encode('', add_special_tokens=True)  returns [1]
+                // tokenizer.encode('', add_special_tokens=False) returns []
+
+                bool is_prev_special = true;  // prefix with space if first token
+
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                    is_prev_special = true;
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text;
+
+                        // prefix with space if previous is special
+                        if (add_space_prefix && is_prev_special) {
+                            text = ' ';
+                        }
+
+                        text += fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        llama_escape_whitespace(text);
+                        llm_tokenizer_spm_session session(vocab);
+                        session.tokenize(text, output);
+                        is_prev_special = false;
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                        is_prev_special = true;
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            {
+                llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get()));
+                // it calls some other methods that are not exist in llm_tokenizer,
+                // here just cast it to bpe tokenizer object
+                if (add_special) {
+                    session.append_bos(output);
+                }
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        session.append(fragment.token, output);
+                    }
+                }
+
+                if (add_special) {
+                    session.append_eos(output);
+                    session.check_double_bos_eos(output);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            {
+                if (add_special) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+
+                llm_tokenizer_wpm_session session(vocab);
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special) {
+                    GGML_ASSERT(special_sep_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_sep_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            {
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+                llm_tokenizer_ugm_session session(vocab, *static_cast(tokenizer.get()));
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            {
+                llm_tokenizer_rwkv_session session(vocab, *static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_NONE:
+            GGML_ABORT("fatal error");
+    }
+
+    return output;
+}
+
+int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
     // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
     static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
-    const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
+    const llama_token_attr attr = token_get_attr(token);
     if (!special && (attr & attr_special)) {
         return 0;
     }
@@ -1791,7 +2502,7 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
 
     // if we have a cache - use it
     {
-        const auto & cache = vocab.cache_token_to_piece;
+        const auto & cache = cache_token_to_piece;
 
         if (!cache.empty()) {
             const auto & result = cache.at(token);
@@ -1799,9 +2510,9 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
         }
     }
 
-    if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
-        const std::string & token_text = vocab.id_to_token[token].text;
-        switch (llama_vocab_get_type(vocab)) {
+    if (0 <= token && token < (int32_t) id_to_token.size()) {
+        const std::string & token_text = id_to_token[token].text;
+        switch (get_type()) {
             case LLAMA_VOCAB_TYPE_WPM:
             case LLAMA_VOCAB_TYPE_SPM:
             case LLAMA_VOCAB_TYPE_UGM: {
@@ -1816,7 +2527,7 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
                     return _try_copy(result.data(), result.size());
                 }
                 if (attr & LLAMA_TOKEN_ATTR_BYTE) {
-                    char byte = (char) llama_token_to_byte(vocab, token);
+                    char byte = (char) token_to_byte(token);
                     return _try_copy((char*) &byte, 1);
                 }
                 break;
@@ -1852,43 +2563,46 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
     return 0;
 }
 
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
+const std::string & llama_vocab::impl::token_to_piece(llama_token token) const {
+    return cache_token_to_piece.at(token);
+}
+
+int32_t llama_vocab::impl::detokenize(
                const llama_token * tokens,
                          int32_t   n_tokens,
                             char * text,
                          int32_t   text_len_max,
                             bool   remove_special,
-                            bool   unparse_special) {
-    if (vocab.type == LLAMA_VOCAB_TYPE_NONE) {
+                            bool   unparse_special) const {
+    if (type == LLAMA_VOCAB_TYPE_NONE) {
         return 0;
     }
 
-    GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
 
     int32_t avail = text_len_max;
     int32_t total = 0;
 
     // remove the leading space
-    bool remove_space = vocab.tokenizer_add_space_prefix;
+    bool remove_space = add_space_prefix;
 
-    if (remove_special && vocab.tokenizer_add_bos) {
-        if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
+    if (remove_special && add_bos) {
+        if (n_tokens > 0 && tokens[0] == special_bos_id) {
             remove_space = false;
             n_tokens--;
             tokens++;
         }
     }
 
-    if (remove_special && vocab.tokenizer_add_eos) {
-        if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
+    if (remove_special && add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens - 1] == special_eos_id) {
             n_tokens--;
         }
     }
 
     for (int32_t i = 0; i < n_tokens; ++i) {
         GGML_ASSERT(avail >= 0);
-        int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
+        int32_t n_chars = token_to_piece(tokens[i], text, avail, remove_space, unparse_special);
         remove_space = false;
         if (n_chars < 0) {
             avail = 0;
@@ -1904,7 +2618,7 @@ int32_t llama_detokenize_impl(
         return -total;
     }
 
-    if (vocab.tokenizer_clean_spaces) {
+    if (clean_spaces) {
         text -= total;  // restart text
 
         // first pass: characters ?!.,  //TODO: where do these characters come from?
@@ -1965,13 +2679,321 @@ int32_t llama_detokenize_impl(
     return total <= text_len_max ? total : -total;
 }
 
-std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector & tokens, bool special) {
+void llama_vocab::impl::print_info() const {
+    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, type_name().c_str());
+    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, vocab.n_tokens());
+    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
+
+    // special tokens
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token[special_bos_id].text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token[special_eos_id].text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token[special_eot_id].text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token[special_eom_id].text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token[special_unk_id].text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token[special_sep_id].text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token[special_pad_id].text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token[special_mask_id].text.c_str() ); }
+
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token[linefeed_id].text.c_str() ); }
+
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token[special_fim_pre_id].text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token[special_fim_suf_id].text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token[special_fim_mid_id].text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token[special_fim_pad_id].text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token[special_fim_rep_id].text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token[special_fim_sep_id].text.c_str() ); }
+
+    for (const auto & id : special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token[id].text.c_str() );
+    }
+
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+}
+
+llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
+}
+
+llama_vocab::~llama_vocab() {
+}
+
+void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
+    pimpl->load(ml, kv);
+}
+
+enum llama_vocab_type llama_vocab::get_type() const {
+    return pimpl->type;
+}
+
+enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
+    return pimpl->pre_type;
+}
+
+uint32_t llama_vocab::n_tokens() const {
+    return (uint32_t) pimpl->id_to_token.size();
+}
+
+uint32_t llama_vocab::n_token_types() const {
+    return (uint32_t) pimpl->n_token_types;
+}
+
+std::string llama_vocab::type_name() const{
+    return pimpl->type_name();
+}
+
+bool llama_vocab::is_normal(llama_token id) const {
+    return pimpl->is_normal(id);
+}
+
+bool llama_vocab::is_unknown(llama_token id) const {
+    return pimpl->is_unknown(id);
+}
+
+bool llama_vocab::is_control(llama_token id) const {
+    return pimpl->is_control(id);
+}
+
+bool llama_vocab::is_byte(llama_token id) const {
+    return pimpl->is_byte(id);
+}
+
+bool llama_vocab::is_user_defined(llama_token id) const {
+    return pimpl->is_user_defined(id);
+}
+
+bool llama_vocab::is_unused(llama_token id) const {
+    return pimpl->is_unused(id);
+}
+
+bool llama_vocab::is_eog(llama_token id) const {
+    return pimpl->is_eog(id);
+}
+
+uint8_t llama_vocab::token_to_byte(llama_token id) const {
+    return pimpl->token_to_byte(id);
+}
+
+llama_token llama_vocab::byte_to_token(uint8_t ch) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    static const char * hex = "0123456789ABCDEF";
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
+            auto token = pimpl->token_to_id.find(buf);
+            if (token != pimpl->token_to_id.end()) {
+                return (*token).second;
+            }
+            // Try to fall back to just the byte as a string
+            const char buf2[2] = { (char)ch, 0 };
+            return pimpl->token_to_id.at(buf2);
+        }
+        case LLAMA_VOCAB_TYPE_WPM:
+        case LLAMA_VOCAB_TYPE_BPE: {
+            return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token llama_vocab::text_to_token(const std::string & text) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    auto it = pimpl->token_to_id.find(text);
+    if (it != pimpl->token_to_id.end()) {
+        return (*it).second;
+    }
+    return LLAMA_TOKEN_NULL;
+}
+
+const llama_vocab::token_data & llama_vocab::get_token_data(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id);
+}
+
+const char * llama_vocab::token_get_text(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).text.c_str();
+}
+
+float llama_vocab::token_get_score(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).score;
+}
+
+llama_token_attr llama_vocab::token_get_attr(llama_token id) const {
+    return pimpl->token_get_attr(id);
+}
+
+llama_token llama_vocab::token_bos() const {
+    return pimpl->special_bos_id;
+}
+
+llama_token llama_vocab::token_eos() const {
+    return pimpl->special_eos_id;
+}
+
+llama_token llama_vocab::token_eot() const {
+    return pimpl->special_eot_id;
+}
+
+llama_token llama_vocab::token_eom() const {
+    return pimpl->special_eom_id;
+}
+
+llama_token llama_vocab::token_unk() const {
+    return pimpl->special_unk_id;
+}
+
+llama_token llama_vocab::token_sep() const {
+    return pimpl->special_sep_id;
+}
+
+llama_token llama_vocab::token_nl() const {
+    return pimpl->linefeed_id;
+}
+
+llama_token llama_vocab::token_pad() const {
+    return pimpl->special_pad_id;
+}
+
+llama_token llama_vocab::token_prefix() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_middle() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_suffix() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_pre() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_fim_suf() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_mid() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_fim_pad() const {
+    return pimpl->special_fim_pad_id;
+}
+
+llama_token llama_vocab::token_fim_rep() const {
+    return pimpl->special_fim_rep_id;
+}
+
+llama_token llama_vocab::token_fim_sep() const {
+    return pimpl->special_fim_sep_id;
+}
+
+bool llama_vocab::get_add_space_prefix() const {
+    return pimpl->add_space_prefix;
+}
+
+bool llama_vocab::get_add_bos() const {
+    return pimpl->add_bos;
+}
+
+bool llama_vocab::get_add_eos() const {
+    return pimpl->add_eos;
+}
+
+bool llama_vocab::get_ignore_merges() const {
+    return pimpl->ignore_merges;
+}
+
+bool llama_vocab::get_clean_spaces() const {
+    return pimpl->clean_spaces;
+}
+
+bool llama_vocab::get_remove_extra_whitespaces() const {
+    return pimpl->remove_extra_whitespaces;
+}
+
+bool llama_vocab::get_escape_whitespaces() const {
+    return pimpl->escape_whitespaces;
+}
+
+bool llama_vocab::get_treat_whitespace_as_suffix() const {
+    return pimpl->treat_whitespace_as_suffix;
+}
+
+int llama_vocab::max_token_len() const {
+    return pimpl->max_token_len;
+}
+
+int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
+    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
+    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
+    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
+    GGML_ASSERT(token_right.find('\n') == std::string::npos);
+
+    auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right));
+    if (it == pimpl->bpe_ranks.end()) {
+        return -1;
+    }
+
+    return it->second;
+}
+
+int32_t llama_vocab::tokenize(
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) const {
+    auto res = tokenize(std::string(text, text_len), add_special, parse_special);
+    if (n_tokens_max < (int) res.size()) {
+        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
+        return -((int) res.size());
+    }
+
+    for (size_t i = 0; i < res.size(); i++) {
+        tokens[i] = res[i];
+    }
+
+    return res.size();
+}
+
+std::vector llama_vocab::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    return pimpl->tokenize(raw_text, add_special, parse_special);
+}
+
+const std::string & llama_vocab::token_to_piece(llama_token token) const {
+    return pimpl->token_to_piece(token);
+}
+
+int32_t llama_vocab::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    return pimpl->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_vocab::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    return pimpl->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
+std::string llama_vocab::detokenize(const std::vector & tokens, bool special) const {
     std::string text;
     text.resize(std::max(text.capacity(), tokens.size()));
-    int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    int32_t n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
     if (n_chars < 0) {
         text.resize(-n_chars);
-        n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
         GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
     }
 
@@ -1980,3 +3002,243 @@ std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector
     // NOTE: the original tokenizer decodes bytes after collecting the pieces.
     return text;
 }
+
+void llama_vocab::print_info() const {
+    pimpl->print_info();
+}
+
+//
+// interface implementation
+//
+
+int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab) {
+    return vocab->n_tokens();
+}
+
+// deprecated
+int32_t llama_n_vocab(const struct llama_vocab * vocab) {
+    return llama_vocab_n_tokens(vocab);
+}
+
+enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) {
+    return vocab->get_type();
+}
+
+const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_text(token);
+}
+
+float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_score(token);
+}
+
+enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_attr(token);
+}
+
+bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_eog(token);
+}
+
+bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_control(token);
+}
+
+llama_token llama_vocab_bos(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_eos(const struct llama_vocab * vocab) {
+    return vocab->token_eos();
+}
+
+llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
+    return vocab->token_eot();
+}
+
+// deprecated
+llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_sep(const struct llama_vocab * vocab) {
+    return vocab->token_sep();
+}
+
+llama_token llama_vocab_nl (const struct llama_vocab * vocab) {
+    return vocab->token_nl();
+}
+
+llama_token llama_vocab_pad(const struct llama_vocab * vocab) {
+    return vocab->token_pad();
+}
+
+bool llama_vocab_get_add_bos(const struct llama_vocab * vocab) {
+    return vocab->get_add_bos();
+}
+
+bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
+    return vocab->get_add_eos();
+}
+
+llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pre();
+}
+
+llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab) {
+    return vocab->token_fim_suf();
+}
+
+llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab) {
+    return vocab->token_fim_mid();
+}
+
+llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pad();
+}
+
+llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_rep();
+}
+
+llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_sep();
+}
+
+// deprecated
+const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_text(vocab, token);
+}
+
+// deprecated
+float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_score(vocab, token);
+}
+
+// deprecated
+enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_attr(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_eog(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_control(vocab, token);
+}
+
+// deprecated
+llama_token llama_token_bos(const struct llama_vocab * vocab) {
+    return llama_vocab_bos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eos(const struct llama_vocab * vocab) {
+    return llama_vocab_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eot(const struct llama_vocab * vocab) {
+    return llama_vocab_eot(vocab);
+}
+
+// deprecated
+llama_token llama_token_cls(const struct llama_vocab * vocab) {
+    //return llama_vocab_cls(vocab);
+    return llama_vocab_bos(vocab); // avoid deprecation warning
+}
+
+// deprecated
+llama_token llama_token_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_sep(vocab);
+}
+
+// deprecated
+llama_token llama_token_nl (const struct llama_vocab * vocab) {
+    return llama_vocab_nl(vocab);
+}
+
+// deprecated
+llama_token llama_token_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_pad(vocab);
+}
+
+// deprecated
+bool llama_add_bos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_bos(vocab);
+}
+
+// deprecated
+bool llama_add_eos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pre(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pre(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_suf(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_suf(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_mid(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_mid(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pad(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_rep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_rep(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_sep(vocab);
+}
+
+//
+// tokenization
+//
+
+int32_t llama_tokenize(
+    const struct llama_vocab * vocab,
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) {
+    return vocab->tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
+}
+
+int32_t llama_token_to_piece(
+    const struct llama_vocab * vocab,
+                 llama_token   token,
+                        char * buf,
+                     int32_t   length,
+                     int32_t   lstrip,
+                        bool   special) {
+    return vocab->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_detokenize(
+    const struct llama_vocab * vocab,
+           const llama_token * tokens,
+                     int32_t   n_tokens,
+                        char * text,
+                     int32_t   text_len_max,
+                        bool   remove_special,
+                        bool   unparse_special) {
+    return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h
index 0d00086da..5ce355214 100644
--- a/llama/llama.cpp/src/llama-vocab.h
+++ b/llama/llama.cpp/src/llama-vocab.h
@@ -4,179 +4,122 @@
 
 #include 
 #include 
-#include 
-#include 
-#include 
+#include 
 
-static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
-    switch (type) {
-        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
-        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
-        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
-        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
-        default:                    return "unknown";
-    }
-}
-
-struct llm_tokenizer;
+struct LLM_KV;
+struct llama_model_loader;
 
 struct llama_vocab {
-    using id    = llama_token;
-    using token = std::string;
-    using tattr = llama_token_attr;
-
     struct token_data {
-        token text;
-        float score;
-        tattr attr;
+        std::string      text;
+        float            score;
+        llama_token_attr attr;
     };
 
-    uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
-
-    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
-    enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-
-    int max_token_len = 0; // used for optimizing longest token search
-
-    std::unordered_map token_to_id;
-    std::vector       id_to_token;
-
-    std::vector    cache_special_tokens;
-    std::vector cache_token_to_piece; // llama_token_to_piece(special = true);
-
-    std::map, int> bpe_ranks;
-
-    // default LLaMA special tokens
-    // TODO: should we set all of these to LLAMA_TOKEN_NULL?
-    id special_bos_id  = 1;
-    id special_eos_id  = 2;
-    id special_eot_id  = LLAMA_TOKEN_NULL;
-    id special_eom_id  = LLAMA_TOKEN_NULL;
-    id special_unk_id  = 0;
-    id special_sep_id  = LLAMA_TOKEN_NULL;
-    id special_pad_id  = LLAMA_TOKEN_NULL;
-    id special_cls_id  = LLAMA_TOKEN_NULL; // TODO: revisit if this is really needed https://github.com/ggerganov/llama.cpp/pull/10930
-    id special_mask_id = LLAMA_TOKEN_NULL;
-
-    id linefeed_id = 13;
-
-    // fim tokens
-    id special_fim_pre_id = LLAMA_TOKEN_NULL;
-    id special_fim_suf_id = LLAMA_TOKEN_NULL;
-    id special_fim_mid_id = LLAMA_TOKEN_NULL;
-    id special_fim_pad_id = LLAMA_TOKEN_NULL;
-    id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
-    id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
-
-    // set of all tokens that cause "end of generation"
-    std::set special_eog_ids;
-
-    // tokenizer flags
-    bool tokenizer_add_space_prefix           = false;
-    bool tokenizer_add_bos                    = false;
-    bool tokenizer_add_eos                    = false;
-    bool tokenizer_ignore_merges              = false;
-    bool tokenizer_clean_spaces               = false;  // clean_up_tokenization_spaces
-    bool tokenizer_remove_extra_whitespaces   = false;
-    bool tokenizer_escape_whitespaces         = true;
-    bool tokenizer_treat_whitespace_as_suffix = false;
-
-    std::vector precompiled_charsmap;
-
-    llm_tokenizer * tokenizer = nullptr;
-
-    llama_vocab() = default;
+    llama_vocab();
     ~llama_vocab();
 
+    void load(llama_model_loader & ml, const LLM_KV & kv);
+
+    enum llama_vocab_type     get_type()     const;
+    enum llama_vocab_pre_type get_pre_type() const;
+
+    uint32_t n_tokens() const;
+    uint32_t n_token_types() const;
+
+    std::string type_name() const;
+
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
+
+    uint8_t     token_to_byte(llama_token id) const;
+    llama_token byte_to_token(uint8_t ch)     const;
+
+    llama_token text_to_token(const std::string & text) const;
+
+    const token_data & get_token_data(llama_token id) const;
+
+    const char *     token_get_text (llama_token id) const;
+    float            token_get_score(llama_token id) const;
+    llama_token_attr token_get_attr (llama_token id) const;
+
+    llama_token token_bos() const;
+    llama_token token_eos() const;
+    llama_token token_eot() const;
+    llama_token token_eom() const;
+    llama_token token_unk() const;
+    llama_token token_sep() const;
+    llama_token token_nl () const;
+    llama_token token_pad() const;
+
+    llama_token token_prefix() const;
+    llama_token token_middle() const;
+    llama_token token_suffix() const;
+
+    llama_token token_fim_pre() const;
+    llama_token token_fim_suf() const;
+    llama_token token_fim_mid() const;
+    llama_token token_fim_pad() const;
+    llama_token token_fim_rep() const;
+    llama_token token_fim_sep() const;
+
+    bool get_add_space_prefix          () const;
+    bool get_add_bos                   () const;
+    bool get_add_eos                   () const;
+    bool get_ignore_merges             () const;
+    bool get_clean_spaces              () const;
+    bool get_remove_extra_whitespaces  () const;
+    bool get_escape_whitespaces        () const;
+    bool get_treat_whitespace_as_suffix() const;
+
+    int max_token_len() const;
+
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
 
-    void init_tokenizer();
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
+
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
+
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
+
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
+
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
+
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
+
+    void print_info() const;
+
+private:
+    struct impl;
+    std::unique_ptr pimpl;
 };
-
-//
-// internal API
-//
-
-// TODO: rename to llama_tokenize_impl
-// TODO: This should probably be in llama.h
-std::vector llama_tokenize_internal(
-        const llama_vocab & vocab,
-        std::string raw_text,
-        bool add_special,
-        bool parse_special = false);
-
-// TODO: move the API below as member functions of llama_vocab
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
-
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
-
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
-
-bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
-bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
-
-int32_t llama_tokenize_impl(
-        const struct llama_vocab & vocab,
-                      const char * text,
-                         int32_t   text_len,
-                     llama_token * tokens,
-                         int32_t   n_tokens_max,
-                            bool   add_special,
-                            bool   parse_special);
-
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token,
-                            char * buf,
-                         int32_t   length,
-                         int32_t   lstrip,
-                            bool   special);
-
-// check if token0 is contained as a prefix in token1
-bool llama_token_is_prefix_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token0,
-                     llama_token   token1);
-
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
-               const llama_token * tokens,
-                         int32_t   n_tokens,
-                            char * text,
-                         int32_t   text_len_max,
-                            bool   remove_special,
-                            bool   unparse_special);
-
-std::string llama_detokenize(
-        const struct llama_vocab & vocab,
-  const std::vector & tokens,
-                            bool   special);
diff --git a/llama/llama.cpp/src/llama.cpp b/llama/llama.cpp/src/llama.cpp
index c95da45d3..01854fcef 100644
--- a/llama/llama.cpp/src/llama.cpp
+++ b/llama/llama.cpp/src/llama.cpp
@@ -8,7 +8,6 @@
 #include "llama-kv-cache.h"
 #include "llama-model-loader.h"
 #include "llama-model.h"
-#include "llama-quant.h"
 
 #include "ggml.h"
 #include "ggml-alloc.h"
@@ -18,2560 +17,60 @@
 #include 
 #include 
 #include 
-#include 
 #include 
-#include 
-#include 
 #include 
-#include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
-#include 
-#include 
-#include 
-#include 
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-//
-// tensor loading (TODO: add llama_tesor_loader?)
-//
-
-static int llama_get_device_count(const llama_model & model) {
-    return (int) model.devices.size();
-}
-
-// checks if the weight tensor can be used with the specified buffer type and device
-static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
-    GGML_ASSERT(w != nullptr);
-
-    if (op == GGML_OP_NONE) {
-        return true;
-    }
-
-    ggml_init_params params = {
-        /*.mem_size   =*/ ggml_tensor_overhead()*8,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-    ggml_context_ptr ctx_ptr { ggml_init(params) };
-    if (!ctx_ptr) {
-        throw std::runtime_error(format("failed to create ggml context"));
-    }
-    ggml_context * ctx = ctx_ptr.get();
-
-    ggml_tensor * op_tensor = nullptr;
-
-    switch (op) {
-        case GGML_OP_GET_ROWS:
-            {
-                ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
-                op_tensor = ggml_get_rows(ctx, w, b);
-            } break;
-        case GGML_OP_MUL_MAT:
-            {
-                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]);
-                op_tensor = ggml_mul_mat(ctx, w, b);
-            } break;
-        case GGML_OP_MUL_MAT_ID:
-            {
-                int n_expert_used = hparams.n_expert_used;
-                ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
-                ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512);
-                op_tensor = ggml_mul_mat_id(ctx, w, b, ids);
-            } break;
-        case GGML_OP_ADD:
-            {
-                ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
-                op_tensor = ggml_add(ctx, a, w);
-            } break;
-        case GGML_OP_MUL:
-            {
-                ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
-                op_tensor = ggml_mul(ctx, a, w);
-            } break;
-        case GGML_OP_DIV:
-            {
-                ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]);
-                op_tensor = ggml_div(ctx, a, w);
-            } break;
-        case GGML_OP_ROPE:
-            {
-                int n_embd_head = hparams.n_embd_head_v;
-                int n_head = hparams.n_head();
-                ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512);
-                ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
-                op_tensor = ggml_rope_ext(
-                    ctx, a, b, w,
-                    0, 0, 0, 0, 0,
-                    0, 0, 0, 0
-                );
-
-            } break;
-        case GGML_OP_SSM_CONV:
-            {
-                // FIXME
-                ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789);
-                op_tensor = ggml_ssm_conv(ctx, conv_x, w);
-            } break;
-        case GGML_OP_SSM_SCAN:
-            {
-                // FIXME
-                const int64_t d_state      = w->ne[0];
-                const int64_t d_inner      = w->ne[1];
-                const int64_t n_seq_tokens = 512;
-                const int64_t n_seqs       = 1;
-                ggml_tensor * s  = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs);
-                ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
-                ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
-                ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
-                ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
-                op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
-            } break;
-        case GGML_OP_RWKV_WKV6:
-            {
-                // FIXME
-                const int64_t S = 123;
-                const int64_t H = 123;
-                const int64_t n_tokens = 123;
-                const int64_t n_seqs = 123;
-                ggml_tensor  * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
-                ggml_tensor  * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * tf = w;
-                ggml_tensor  * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
-                op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
-            } break;
-        case GGML_OP_IM2COL:
-            {
-                const int n_embd = hparams.n_embd;
-                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
-                op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
-            } break;
-        default:
-            GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
-    }
-
-    // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
-    GGML_ASSERT(w->buffer == nullptr);
-    w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
-    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
-    ggml_backend_buffer_free(w->buffer);
-    w->buffer = nullptr;
-
-    return op_supported;
-}
-
-// find the first buffer type in the list that can use the tensor
-static ggml_backend_buffer_type_t select_weight_buft(const llama_model & model, ggml_tensor * tensor, ggml_op op, const llama_model::buft_list_t & buft_list) {
-    GGML_ASSERT(!buft_list.empty());
-    for (const auto & cur : buft_list) {
-        ggml_backend_dev_t cur_dev = cur.first;
-        ggml_backend_buffer_type_t cur_buft = cur.second;
-        if (weight_buft_supported(model.hparams, tensor, op, cur_buft, cur_dev)) {
-            return cur_buft;
-        }
-    }
-    return nullptr;
-}
-
-// CPU: ACCEL -> CPU extra -> GPU host -> CPU
-static llama_model::buft_list_t make_cpu_buft_list(llama_model & model) {
-    llama_model::buft_list_t buft_list;
-
-    // add ACCEL buffer types
-    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
-            auto * buft = ggml_backend_dev_buffer_type(dev);
-            // skip
-            if (buft != ggml_backend_cpu_buffer_type()) {
-                buft_list.emplace_back(dev, buft);
-            }
-        }
-    }
-
-    // add extra buffer types
-    auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-    auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
-    auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
-        ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
-    if (ggml_backend_dev_get_extra_bufts_fn) {
-        ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
-        while (extra_bufts && *extra_bufts) {
-            buft_list.emplace_back(cpu_dev, *extra_bufts);
-            ++extra_bufts;
-        }
-    }
-
-    // add a host buffer type
-    // storing the tensors in a host buffer is useful when the processing of large batches
-    // is offloaded to a GPU device, since it reduces the time spent on data transfers
-    // generally, this will be done using the first device in the list
-    // a better approach would be to handle this on a weight-by-weight basis using the offload_op
-    // function of the device to determine if it would benefit from being stored in a host buffer
-    for (auto * dev : model.devices) {
-        ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev);
-        if (buft) {
-            buft_list.emplace_back(dev, buft);
-            break;
-        }
-    }
-
-    // add the CPU buffer type
-    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
-            buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
-        }
-    }
-
-    return buft_list;
-}
-
-// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU
-static llama_model::buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) {
-    llama_model::buft_list_t buft_list;
-
-    // add the device split buffer type if requested and available
-    if (split_mode == LLAMA_SPLIT_MODE_ROW) {
-        ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
-        auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t)
-            ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type");
-        if (ggml_backend_split_buffer_type_fn) {
-            size_t dev_index = [&]() {
-                auto * reg = ggml_backend_dev_backend_reg(dev);
-                for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) {
-                    if (ggml_backend_reg_dev_get(reg, i) == dev) {
-                        return i;
-                    }
-                }
-                throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev)));
-            }();
-            auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split);
-            if (buft != nullptr) {
-                buft_list.emplace_back(dev, buft);
-            }
-        }
-    }
-
-    // add the device default buffer type
-    buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
-
-    return buft_list;
-}
-
-// Returns false if cancelled by progress_callback
-static bool llm_load_tensors(
-        llama_model_loader & ml,
-        llama_model & model,
-        int n_gpu_layers,
-        enum llama_split_mode split_mode,
-        int main_gpu,
-        const float * tensor_split,
-        bool use_mlock,
-        llama_progress_callback progress_callback,
-        void * progress_callback_user_data) {
-    auto & hparams = model.hparams;
-
-    model.split_mode   = split_mode;
-    model.main_gpu     = main_gpu;
-    model.n_gpu_layers = n_gpu_layers;
-
-    const int n_layer = hparams.n_layer;
-
-    bool use_mmap_buffer = true;
-
-    // build a list of buffer types for the CPU and GPU devices
-    model.cpu_buft_list = make_cpu_buft_list(model);
-    for (auto * dev : model.devices) {
-        llama_model::buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
-        // add CPU buffer types as a fallback
-        buft_list.insert(buft_list.end(), model.cpu_buft_list.begin(), model.cpu_buft_list.end());
-        model.gpu_buft_list.emplace(dev, std::move(buft_list));
-    }
-
-    // calculate the split points
-    int device_count = llama_get_device_count(model);
-    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
-    std::vector splits(device_count);
-    if (all_zero) {
-        // default split, by free memory
-        for (int i = 0; i < device_count; ++i) {
-            ggml_backend_dev_t dev = model.devices[i];
-            size_t total;
-            size_t free;
-            ggml_backend_dev_memory(dev, &free, &total);
-            splits[i] = free;
-        }
-    } else {
-        std::copy(tensor_split, tensor_split + device_count, splits.begin());
-    }
-
-    // sum and normalize the splits to get the split points
-    float split_sum = 0.0f;
-    for (int i = 0; i < device_count; ++i) {
-        split_sum += splits[i];
-        splits[i] = split_sum;
-    }
-    for (int i = 0; i < device_count; ++i) {
-        splits[i] /= split_sum;
-    }
-
-    ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-    const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
-    const int act_gpu_layers = model.devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
-    auto get_layer_buft_list = [&](int il) -> llama_model::layer_dev {
-        if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
-            return {cpu_dev, &model.cpu_buft_list};
-        }
-        int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
-        auto * dev = model.devices.at(layer_gpu);
-        return {dev, &model.gpu_buft_list.at(dev)};
-    };
-
-    // assign the input layer
-    // there is very little benefit to offloading the input layer, so always keep it on the CPU
-    model.dev_input = { cpu_dev, &model.cpu_buft_list };
-
-    // assign the repeating layers to the devices according to the splits
-    model.dev_layer.resize(n_layer);
-    for (int il = 0; il < n_layer; ++il) {
-        model.dev_layer[il] = get_layer_buft_list(il);
-    }
-    // assign the output layer
-    model.dev_output = get_layer_buft_list(n_layer);
-
-    // one ggml context per buffer type
-    int max_n_tensors = ml.n_tensors;
-    max_n_tensors += 1;         // duplicated output tensor
-    max_n_tensors += n_layer*2; // duplicated rope freq tensors
-    const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
-
-    std::map ctx_map;
-    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            ggml_init_params params = {
-                /*.mem_size   =*/ ctx_size,
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-            ggml_context * ctx = ggml_init(params);
-            if (!ctx) {
-                throw std::runtime_error(format("failed to create ggml context"));
-            }
-            ctx_map[buft] = ctx;
-            model.ctxs.emplace_back(ctx);
-            return ctx;
-        }
-        return it->second;
-    };
-
-    // create tensors for the weights
-    {
-        // note: cast to int64_t since we will use these for the tensor dimensions
-        const int64_t n_head        = hparams.n_head();
-        const int64_t n_head_kv     = hparams.n_head_kv();
-        const int64_t n_embd        = hparams.n_embd;
-        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
-        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-        const int64_t n_embd_head_v = hparams.n_embd_head_v;
-        const int64_t n_ff          = hparams.n_ff();
-        const int64_t n_embd_gqa    = n_embd_v_gqa;
-        const int64_t n_vocab       = hparams.n_vocab;
-        const int64_t n_vocab_type  = hparams.n_vocab_type;
-        const int64_t n_rot         = hparams.n_rot;
-        const int64_t n_expert      = hparams.n_expert;
-        const int64_t n_expert_used = hparams.n_expert_used;
-        const int64_t n_ctx_train   = hparams.n_ctx_train;
-
-        if (n_expert > 0 && hparams.n_expert_used == 0) {
-            throw std::runtime_error("model has expert layers but no expert layers are used");
-        }
-
-        int n_moved_tensors = 0;
-        ggml_tensor * first_moved_tensor = nullptr;
-        ggml_backend_buffer_type_t first_moved_from_buft = nullptr;
-        ggml_backend_buffer_type_t first_moved_to_buft = nullptr;
-
-        auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * {
-            ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str());
-
-            if (!t_meta) {
-                if (flags & llama_model_loader::TENSOR_NOT_REQUIRED) {
-                    return nullptr;
-                }
-                throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
-            }
-
-            // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops
-            // the tensor is duplicated
-            // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor
-            llm_tensor tn_tensor = tn.tensor;
-            if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & llama_model_loader::TENSOR_DUPLICATED) {
-                tn_tensor = LLM_TENSOR_OUTPUT;
-            }
-
-            llm_tensor_info info;
-            try {
-                info = llm_tensor_info_for(tn_tensor);
-            } catch (const std::out_of_range & e) {
-                throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str()));
-            }
-
-            // tensors with "bias" suffix are always used with GGML_OP_ADD
-            ggml_op op;
-            bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
-            if (bias) {
-                op = GGML_OP_ADD;
-            } else {
-                op = info.op;
-            }
-
-            // sanity checks
-            if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) {
-                if (tn.bid != -1) {
-                    GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());
-                }
-            } else {
-                if (tn.bid == -1) {
-                    GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str());
-                }
-            }
-
-            // select the buffer type for this tensor
-            llama_model::buft_list_t * buft_list;
-            switch (info.layer) {
-                case LLM_TENSOR_LAYER_INPUT:
-                    buft_list = model.dev_input.buft_list;
-                    break;
-                case LLM_TENSOR_LAYER_OUTPUT:
-                    buft_list = model.dev_output.buft_list;
-                    break;
-                case LLM_TENSOR_LAYER_REPEATING:
-                    buft_list = model.dev_layer.at(tn.bid).buft_list;
-                    break;
-                default:
-                    GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
-            }
-
-            ggml_backend_buffer_type_t buft = select_weight_buft(model, t_meta, op, *buft_list);
-            if (!buft) {
-                throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
-            }
-
-            // avoid using a host buffer when using mmap
-            auto * buft_dev = ggml_backend_buft_get_device(buft);
-            if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
-                auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-                buft = ggml_backend_dev_buffer_type(cpu_dev);
-            }
-
-            if (buft != buft_list->front().second) {
-                n_moved_tensors++;
-                if (!first_moved_tensor) {
-                    first_moved_tensor = t_meta;
-                    first_moved_from_buft = buft_list->front().second;
-                    first_moved_to_buft   = buft;
-                }
-            }
-
-            ggml_context * ctx = ctx_for_buft(buft);
-
-            // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one
-            if (flags & llama_model_loader::TENSOR_DUPLICATED) {
-                ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str());
-                if (t) {
-                    return t;
-                }
-            }
-            return ml.create_tensor(ctx, tn, ne, flags);
-        };
-
-        model.layers.resize(n_layer);
-
-        // TODO: move to a separate function
-        const auto tn = LLM_TN(model.arch);
-        switch (model.arch) {
-            case LLM_ARCH_LLAMA:
-            case LLM_ARCH_REFACT:
-            case LLM_ARCH_MINICPM:
-            case LLM_ARCH_GRANITE:
-            case LLM_ARCH_GRANITE_MOE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
-                            layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                            layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-                        else {
-                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-
-                        if (n_expert == 0) {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                            // optional MLP bias
-                            layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        } else {
-                            layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_MLLAMA:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
-
-                    // output
-                    {
-                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        if (hparams.cross_attention_layers(i)) {
-                            layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128}, 0);
-                            layer.cross_attn_k_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_PROJ,   "weight", i), {n_embd, 1024}, 0);
-                            layer.cross_attn_o_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_O_PROJ,   "weight", i), {n_embd, n_embd}, 0);
-                            layer.cross_attn_q_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}, 0);
-                            layer.cross_attn_q_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}, 0);
-                            layer.cross_attn_v_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}, 0);
-                            layer.cross_attn_attn_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}, 0);
-                            layer.cross_attn_mlp_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}, 0);
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        } else {
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_DECI:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-                        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(i);
-                        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(i);
-                        const int64_t n_embd_gqa    = hparams.n_embd_v_gqa(i);
-                        const int64_t n_ff          = hparams.n_ff(i);
-                        const int64_t n_head        = hparams.n_head(i);
-                        const int64_t n_head_kv     = hparams.n_head_kv(i);
-
-                        if (n_head_kv == 0 && n_head > 0) {
-                            // linear attention for DeciLMCausalModel
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        }
-                        else if (n_head_kv > 0) {
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-                        }
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
-                            layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                            layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-                        else {
-                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        // optional MLP bias
-                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_MINICPM3:
-                {
-                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
-                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-
-                    const int64_t q_lora_rank  = hparams.n_lora_q;
-                    const int64_t kv_lora_rank = hparams.n_lora_kv;
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
-
-                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
-
-                        layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
-                        layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
-
-                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
-                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
-                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                    }
-                } break;
-            case LLM_ARCH_GROK:
-                {
-                    if (n_expert == 0) {
-                        throw std::runtime_error("Grok model cannot have zero experts");
-                    }
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
-
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_DBRX:
-                {
-                    if (n_expert == 0) {
-                        throw std::runtime_error("DBRX model cannot have zero experts");
-                    }
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                    }
-                } break;
-            case LLM_ARCH_BAICHUAN:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    {
-                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_FALCON:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    {
-                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-
-                        model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_STARCODER:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
-
-                    // output
-                    {
-                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                        model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            // needs to be on GPU
-                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
-                        layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_BERT:
-            case LLM_ARCH_NOMIC_BERT:
-                {
-                    model.tok_embd     = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
-                    model.type_embd    = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0);
-
-                    if (model.arch == LLM_ARCH_BERT) {
-                        model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train}, 0);
-
-                        model.cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {n_embd},         llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        model.cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        model.cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        if (model.arch == LLM_ARCH_BERT) {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd}, 0);
-
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa}, 0);
-
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa}, 0);
-                        } else {
-                            layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        }
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd}, 0);
-
-                        if (model.arch == LLM_ARCH_BERT) {
-                            layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
-                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, 0);
-                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
-                        } else {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-                        }
-
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
-                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_JINA_BERT_V2:
-                {
-                    model.tok_embd  = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0); // word_embeddings
-                    model.type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0); // token_type_embeddings
-
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0); //LayerNorm bias
-
-                    model.cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i]; // JinaBertLayer
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
-
-                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa}, 0);
-
-                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0); //output_dens
-
-                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm
-                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd}, 0);
-
-                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
-
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
-                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias",   i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_BLOOM:
-                {
-                    model.tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias",   i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias",   i), {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias",   i), {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias",   i), {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_MPT:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    if (!model.output) {
-                        model.output    = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // AWQ ScaleActivation layer
-                        layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_STABLELM:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm =   create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors, present in Stable LM 2 1.6B
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // optional q and k layernorms, present in StableLM 2 12B
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head},    llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_QWEN:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd*3}, 0);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff/2}, 0);
-                    }
-                } break;
-            case LLM_ARCH_QWEN2:
-            case LLM_ARCH_QWEN2VL:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_QWEN2MOE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-
-                        if (n_expert == 0) {
-                            throw std::runtime_error("n_expert must be > 0 for QWEN2MOE");
-                        }
-                        if (n_expert_used == 0) {
-                            throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE");
-                        }
-
-                        // MoE branch
-                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
-
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-
-                        // Shared expert branch
-                        const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
-
-                        layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, n_ff_shexp}, 0);
-                        layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp,     n_embd}, 0);
-                        layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, n_ff_shexp}, 0);
-                    }
-                } break;
-            case LLM_ARCH_PHI2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-                    model.output_b      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        if (layer.wqkv == nullptr) {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
-                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
-
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i),   {n_embd_gqa}, 0);
-
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i),   {n_embd_gqa}, 0);
-                        }
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_PHI3:
-                {
-                    const int64_t n_embd_head = n_embd / n_head;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
-                        layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
-
-                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                    }
-                } break;
-            case LLM_ARCH_PLAMO:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GPT2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_CODESHELL:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_ORION:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_INTERNLM2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GEMMA:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GEMMA2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_STARCODER2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        // optional bias tensors
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP ,  "bias", i), {  n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_MAMBA:
-                {
-                    const int64_t d_conv  = hparams.ssm_d_conv;
-                    const int64_t d_inner = hparams.ssm_d_inner;
-                    const int64_t d_state = hparams.ssm_d_state;
-                    const int64_t dt_rank = hparams.ssm_dt_rank;
-
-                    // only an expansion factor of 2 is supported for now
-                    if (2 * n_embd != d_inner) {
-                        throw std::runtime_error("only an expansion factor of 2 is supported for now");
-                    }
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed, duplicated to allow offloading
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        // norm
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0);
-
-                        layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0);
-                        layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0);
-
-                        layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0);
-
-                        layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0);
-                        layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0);
-
-                        // no "weight" suffix for these
-                        layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
-                        layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
-
-                        // out_proj
-                        layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_XVERSE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_COMMAND_R:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    // init output from the input tok embed
-                    model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (n_layer >= 64){
-                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
-                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
-                        }
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_COHERE2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
-                    // init output from the input tok embed
-                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
-                                                      llama_model_loader::TENSOR_DUPLICATED);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
-                    }
-                }
-                break;
-            case LLM_ARCH_OLMO:  // adapted from LLM_ARCH_LLAMA with norm params removed
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_OLMO2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_OLMOE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-
-                        if (n_expert == 0) {
-                            throw std::runtime_error("n_expert must be > 0");
-                        }
-                        if (n_expert_used == 0) {
-                            throw std::runtime_error("n_expert_used must be > 0");
-                        }
-
-                        // MoE branch
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                    }
-                } break;
-            case LLM_ARCH_OPENELM:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    // init output from the input tok embed
-                    model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        const int64_t n_head      =   hparams.n_head(i);
-                        const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
-                        const int64_t n_ff        =   hparams.n_ff(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GPTNEOX:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_ARCTIC:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
-                    }
-                } break;
-            case LLM_ARCH_DEEPSEEK:
-                {
-
-                    const int64_t n_ff_exp        = hparams.n_ff_exp;
-                    const int64_t n_expert_shared = hparams.n_expert_shared;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (i < (int) hparams.n_layer_dense_lead) {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        } else {
-                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-
-                            if (n_expert == 0) {
-                                throw std::runtime_error("n_expert must be > 0");
-                            }
-                            if (n_expert_used == 0) {
-                                throw std::runtime_error("n_expert_used must be > 0");
-                            }
-
-                            // MoE branch
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-
-                            // Shared expert branch
-                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
-                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_DEEPSEEK2:
-                {
-                    const bool is_lite = (hparams.n_layer == 27);
-
-                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
-                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-
-                    const int64_t q_lora_rank  = hparams.n_lora_q;
-                    const int64_t kv_lora_rank = hparams.n_lora_kv;
-
-                    const int64_t n_ff_exp        = hparams.n_ff_exp;
-                    const int64_t n_expert_shared = hparams.n_expert_shared;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        if (!is_lite) {
-                            layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
-                        }
-
-                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
-
-                        if (!is_lite) {
-                            layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
-                            layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
-                        } else {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        }
-
-                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
-                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
-                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (i < (int) hparams.n_layer_dense_lead) {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        } else {
-                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                            if (n_expert == 0) {
-                                throw std::runtime_error("n_expert must be > 0");
-                            }
-                            if (n_expert_used == 0) {
-                                throw std::runtime_error("n_expert_used must be > 0");
-                            }
-
-                            // MoE branch
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-
-                            // Shared expert branch
-                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
-                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_BITNET:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm     = create_tensor(tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd}, 0);
-                        layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq       = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wk       = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wv       = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wo       = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm     = create_tensor(tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd}, 0);
-                        layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
-
-                        layer.ffn_gate       = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down       = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up         = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_scale   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_T5:
-                {
-                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm     = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0);
-
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        layer.attn_norm  = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.attn_norm_cross  = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        // this tensor seems to be unused in HF transformers implementation
-                        layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_T5ENCODER:
-                {
-                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_JAIS:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_CHATGLM:
-                {
-                    model.tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_NEMOTRON:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        // optional MLP bias
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_EXAONE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM,   "weight", i), {n_embd}, 0);
-                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN,   "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,     "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_RWKV6:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // Block 0, LN0
-                    model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
-
-                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
-                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
-                    const int head_size = hparams.wkv_head_size;
-                    const int attn_hidden_size = n_embd;
-                    const int ffn_size = hparams.n_ff_arr[0];
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, 0);
-
-                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
-                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
-
-                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
-
-                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
-                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
-                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
-                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
-                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
-                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
-                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
-                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
-
-                        layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
-                        layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
-                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
-
-                        layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
-
-                        layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
-                        layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
-                        layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0);
-                    }
-
-                } break;
-            case LLM_ARCH_CHAMELEON:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
-                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i),  {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i),  {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_SOLAR:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    {
-                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.bskcn_tv = create_tensor(tn(LLM_TENSOR_BSKCN_TV, "weight", i), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_WAVTOKENIZER_DEC:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);
-
-                    model.conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0);
-                    model.conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"),   {1, hparams.posnet.n_embd}, 0);
-
-                    // posnet
-                    {
-                        const int64_t n_embd = hparams.posnet.n_embd;
-
-                        for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) {
-                            auto & layer = model.layers[i].posnet;
-
-                            // posnet:
-                            //
-                            //  - resnet
-                            //  - resnet
-                            //  - attn
-                            //  - resnet
-                            //  - resnet
-                            //  - norm
-                            //
-                            switch (i) {
-                                case 0:
-                                case 1:
-                                case 3:
-                                case 4:
-                                    {
-                                        layer.norm1   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0);
-                                        layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.conv1   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0);
-                                        layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.norm2   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0);
-                                        layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.conv2   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0);
-                                        layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias",   i), {1, n_embd}, 0);
-                                    } break;
-                                case 2:
-                                    {
-                                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
-                                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_q      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_q_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_k      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_k_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_v      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_v_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_o      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_o_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "bias",   i), {1, n_embd}, 0);
-                                    } break;
-                                case 5:
-                                    {
-                                        layer.norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
-                                        layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
-                                    } break;
-                                default: GGML_ABORT("unknown posnet layer");
-                            };
-                        }
-                    }
-
-                    GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
-
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {hparams.posnet.n_embd}, 0);
-
-                    // convnext
-                    {
-                        const int64_t n_embd = hparams.convnext.n_embd;
-
-                        for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) {
-                            auto & layer = model.layers[i].convnext;
-
-                            layer.dw     = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "weight", i), {7, 1, n_embd}, 0);
-                            layer.dw_b   = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "bias",   i), {1, n_embd}, 0);
-
-                            layer.norm   = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "weight", i), {n_embd}, 0);
-                            layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "bias",   i), {n_embd}, 0);
-
-                            layer.pw1    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "weight", i), {n_embd, n_ff}, 0);
-                            layer.pw1_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "bias",   i), {n_ff}, 0);
-
-                            layer.pw2    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "weight", i), {n_ff, n_embd}, 0);
-                            layer.pw2_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "bias",   i), {n_embd}, 0);
-
-                            layer.gamma  = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0);
-                        }
-
-                        // output
-                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    }
-
-                    model.output   = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
-                    model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"),   {n_embd}, 0);
-                } break;
-            default:
-                throw std::runtime_error("unknown architecture");
-        }
-
-        if (n_moved_tensors > 0) {
-            LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n",
-                __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
-                ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
-        }
-    }
-
-    ml.done_getting_tensors();
-
-    ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr);
-    model.mappings.reserve(ml.mappings.size());
-
-    // create the backend buffers
-    std::vector> ctx_bufs;
-    ctx_bufs.reserve(ctx_map.size());
-
-    // Ensure we have enough capacity for the maximum backend buffer we will potentially create
-    const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
-    model.bufs.reserve(n_max_backend_buffer);
-
-    for (auto & it : ctx_map) {
-        ggml_backend_buffer_type_t buft = it.first;
-        ggml_context * ctx              = it.second;
-
-        // skip contexts without tensors
-        if (ggml_get_first_tensor(ctx) == nullptr) {
-            continue;
-        }
-
-        llama_buf_map bufs;
-        bufs.reserve(n_max_backend_buffer);
-
-        // check if it is possible to use buffer_from_host_ptr with this buffer type
-        ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
-        if (!dev) {
-            // FIXME: workaround for CPU backend buft having a NULL device
-            dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-        }
-        ggml_backend_dev_props props;
-        ggml_backend_dev_get_props(dev, &props);
-        bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
-        bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
-
-        if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                // only the mmap region containing the tensors in the model is mapped to the backend buffer
-                // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
-                // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
-                void * addr = nullptr;
-                size_t first, last; // NOLINT
-                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
-                if (first >= last) {
-                    continue;
-                }
-                const size_t max_size = ggml_get_max_tensor_size(ctx);
-                ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
-                if (buf == nullptr) {
-                    throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
-                }
-                model.bufs.emplace_back(buf);
-                bufs.emplace(idx, buf);
-            }
-        }
-        else {
-            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-            if (buf == nullptr) {
-                throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
-            }
-            model.bufs.emplace_back(buf);
-            if (use_mlock && ggml_backend_buffer_is_host(buf)) {
-                model.mlock_bufs.emplace_back(new llama_mlock);
-                auto & mlock_buf = model.mlock_bufs.back();
-                mlock_buf->init   (ggml_backend_buffer_get_base(buf));
-                mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
-            }
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                bufs.emplace(idx, buf);
-            }
-        }
-
-        if (bufs.empty()) {
-            throw std::runtime_error("failed to allocate buffer");
-        }
-
-        for (auto & buf : bufs) {
-            // indicate that this buffer contains weights
-            // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight
-            ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
-        }
-
-        ctx_bufs.emplace_back(ctx, bufs);
-    }
-
-    if (llama_supports_gpu_offload()) {
-        const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
-
-        LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
-        if (n_gpu_layers > (int) hparams.n_layer) {
-            LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
-        }
-
-        const int max_backend_supported_layers = hparams.n_layer + 1;
-        const int max_offloadable_layers       = hparams.n_layer + 1;
-
-        LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
-    }
-
-    // print memory requirements per buffer type
-    for (auto & buf : model.bufs) {
-        LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
-    }
-
-    // populate tensors_by_name
-    for (auto & ctx : model.ctxs) {
-        for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
-            model.tensors_by_name.emplace_back(ggml_get_name(cur), cur);
-        }
-    }
-
-    // load tensor data
-    for (auto & it : ctx_bufs) {
-        ggml_context * ctx = it.first;
-        auto & bufs = it.second;
-        if (!ml.load_all_data(ctx, bufs, use_mlock ? &model.mlock_mmaps : NULL, progress_callback, progress_callback_user_data)) {
-            return false;
-        }
-    }
-
-    if (use_mmap_buffer) {
-        for (auto & mapping : ml.mappings) {
-            model.mappings.emplace_back(std::move(mapping));
-        }
-    }
-
-    return true;
-}
-
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
-static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
-    model.t_start_us = ggml_time_us();
+static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) {
+    // loading time will be recalculated after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    model.t_load_us = 0;
+    time_meas tm(model.t_load_us);
+
+    model.t_start_us = tm.t_start_us;
 
     try {
-        llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
+        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
+
+        ml.print_info();
 
         model.hparams.vocab_only = params.vocab_only;
 
         try {
-            llm_load_arch(ml, model);
+            model.load_arch(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
         }
         try {
-            llm_load_hparams(ml, model);
+            model.load_hparams(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
         }
         try {
-            llm_load_vocab(ml, model);
+            model.load_vocab(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
         }
 
-        llm_load_stats(ml, model);
-        llm_load_print_meta(ml, model);
-
-        if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
-            model.hparams.n_vocab != model.vocab.id_to_token.size()) {
-            LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size());
-        }
+        model.load_stats(ml);
+        model.print_info();
 
         if (params.vocab_only) {
             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
             return 0;
         }
 
-        if (!llm_load_tensors(
-            ml, model, params.n_gpu_layers, params.split_mode,  params.main_gpu, params.tensor_split, params.use_mlock,
-            params.progress_callback, params.progress_callback_user_data
-        )) {
+        if (!model.load_tensors(ml)) {
             return -2;
         }
     } catch (const std::exception & err) {
@@ -2579,10 +78,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
         return -1;
     }
 
-    // loading time will be recalculate after the first eval, so
-    // we take page faults deferred by mmap() into consideration
-    model.t_load_us = ggml_time_us() - model.t_start_us;
-
     return 0;
 }
 
@@ -2615,21 +110,36 @@ static struct ggml_tensor * llm_build_inp_embd(
         struct ggml_context * ctx,
        struct llama_context & lctx,
         const llama_hparams & hparams,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_tensor * tok_embd,
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
 
     struct ggml_tensor * inpL;
 
-    if (batch.token) {
-        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
+    if (ubatch.token) {
+        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
         cb(lctx.inp_tokens, "inp_tokens", -1);
         ggml_set_input(lctx.inp_tokens);
 
         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
+
+        // apply lora for embedding tokens if needed
+        for (auto & it : lctx.lora) {
+            struct llama_adapter_lora_weight * lw = it.first->get_weight(tok_embd);
+            if (lw == nullptr) {
+                continue;
+            }
+            const float adapter_scale = it.second;
+            const float scale = lw->get_scale(it.first->alpha, adapter_scale);
+            struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
+                ctx, lw->b, // non-transposed lora_b
+                ggml_get_rows(ctx, lw->a, lctx.inp_tokens)
+            ), scale);
+            inpL = ggml_add(ctx, inpL, inpL_delta);
+        }
     } else {
-        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
         inpL = lctx.inp_embd;
         ggml_set_input(lctx.inp_embd);
     }
@@ -2710,17 +220,16 @@ static struct ggml_tensor * llm_build_lora_mm(
           struct ggml_tensor * w,
           struct ggml_tensor * cur) {
     struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
-    for (auto & it : lctx.lora_adapters) {
-        struct llama_lora_weight * lora = it.first->get_weight(w);
-        if (lora == nullptr) {
+    for (auto & it : lctx.lora) {
+        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
+        if (lw == nullptr) {
             continue;
         }
-        const float alpha = it.first->alpha;
-        const float rank  = (float) lora->b->ne[0];
-        const float scale = alpha ? it.second * alpha / rank : it.second;
+        const float adapter_scale = it.second;
+        const float scale = lw->get_scale(it.first->alpha, adapter_scale);
         struct ggml_tensor * ab_cur = ggml_mul_mat(
-            ctx0, lora->b,
-            ggml_mul_mat(ctx0, lora->a, cur)
+            ctx0, lw->b,
+            ggml_mul_mat(ctx0, lw->a, cur)
         );
         ab_cur = ggml_scale(ctx0, ab_cur, scale);
         res = ggml_add(ctx0, res, ab_cur);
@@ -2736,17 +245,17 @@ static struct ggml_tensor * llm_build_lora_mm_id(
           struct ggml_tensor * cur, // struct ggml_tensor * b
           struct ggml_tensor * ids) {
     struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
-    for (auto & it : lctx.lora_adapters) {
-        struct llama_lora_weight * lora = it.first->get_weight(w);
-        if (lora == nullptr) {
+    for (auto & it : lctx.lora) {
+        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
+        if (lw == nullptr) {
             continue;
         }
         const float alpha = it.first->alpha;
-        const float rank  = (float) lora->b->ne[0];
+        const float rank  = (float) lw->b->ne[0];
         const float scale = alpha ? it.second * alpha / rank : it.second;
         struct ggml_tensor * ab_cur = ggml_mul_mat_id(
-            ctx0, lora->b,
-            ggml_mul_mat_id(ctx0, lora->a, cur, ids),
+            ctx0, lw->b,
+            ggml_mul_mat_id(ctx0, lw->a, cur, ids),
             ids
         );
         ab_cur = ggml_scale(ctx0, ab_cur, scale);
@@ -3239,7 +748,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
 static struct ggml_tensor * llm_build_mamba(
         struct ggml_context * ctx,
        struct llama_context & lctx,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_cgraph * graph,
          struct ggml_tensor * cur,
          struct ggml_tensor * state_copy,
@@ -3255,17 +764,17 @@ static struct ggml_tensor * llm_build_mamba(
     const int64_t d_inner = hparams.ssm_d_inner;
     const int64_t d_state = hparams.ssm_d_state;
     const int64_t dt_rank = hparams.ssm_dt_rank;
-    const int64_t n_seqs  = batch.n_seqs;
+    const int64_t n_seqs  = ubatch.n_seqs;
     // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
     const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
     // Use the same RMS norm as the final layer norm
     const float norm_rms_eps = hparams.f_norm_rms_eps;
 
-    const int64_t n_seq_tokens = batch.n_seq_tokens;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
     GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(batch.equal_seqs);
-    GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
+    GGML_ASSERT(ubatch.equal_seqs);
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
     struct ggml_tensor * conv_states_all = kv.k_l[il];
     struct ggml_tensor * ssm_states_all  = kv.v_l[il];
@@ -3377,16 +886,20 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         const struct llama_layer * layer,
         struct ggml_tensor * cur,
         struct ggml_tensor * x_prev,
-        struct ggml_tensor ** wkv_state) {
+        struct ggml_tensor ** wkv_state,
+        size_t wkv_head_size,
+        size_t head_count_kv) {
     size_t n_embd       = cur->ne[0];
     size_t n_seq_tokens = cur->ne[1];
     size_t n_seqs       = cur->ne[2];
 
-    size_t head_size  = layer->time_mix_first->ne[0];
-    size_t head_count = layer->time_mix_first->ne[1];
+    size_t head_size  = wkv_head_size;
+    size_t head_count = n_embd / head_size;
 
     size_t n_tokens = n_seqs * n_seq_tokens;
 
+    bool is_qrwkv = layer->time_mix_first == nullptr;
+
     struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
 
     sx  = ggml_reshape_2d(ctx, sx,  n_embd, n_tokens);
@@ -3415,69 +928,64 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         xxx
     );
 
-    struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
-    struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
-    struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
-    struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
-    struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+    struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
+    if (layer->time_mix_lerp_fused) {
+        // fusing these weights makes some performance improvement
+        sx  = ggml_reshape_3d(ctx, sx,  n_embd, 1, n_tokens);
+        cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
+        xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+    } else {
+        // for backward compatibility
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
 
-    struct ggml_tensor * xw = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mw, layer->time_mix_lerp_w),
-            sx
-        ),
-        cur
-    );
+        xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
+        xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
+        xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
+        xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
+        xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
+    }
 
-    struct ggml_tensor * xk = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mk, layer->time_mix_lerp_k),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
+    struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk);
+    struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv);
+    if (layer->time_mix_receptance_b) {
+        r = ggml_add(ctx, r, layer->time_mix_receptance_b);
+    }
+    if (layer->time_mix_key_b) {
+        k = ggml_add(ctx, k, layer->time_mix_key_b);
+    }
+    if (layer->time_mix_value_b) {
+        v = ggml_add(ctx, v, layer->time_mix_value_b);
+    }
 
-    struct ggml_tensor * xv = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mv, layer->time_mix_lerp_v),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg);
+    if (is_qrwkv) {
+        g = ggml_sigmoid(ctx, g);
+    } else {
+        g = ggml_silu(ctx, g);
+    }
 
-    struct ggml_tensor * xr = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mr, layer->time_mix_lerp_r),
-            sx
-        ),
-        cur
-    );
+    if (head_count_kv != head_count) {
+        GGML_ASSERT(head_count % head_count_kv == 0);
+        k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens);
+        v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens);
+        struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
+        k = ggml_repeat(ctx, k, tmp);
+        v = ggml_repeat(ctx, v, tmp);
+    }
 
-    struct ggml_tensor * xg = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mg, layer->time_mix_lerp_g),
-            sx
-        ),
-        cur
-    );
-
-    struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk), 1,         head_size, head_count, n_tokens);
-    struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * g = ggml_silu(
-        ctx,
-        llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
-    );
+    k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens);
+    v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens);
+    r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens);
 
     struct ggml_tensor * w = ggml_mul_mat(
         ctx,
@@ -3488,25 +996,35 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         )
     );
 
-    w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
+    w = ggml_add(ctx, w, layer->time_mix_decay);
     w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
-    w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+    w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
 
-    k = ggml_transpose(ctx, k);
-    v = ggml_transpose(ctx, v);
-    r = ggml_transpose(ctx, r);
+    if (is_qrwkv) {
+        // k = k * (1 - w)
+        k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
+    }
 
-    struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    struct ggml_tensor * wkv_output;
+    if (!layer->time_mix_first) {
+        wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
+    } else {
+        wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    }
     cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
     *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
 
-    // group norm with head_count groups
-    cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
-    cur = ggml_norm(ctx, cur, 64e-5f);
+    if (!is_qrwkv) {
+        // group norm with head_count groups
+        cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
+        cur = ggml_norm(ctx, cur, 64e-5f);
 
-    // Convert back to regular vectors.
-    cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-    cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+        // Convert back to regular vectors.
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+        cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+    } else {
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+    }
 
     cur = ggml_mul(ctx, cur, g);
     cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
@@ -3670,7 +1188,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_k_shift() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         GGML_ASSERT(kv_self.size == n_ctx);
 
@@ -3720,7 +1238,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_defrag(const std::vector & moves) {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         for (const auto & move : moves) {
             for (int il = 0; il < n_layer; ++il) {
@@ -3965,7 +1483,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_llama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -4059,6 +1577,7 @@ struct llm_build_context {
 
             // feed-forward network
             if (model.layers[il].ffn_gate_inp == nullptr) {
+
                 cur = llm_build_norm(ctx0, ffn_inp, hparams,
                         model.layers[il].ffn_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -4129,8 +1648,8 @@ struct llm_build_context {
         return gf;
     }
 
-        struct ggml_cgraph * build_mllama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+    struct ggml_cgraph * build_mllama() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -4364,7 +1883,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deci() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -4525,7 +2044,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_baichuan() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -4537,7 +2056,7 @@ struct llm_build_context {
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr;
+        struct ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -4562,7 +2081,7 @@ struct llm_build_context {
                 cb(Vcur, "Vcur", il);
 
                 switch (model.type) {
-                    case MODEL_7B:
+                    case LLM_TYPE_7B:
                         Qcur = ggml_rope_ext(
                             ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -4574,7 +2093,7 @@ struct llm_build_context {
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         break;
-                    case MODEL_13B:
+                    case LLM_TYPE_13B:
                         Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
                         Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
                         break;
@@ -4640,7 +2159,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_xverse() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -4743,7 +2262,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_falcon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -4863,7 +2382,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_grok() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -5022,7 +2541,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_dbrx() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -5150,7 +2669,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_starcoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5254,7 +2773,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_refact() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5348,7 +2867,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bert() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5542,7 +3061,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bloom() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5643,7 +3162,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mpt() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5933,7 +3452,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6045,7 +3564,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6157,7 +3676,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2vl() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
         GGML_ASSERT(n_embd_head == hparams.n_rot);
@@ -6275,7 +3794,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2moe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -6423,7 +3942,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -6544,7 +4063,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -6577,7 +4096,7 @@ struct llm_build_context {
 
                 struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm,
-                    NULL,
+                    model.layers[il].attn_norm_b,
                     LLM_NORM_RMS, cb, il);
                 cb(attn_norm_output, "attn_norm", il);
 
@@ -6592,8 +4111,7 @@ struct llm_build_context {
                     Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
                     Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
-                }
-                else {
+                } else {
                     Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
@@ -6637,14 +4155,12 @@ struct llm_build_context {
             residual = cur;
 
             cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_norm, NULL,
+                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
                 LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            // FF
-            // special-case: the up and gate tensors are merged into a single tensor
-            // TOOD: support into llm_build_ffn
-            {
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
                 cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         NULL,                      NULL, NULL,
@@ -6652,6 +4168,20 @@ struct llm_build_context {
                         NULL,
                         LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        cb, il);
+                cb(cur, "ffn_moe_out", il);
             }
 
             cur = ggml_add(ctx0, residual, cur);
@@ -6664,11 +4194,16 @@ struct llm_build_context {
 
         cur = llm_build_norm(ctx0, inpL, hparams,
             model.output_norm,
-            NULL,
+            model.output_norm_b,
             LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (model.output_b != nullptr) {
+            cb(cur, "result_output_no_bias", -1);
+            cur = ggml_add(ctx0, cur, model.output_b);
+        }
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -6782,7 +4317,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gpt2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -6887,7 +4422,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_codeshell() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -6998,7 +4533,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_orion() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7116,7 +4651,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_internlm2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7234,7 +4769,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_minicpm3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         //TODO: if the model varies, these parameters need to be read from the model
         const int64_t n_embd_base = 256;
@@ -7318,7 +4853,8 @@ struct llm_build_context {
                         ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
                 cb(k_pe, "k_pe", il);
 
-                kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
+                kv_compressed = ggml_cont(ctx0, kv_compressed);
                 kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
                         model.layers[il].attn_kv_a_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -7350,7 +4886,7 @@ struct llm_build_context {
                     0);
                 cb(v_states, "v_states", il);
 
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 q_pe = ggml_rope_ext(
                     ctx0, q_pe, inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -7359,7 +4895,7 @@ struct llm_build_context {
                 cb(q_pe, "q_pe", il);
 
                 // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 k_pe = ggml_rope_ext(
                     ctx0, k_pe, inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -7443,7 +4979,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -7551,7 +5087,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -7601,9 +5137,9 @@ struct llm_build_context {
 
                 // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
                 switch (model.type) {
-                    case llm_type::MODEL_2B:
-                    case llm_type::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
-                    case llm_type::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    case LLM_TYPE_2B:
+                    case LLM_TYPE_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
                     default: GGML_ABORT("fatal error");
                 };
                 cb(Qcur, "Qcur_scaled", il);
@@ -7687,7 +5223,7 @@ struct llm_build_context {
 
 
     struct ggml_cgraph * build_starcoder2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7806,7 +5342,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mamba() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -7861,7 +5397,7 @@ struct llm_build_context {
 
     struct ggml_cgraph * build_command_r() {
 
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8009,7 +5545,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_cohere2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8146,7 +5682,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_olmo() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8270,7 +5806,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_olmo2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8398,7 +5934,7 @@ struct llm_build_context {
     //   * removed bias
     //   * added q, k norm
     struct ggml_cgraph * build_olmoe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8524,7 +6060,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_openelm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8649,7 +6185,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gptneox() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -8791,7 +6327,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_arctic() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8925,7 +6461,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deepseek() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9082,7 +6618,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deepseek2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9172,7 +6708,8 @@ struct llm_build_context {
                         ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
                 cb(k_pe, "k_pe", il);
 
-                kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
+                kv_compressed = ggml_cont(ctx0, kv_compressed);
                 kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
                         model.layers[il].attn_kv_a_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -9204,7 +6741,7 @@ struct llm_build_context {
                     0);
                 cb(v_states, "v_states", il);
 
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 q_pe = ggml_rope_ext(
                     ctx0, q_pe, inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9213,7 +6750,7 @@ struct llm_build_context {
                 cb(q_pe, "q_pe", il);
 
                 // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 k_pe = ggml_rope_ext(
                     ctx0, k_pe, inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9312,7 +6849,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bitnet() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9463,7 +7000,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_t5_enc() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9595,7 +7132,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_t5_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9800,7 +7337,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_jais() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9892,7 +7429,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_chatglm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9923,17 +7460,30 @@ struct llm_build_context {
                 struct ggml_tensor * Qcur = nullptr;
                 struct ggml_tensor * Kcur = nullptr;
                 struct ggml_tensor * Vcur = nullptr;
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
+                if (model.layers[il].wqkv == nullptr) {
+                    Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                    if (model.layers[il].bq) {
+                        Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    }
+                    Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                    if (model.layers[il].bk) {
+                        Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    }
+                    Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                    if (model.layers[il].bv) {
+                        Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    }
+                } else {
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
+                    cb(cur, "wqkv", il);
+                    if (model.layers[il].bqkv) {
+                        cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                        cb(cur, "bqkv", il);
+                    }
+                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                }
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
@@ -10006,7 +7556,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_nemotron() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10127,7 +7677,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_exaone() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10254,7 +7804,7 @@ struct llm_build_context {
     }
 
     ggml_cgraph * build_rwkv6() {
-        ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // Token shift state dimensions should be 2 * n_emb
         GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
@@ -10299,7 +7849,7 @@ struct llm_build_context {
                 1
             );
 
-            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
             ggml_build_forward_expand(gf, cur);
             ggml_build_forward_expand(
                 gf,
@@ -10366,6 +7916,114 @@ struct llm_build_context {
         return gf;
     }
 
+    // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
+    ggml_cgraph * build_rwkv6qwen2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+
+        const int64_t n_seqs = ubatch.n_seqs;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        struct ggml_tensor * state_copy = build_inp_s_copy();
+        struct ggml_tensor * state_mask = build_inp_s_mask();
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            // (ab)using the KV cache to store the states
+            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.k_l[il], state_copy, state_mask,
+                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.v_l[il], state_copy, state_mask,
+                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
+
+            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
+            struct ggml_tensor * x_prev = ggml_concat(
+                ctx0,
+                token_shift,
+                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+                1
+            );
+
+            struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
+                    ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
+                )
+            );
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
+            ggml_build_forward_expand(gf, ffn_inp);
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // ref: https://github.com/facebookresearch/chameleon
     // based on the original build_llama() function, changes:
     //   * qk-norm
@@ -10373,7 +8031,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_chameleon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10538,14 +8196,12 @@ struct llm_build_context {
         cb(img_logits, "img_logits", -1);
         cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
         cb(cur, "result_output", -1);
-
         ggml_build_forward_expand(gf, cur);
-
         return gf;
-    }
+   }
 
-    ggml_cgraph * build_solar() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+   ggml_cgraph * build_solar() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10681,23 +8337,19 @@ struct llm_build_context {
         }
 
         cur = inpL;
-
         cur = llm_build_norm(ctx0, cur, hparams,
                 model.output_norm, NULL,
                 LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
-
         // lm_head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
-
         ggml_build_forward_expand(gf, cur);
-
         return gf;
     }
 
     struct ggml_cgraph * build_wavtokenizer_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -10906,12 +8558,12 @@ static struct ggml_cgraph * llama_build_graph(
 
         // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
         // FIXME: fix in ggml_backend_sched
-        const bool full_offload = lctx.model.n_gpu_layers > (int)lctx.model.hparams.n_layer;
+        const bool full_offload = lctx.model.params.n_gpu_layers > (int) lctx.model.hparams.n_layer;
         if (ubatch.n_tokens < 32 || full_offload) {
             if (il != -1 && strcmp(name, "norm") == 0) {
-                const auto & dev_layer = lctx.model.dev_layer.at(il);
+                const auto & dev_layer = lctx.model.dev_layer(il);
                 for (auto & backend : lctx.backends) {
-                    if (ggml_backend_get_device(backend.get()) == dev_layer.dev) {
+                    if (ggml_backend_get_device(backend.get()) == dev_layer) {
                         if (ggml_backend_supports_op(backend.get(), cur)) {
                             ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, backend.get());
                         }
@@ -11003,6 +8655,7 @@ static struct ggml_cgraph * llama_build_graph(
                 result = llm.build_phi2();
             } break;
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
             {
                 result = llm.build_phi3();
             } break;
@@ -11130,6 +8783,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_rwkv6();
             } break;
+        case LLM_ARCH_RWKV6QWEN2:
+            {
+                result = llm.build_rwkv6qwen2();
+            } break;
         case LLM_ARCH_CHAMELEON:
             {
                 result = llm.build_chameleon();
@@ -11183,73 +8840,33 @@ static enum ggml_status llama_graph_compute(
     return status;
 }
 
-// decode a batch of tokens by evaluating the transformer
-// in case of unsuccessful decoding (error or warning),
-// the kv_cache state will be returned to its original state
-// (for non-recurrent models) or cleaned (for recurrent models)
-//
-//   - lctx:      llama context
-//   - batch:     batch to evaluate
-//
-// return 0 on success
-// return positive int on warning
-// return negative int on error
-//
-static int llama_decode_internal(
-         llama_context & lctx,
-           llama_batch   inp_batch) {
-
-    lctx.is_encoding = false;
-
-    if (inp_batch.n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
-        return -1;
-    }
-
-    // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
-
-    const llama_batch & batch = batch_allocr.batch;
-    const uint32_t n_tokens_all = batch.n_tokens;
-
+static int llama_prepare_sbatch(
+        llama_context     & lctx,
+        const llama_batch & batch,
+        uint32_t          & n_outputs) {
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+    const uint32_t n_tokens_all = batch.n_tokens;
+    const  int64_t n_embd       = hparams.n_embd;
 
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
     if (batch.token) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+            if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
         }
     }
-
     GGML_ASSERT(n_tokens_all <= cparams.n_batch);
-
     GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
 
-    if (lctx.t_compute_start_us == 0) {
-        lctx.t_compute_start_us = ggml_time_us();
-    }
     lctx.n_queued_tokens += n_tokens_all;
-
-    auto & kv_self = lctx.kv_self;
-    llama_kv_slot_restorer kv_slot_restorer(kv_self);
-
-    const int64_t n_embd  = hparams.n_embd;
-    const int64_t n_vocab = hparams.n_vocab;
-
-    uint32_t n_outputs = 0;
-    uint32_t n_outputs_prev = 0;
-
-    const auto n_ubatch = cparams.n_ubatch;
-
-    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
     lctx.embd_seq.clear();
 
     // count outputs
@@ -11265,7 +8882,7 @@ static int llama_decode_internal(
     }
 
     lctx.sbatch.from_batch(batch, batch.n_embd,
-        /* simple_split */ !kv_self.recurrent,
+        /* simple_split */ !lctx.kv_self.recurrent,
         /* logits_all   */ n_outputs == n_tokens_all);
 
     // reserve output buffer
@@ -11274,75 +8891,152 @@ static int llama_decode_internal(
         return -2;
     };
 
+    return 0;
+}
+
+static int llama_prepare_ubatch(
+        llama_context          & lctx,
+        llama_kv_slot_restorer & kv_slot_restorer,
+        llama_ubatch           & ubatch,
+        const uint32_t           n_outputs,
+        const uint32_t           n_tokens_all) {
+    GGML_ASSERT(lctx.sbatch.n_tokens > 0);
+
+    auto       & kv_self = lctx.kv_self;
+    const auto & cparams = lctx.cparams;
+    const auto & hparams = lctx.model.hparams;
+
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+    if (lctx.kv_self.recurrent) {
+        if (embd_pooled) {
+            // Pooled embeddings cannot be split across ubatches (yet)
+            ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
+        } else {
+            // recurrent model architectures are easier to implement
+            // with equal-length sequences
+            ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
+        }
+    } else {
+        ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
+    }
+
+    // count the outputs in this u_batch
+    {
+        int32_t n_outputs_new = 0;
+
+        if (n_outputs == n_tokens_all) {
+            n_outputs_new = ubatch.n_tokens;
+        } else {
+            GGML_ASSERT(ubatch.output);
+            for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
+                n_outputs_new += int32_t(ubatch.output[i] != 0);
+            }
+        }
+
+        // needs to happen before the graph is built
+        lctx.n_outputs = n_outputs_new;
+    }
+
+    // non-causal masks do not use the KV cache
+    if (hparams.causal_attn) {
+        llama_kv_cache_update(&lctx);
+
+        // if we have enough unused cells before the current head ->
+        //   better to start searching from the beginning of the cache, hoping to fill it
+        if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
+            kv_self.head = 0;
+        }
+
+        auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+        if (!slot) {
+            llama_kv_cache_defrag(kv_self);
+            llama_kv_cache_update(&lctx);
+            slot = llama_kv_cache_find_slot(kv_self, ubatch);
+        }
+        if (!slot) {
+            return 1;
+        }
+        kv_slot_restorer.save(slot);
+
+        if (!kv_self.recurrent) {
+            // a heuristic, to avoid attending the full cache if it is not yet utilized
+            // after enough generations, the benefit from this heuristic disappears
+            // if we start defragmenting the cache, the benefit from this will be more important
+            const uint32_t pad = llama_kv_cache_get_padding(cparams);
+            kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
+            //kv_self.n = llama_kv_cache_cell_max(kv_self);
+        }
+    }
+
+    return 0;
+}
+
+// decode a batch of tokens by evaluating the transformer
+// in case of unsuccessful decoding (error or warning),
+// the kv_cache state will be returned to its original state
+// (for non-recurrent models) or cleaned (for recurrent models)
+//
+//   - lctx:      llama context
+//   - inp_batch: batch to evaluate
+//
+// return 0 on success
+// return positive int on warning
+// return negative int on error
+//
+static int llama_decode_impl(
+         llama_context & lctx,
+           llama_batch   inp_batch) {
+
+    lctx.is_encoding = false;
+
+    if (inp_batch.n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
+        return -1;
+    }
+
+    // temporarily allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
+    const llama_batch & batch = batch_allocr.batch;
+
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    if (lctx.t_compute_start_us == 0) {
+        lctx.t_compute_start_us = ggml_time_us();
+    }
+    auto & kv_self = lctx.kv_self;
+    llama_kv_slot_restorer kv_slot_restorer(kv_self);
+
+    const int64_t n_embd  = hparams.n_embd;
+    const int64_t n_vocab = hparams.n_vocab;
+
+    uint32_t n_outputs = 0;
+    uint32_t n_outputs_prev = 0;
+
+    {
+        const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
+        if (ret != 0) {
+            return ret;
+        }
+    }
+
     while (lctx.sbatch.n_tokens > 0) {
         llama_ubatch ubatch;
-        if (kv_self.recurrent) {
-            if (embd_pooled) {
-                // Pooled embeddings cannot be split across ubatches (yet)
-                ubatch = lctx.sbatch.split_seq(n_ubatch);
-            } else {
-                // recurrent model architectures are easier to implement
-                // with equal-length sequences
-                ubatch = lctx.sbatch.split_equal(n_ubatch);
-            }
-        } else {
-            ubatch = lctx.sbatch.split_simple(n_ubatch);
-        }
-        const uint32_t n_tokens = ubatch.n_tokens;
-
-        // count the outputs in this u_batch
         {
-            int32_t n_outputs_new = 0;
-
-            if (n_outputs == n_tokens_all) {
-                n_outputs_new = n_tokens;
-            } else {
-                GGML_ASSERT(ubatch.output);
-                for (uint32_t i = 0; i < n_tokens; i++) {
-                    n_outputs_new += (int32_t) (ubatch.output[i] != 0);
-                }
+            const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
+            if (ret != 0) {
+                return ret;
             }
-
-            // needs to happen before the graph is built
-            lctx.n_outputs = n_outputs_new;
         }
 
-        int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
-        ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+        const int         n_threads  = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+        ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool   : lctx.threadpool_batch;
 
         GGML_ASSERT(n_threads > 0);
 
-        // non-causal masks do not use the KV cache
-        if (hparams.causal_attn) {
-            llama_kv_cache_update(&lctx);
-
-            // if we have enough unused cells before the current head ->
-            //   better to start searching from the beginning of the cache, hoping to fill it
-            if (kv_self.head > kv_self.used + 2*n_tokens) {
-                kv_self.head = 0;
-            }
-
-            auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-            if (!slot) {
-                llama_kv_cache_defrag(kv_self);
-                llama_kv_cache_update(&lctx);
-                slot = llama_kv_cache_find_slot(kv_self, ubatch);
-            }
-            if (!slot) {
-                return 1;
-            }
-            kv_slot_restorer.save(slot);
-
-            if (!kv_self.recurrent) {
-                // a heuristic, to avoid attending the full cache if it is not yet utilized
-                // after enough generations, the benefit from this heuristic disappears
-                // if we start defragmenting the cache, the benefit from this will be more important
-                const uint32_t pad = llama_kv_cache_get_padding(cparams);
-                kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
-                //kv_self.n = llama_kv_cache_cell_max(kv_self);
-            }
-        }
-
         //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
         ggml_backend_sched_reset(lctx.sched.get());
@@ -11397,7 +9091,7 @@ static int llama_decode_internal(
 
         // update the kv ring buffer
         {
-            kv_self.head += n_tokens;
+            kv_self.head += ubatch.n_tokens;
 
             // Ensure kv cache head points to a valid index.
             if (kv_self.head >= kv_self.size) {
@@ -11510,12 +9204,14 @@ static int llama_decode_internal(
     //llama_synchronize(&lctx);
 
     // decide if we need to defrag the kv cache
-    if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
-        const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
+    if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
+        // - do not defrag small contexts (i.e. < 2048 tokens)
+        // - count the padding towards the number of used tokens
+        const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + llama_kv_cache_get_padding(cparams))/float(kv_self.n)) : 0.0f;
 
         // queue defragmentation for next llama_kv_cache_update
         if (fragmentation > cparams.defrag_thold) {
-            //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
+            LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
 
             llama_kv_cache_defrag(kv_self);
         }
@@ -11537,7 +9233,7 @@ static int llama_decode_internal(
 // return positive int on warning
 // return negative int on error
 //
-static int llama_encode_internal(
+static int llama_encode_impl(
          llama_context & lctx,
            llama_batch   inp_batch) {
 
@@ -11562,7 +9258,7 @@ static int llama_encode_internal(
 
     if (batch.token) {
         for (uint32_t i = 0; i < n_tokens; ++i) {
-            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
@@ -11719,7 +9415,7 @@ static int llama_encode_internal(
 }
 
 // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
-static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
     auto & kv_self = lctx.kv_self;
 
     const auto & hparams = lctx.model.hparams;
@@ -11739,9 +9435,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     // each move requires 6*n_layer tensors (see build_defrag)
     //   - source view, destination view, copy operation
     //   - x2 for keys and values
-    //const uint32_t max_moves = llama_model_max_nodes(model)/(6*n_layer);
+    //const uint32_t max_moves = model.max_nodes()/(6*n_layer);
     // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer);
+    const uint32_t max_moves = (lctx.model.max_nodes() - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
     //
@@ -11934,7 +9630,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
 }
 
-static void llama_kv_cache_update_internal(struct llama_context & lctx) {
+static void llama_kv_cache_update_impl(struct llama_context & lctx) {
     bool need_reserve = false;
 
     if (lctx.kv_self.has_shift) {
@@ -11970,7 +9666,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
     // defragment the KV cache if needed
     if (lctx.kv_self.do_defrag) {
-        llama_kv_cache_defrag_internal(lctx);
+        llama_kv_cache_defrag_impl(lctx);
 
         need_reserve = true;
 
@@ -11983,7 +9679,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
         // build worst-case graph
         uint32_t n_seqs = 1; // TODO: worst-case number of sequences
         uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
-        llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+        llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
         llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
         ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
 
@@ -11995,45 +9691,38 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
     }
 }
 
-int32_t llama_lora_adapter_set(
+int32_t llama_set_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter,
+            struct llama_adapter_lora * adapter,
             float scale) {
-    if (ctx->cparams.flash_attn) {
-        LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__);
-        return -1;
-    }
-
-    ctx->lora_adapters[adapter] = scale;
-
+    ctx->lora[adapter] = scale;
     return 0;
 }
 
-int32_t llama_lora_adapter_remove(
+int32_t llama_rm_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter) {
-    auto pos = ctx->lora_adapters.find(adapter);
-    if (pos != ctx->lora_adapters.end()) {
-        ctx->lora_adapters.erase(pos);
+            struct llama_adapter_lora * adapter) {
+    auto pos = ctx->lora.find(adapter);
+    if (pos != ctx->lora.end()) {
+        ctx->lora.erase(pos);
         return 0;
     }
 
     return -1;
 }
 
-void llama_lora_adapter_clear(struct llama_context * ctx) {
-    ctx->lora_adapters.clear();
+void llama_clear_adapter_lora(struct llama_context * ctx) {
+    ctx->lora.clear();
 }
 
-// TODO: tmp
-int32_t llama_control_vector_apply(
-        struct llama_context * lctx,
+int32_t llama_apply_adapter_cvec(
+        struct llama_context * ctx,
                  const float * data,
                       size_t   len,
                      int32_t   n_embd,
                      int32_t   il_start,
                      int32_t   il_end) {
-    return llama_control_vector_apply(lctx->cvec, lctx->model, data, len, n_embd, il_start, il_end);
+    return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
 }
 
 //
@@ -12134,13 +9823,12 @@ int64_t llama_time_us(void) {
     return ggml_time_us();
 }
 
-struct llama_model * llama_load_model_from_file(
-        const char * path_model,
+static struct llama_model * llama_model_load_from_file_impl(
+        const std::string & path_model,
+        std::vector & splits,
         struct llama_model_params params) {
     ggml_time_init();
 
-    llama_model * model = new llama_model;
-
     unsigned cur_percentage = 0;
     if (params.progress_callback == NULL) {
         params.progress_callback_user_data = &cur_percentage;
@@ -12158,46 +9846,7 @@ struct llama_model * llama_load_model_from_file(
         };
     }
 
-    if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
-        // split the servers set them into model->rpc_servers
-        std::string servers(params.rpc_servers);
-        size_t pos = 0;
-        while ((pos = servers.find(',')) != std::string::npos) {
-            std::string server = servers.substr(0, pos);
-            model->rpc_servers.push_back(server);
-            servers.erase(0, pos + 1);
-        }
-        model->rpc_servers.push_back(servers);
-    }
-
-    // add RPC devices
-    if (!model->rpc_servers.empty()) {
-        ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
-        if (!rpc_reg) {
-            LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
-            llama_free_model(model);
-            return nullptr;
-        }
-
-        typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
-        ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
-        if (!ggml_backend_rpc_add_device_fn) {
-            LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
-            llama_free_model(model);
-            return nullptr;
-        }
-
-        for (const std::string & server : model->rpc_servers) {
-            ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
-            if (dev) {
-                model->devices.push_back(dev);
-            } else {
-                LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
-                llama_free_model(model);
-                return nullptr;
-            }
-        }
-    }
+    llama_model * model = new llama_model(params);
 
     // create list of devices to use with this model
     if (params.devices) {
@@ -12205,6 +9854,7 @@ struct llama_model * llama_load_model_from_file(
             model->devices.push_back(*dev);
         }
     } else {
+        std::vector rpc_servers;
         // use all available devices
         for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
             ggml_backend_dev_t dev = ggml_backend_dev_get(i);
@@ -12215,17 +9865,26 @@ struct llama_model * llama_load_model_from_file(
                     break;
 
                 case GGML_BACKEND_DEVICE_TYPE_GPU:
-                    model->devices.push_back(dev);
+                    ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
+                    if (ggml_backend_reg_name(reg) == std::string("RPC")) {
+                        rpc_servers.push_back(dev);
+                    } else {
+                        model->devices.push_back(dev);
+                    }
                     break;
             }
         }
+        // add RPC servers at the front of the list
+        if (!rpc_servers.empty()) {
+            model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
+        }
     }
 
     // if using single GPU mode, remove all except the main GPU
     if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
         if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) {
             LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size());
-            llama_free_model(model);
+            llama_model_free(model);
             return nullptr;
         }
         ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
@@ -12239,7 +9898,7 @@ struct llama_model * llama_load_model_from_file(
         LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
     }
 
-    int status = llama_model_load(path_model, *model, params);
+    const int status = llama_model_load(path_model, splits, *model, params);
     GGML_ASSERT(status <= 0);
     if (status < 0) {
         if (status == -1) {
@@ -12248,14 +9907,43 @@ struct llama_model * llama_load_model_from_file(
             LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
         }
 
-        llama_free_model(model);
+        llama_model_free(model);
         return nullptr;
     }
 
     return model;
 }
 
-struct llama_context * llama_new_context_with_model(
+// deprecated
+struct llama_model * llama_load_model_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    return llama_model_load_from_file(path_model, params);
+}
+
+struct llama_model * llama_model_load_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    std::vector splits = {};
+    return llama_model_load_from_file_impl(path_model, splits, params);
+}
+
+struct llama_model * llama_model_load_from_splits(
+        const char ** paths,
+        size_t n_paths,
+        struct llama_model_params params) {
+    std::vector splits;
+    if (n_paths == 0) {
+        LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
+        return nullptr;
+    }
+    for (size_t i = 0; i < n_paths; ++i) {
+        splits.push_back(paths[i]);
+    }
+    return llama_model_load_from_file_impl(splits.front(), splits, params);
+}
+
+struct llama_context * llama_init_from_model(
                  struct llama_model * model,
         struct llama_context_params   params) {
 
@@ -12513,7 +10201,7 @@ struct llama_context * llama_new_context_with_model(
                 backend_ptrs.push_back(backend.get());
             }
 
-            const size_t max_nodes = llama_model_max_nodes(*model);
+            const size_t max_nodes = model->max_nodes();
 
             // buffer used to store the computation graph and the tensor meta data
             ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
@@ -12521,9 +10209,9 @@ struct llama_context * llama_new_context_with_model(
             // TODO: move these checks to ggml_backend_sched
             // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
             bool pipeline_parallel =
-                llama_get_device_count(*model) > 1 &&
-                model->n_gpu_layers > (int)model->hparams.n_layer &&
-                model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
+                model->n_devices() > 1 &&
+                model->params.n_gpu_layers > (int)model->hparams.n_layer &&
+                model->params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
                 params.offload_kqv;
 
             // pipeline parallelism requires support for async compute and events in all devices
@@ -12554,7 +10242,7 @@ struct llama_context * llama_new_context_with_model(
             // initialize scheduler with the worst-case graph
             uint32_t n_seqs = 1; // TODO: worst-case number of sequences
             uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-            llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+            llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 
             llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
             ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
@@ -12606,6 +10294,12 @@ struct llama_context * llama_new_context_with_model(
     return ctx;
 }
 
+struct llama_context * llama_new_context_with_model(
+                 struct llama_model * model,
+        struct llama_context_params   params) {
+    return llama_init_from_model(model, params);
+}
+
 //
 // kv cache
 //
@@ -12672,7 +10366,7 @@ void llama_kv_cache_defrag(struct llama_context * ctx) {
 }
 
 void llama_kv_cache_update(struct llama_context * ctx) {
-    llama_kv_cache_update_internal(*ctx);
+    llama_kv_cache_update_impl(*ctx);
 }
 
 bool llama_kv_cache_can_shift(struct llama_context * ctx) {
@@ -12684,7 +10378,7 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) {
 int32_t llama_encode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    const int ret = llama_encode_internal(*ctx, batch);
+    const int ret = llama_encode_impl(*ctx, batch);
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
     }
@@ -12695,7 +10389,7 @@ int32_t llama_encode(
 int32_t llama_decode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    const int ret = llama_decode_internal(*ctx, batch);
+    const int ret = llama_decode_impl(*ctx, batch);
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
@@ -12703,166 +10397,18 @@ int32_t llama_decode(
     return ret;
 }
 
-//
-// vocab
-//
-
-// TODO: tmp bridges below until `struct llama_vocab` is exposed through the public API
-
-const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
-    return llama_token_get_text_impl(model->vocab, token);
-}
-
-float llama_token_get_score(const struct llama_model * model, llama_token token) {
-    return llama_token_get_score_impl(model->vocab, token);
-}
-
-enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
-    return llama_token_get_attr_impl(model->vocab, token);
-}
-
-bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
-    return llama_token_is_eog_impl(model->vocab, token);
-}
-
-bool llama_token_is_control(const struct llama_model * model, llama_token token) {
-    return llama_token_is_control_impl(model->vocab, token);
-}
-
-llama_token llama_token_bos(const struct llama_model * model) {
-    return llama_token_bos_impl(model->vocab);
-}
-
-llama_token llama_token_eos(const struct llama_model * model) {
-    return llama_token_eos_impl(model->vocab);
-}
-
-llama_token llama_token_eot(const struct llama_model * model) {
-    return llama_token_eot_impl(model->vocab);
-}
-
-llama_token llama_token_cls(const struct llama_model * model) {
-    return llama_token_cls_impl(model->vocab);
-}
-
-llama_token llama_token_sep(const struct llama_model * model) {
-    return llama_token_sep_impl(model->vocab);
-}
-
-llama_token llama_token_nl (const struct llama_model * model) {
-    return llama_token_nl_impl(model->vocab);
-}
-
-llama_token llama_token_pad(const struct llama_model * model) {
-    return llama_token_pad_impl(model->vocab);
-}
-
-bool llama_add_bos_token(const struct llama_model * model) {
-    return llama_add_bos_token_impl(model->vocab);
-}
-
-bool llama_add_eos_token(const struct llama_model * model) {
-    return llama_add_eos_token_impl(model->vocab);
-}
-
-llama_token llama_token_prefix(const struct llama_model * model) {
-    return llama_token_prefix_impl(model->vocab);
-}
-
-llama_token llama_token_middle(const struct llama_model * model) {
-    return llama_token_middle_impl(model->vocab);
-}
-
-llama_token llama_token_suffix(const struct llama_model * model) {
-    return llama_token_suffix_impl(model->vocab);
-}
-
-llama_token llama_token_fim_pre(const struct llama_model * model) {
-    return llama_token_fim_pre_impl(model->vocab);
-}
-
-llama_token llama_token_fim_suf(const struct llama_model * model) {
-    return llama_token_fim_suf_impl(model->vocab);
-}
-
-llama_token llama_token_fim_mid(const struct llama_model * model) {
-    return llama_token_fim_mid_impl(model->vocab);
-}
-
-llama_token llama_token_fim_pad(const struct llama_model * model) {
-    return llama_token_fim_pad_impl(model->vocab);
-}
-
-llama_token llama_token_fim_rep(const struct llama_model * model) {
-    return llama_token_fim_rep_impl(model->vocab);
-}
-
-llama_token llama_token_fim_sep(const struct llama_model * model) {
-    return llama_token_fim_sep_impl(model->vocab);
-}
-
-//
-// tokenization
-//
-
-int32_t llama_tokenize(
-    const struct llama_model * model,
-                  const char * text,
-                     int32_t   text_len,
-                 llama_token * tokens,
-                     int32_t   n_tokens_max,
-                        bool   add_special,
-                        bool   parse_special) {
-    return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
-}
-
-int32_t llama_token_to_piece(
-    const struct llama_model * model,
-                 llama_token   token,
-                        char * buf,
-                     int32_t   length,
-                     int32_t   lstrip,
-                        bool   special) {
-    return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
-}
-
-int32_t llama_detokenize(
-    const struct llama_model * model,
-           const llama_token * tokens,
-                     int32_t   n_tokens,
-                        char * text,
-                     int32_t   text_len_max,
-                        bool   remove_special,
-                        bool   unparse_special) {
-    return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
-}
-
 //
 // chat templates
 //
 
 int32_t llama_chat_apply_template(
-                const struct llama_model * model,
                               const char * tmpl,
          const struct llama_chat_message * chat,
                                   size_t   n_msg,
                                     bool   add_ass,
                                     char * buf,
                                  int32_t   length) {
-    std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
-    if (tmpl == nullptr) {
-        GGML_ASSERT(model != nullptr);
-
-        // load template from model, if available
-        const auto & it = model->gguf_kv.find("tokenizer.chat_template");
-        if (it != model->gguf_kv.end() && it->second.size() > 0) {
-            curr_tmpl = it->second;
-        }
-        else {
-            // worst case: there is no information about template, we will use chatml by default
-            curr_tmpl = "chatml";  // see llm_chat_apply_template
-        }
-    }
+    const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl);
 
     // format the chat to string
     std::vector chat_vec;
@@ -12886,23 +10432,6 @@ int32_t llama_chat_apply_template(
     return res;
 }
 
-//
-// sampling
-//
-
-// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
-struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
-    return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
-}
-
-struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
-    return llama_sampler_init_infill_impl(model->vocab);
-}
-
-struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
-    return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
-}
-
 //
 // model split
 //
@@ -12915,16 +10444,16 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix,
     return 0;
 }
 
-int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
+int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
     std::string str_split_path(split_path);
     char postfix[32];
     snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
     std::string str_postfix(postfix);
 
-    // check if dest ends with postfix
+    // check if split_prefix ends with postfix
     int size_prefix = str_split_path.size() - str_postfix.size();
     if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
+        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
         return size_prefix;
     }
 
@@ -12933,6 +10462,8 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
 
 const char * llama_print_system_info(void) {
     static std::string s;
+    s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls.
+
 
     for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
         auto * reg = ggml_backend_reg_get(i);
diff --git a/llama/llama.cpp/src/unicode.cpp b/llama/llama.cpp/src/unicode.cpp
index 6155da800..9dd53b9ad 100644
--- a/llama/llama.cpp/src/unicode.cpp
+++ b/llama/llama.cpp/src/unicode.cpp
@@ -12,18 +12,17 @@
 
 #include 
 #include 
+#include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
 #include 
 #include 
-#include 
-#include 
 
 size_t unicode_len_utf8(char src) {
     const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@@ -641,7 +640,14 @@ std::vector unicode_cpts_from_utf8(const std::string & utf8) {
     result.reserve(utf8.size());
     size_t offset = 0;
     while (offset < utf8.size()) {
-        result.push_back(unicode_cpt_from_utf8(utf8, offset));
+        try {
+            result.push_back(unicode_cpt_from_utf8(utf8, offset));
+        }
+        catch (const std::invalid_argument & /*ex*/) {
+            // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
+            ++offset;
+            result.emplace_back(0xFFFD); // replacement character
+        }
     }
     return result;
 }
@@ -724,7 +730,7 @@ std::vector unicode_regex_split(const std::string & text, const std
     const auto cpts = unicode_cpts_from_utf8(text);
 
     // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
-    // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
+    // ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
     std::string text_collapsed;
     if (need_collapse) {
         // collapse all unicode categories
diff --git a/llama/llama.go b/llama/llama.go
index a20f23578..6eed3d477 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -14,6 +14,7 @@ package llama
 #include "llama.h"
 #include "clip.h"
 #include "llava.h"
+#include "gguf.h"
 
 #include "mllama.h"
 #include "sampling_ext.h"
@@ -293,29 +294,29 @@ func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
 }
 
 func (m *Model) NumVocab() int {
-	return int(C.llama_n_vocab(m.c))
+	return int(C.llama_n_vocab(m.Vocab()))
 }
 
 func (m *Model) TokenIsEog(token int) bool {
-	return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
+	return bool(C.llama_token_is_eog(m.Vocab(), C.llama_token(token)))
 }
 
 func (m *Model) AddBOSToken() bool {
-	return bool(C.llama_add_bos_token(m.c))
+	return bool(C.llama_add_bos_token(m.Vocab()))
 }
 
 func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
 	cLoraPath := C.CString(loraPath)
 	defer C.free(unsafe.Pointer(cLoraPath))
 
-	loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
+	loraAdapter := C.llama_adapter_lora_init(m.c, cLoraPath)
 	if loraAdapter == nil {
 		return errors.New("unable to load lora")
 	}
 
 	err := -1
 	if loraAdapter != nil {
-		err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
+		err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale)))
 	}
 	if err != 0 {
 		return errors.New("error applying lora from file")
@@ -324,6 +325,10 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
 	return nil
 }
 
+func (m *Model) Vocab() *C.struct_llama_vocab {
+	return C.llama_model_get_vocab(m.c)
+}
+
 type Batch struct {
 	c         C.struct_llama_batch
 	batchSize int
@@ -414,7 +419,7 @@ func (m *Model) TokenToPiece(token int) string {
 	tokenLen := 12
 	buf := make([]byte, tokenLen)
 	tokenLen = int(C.llama_token_to_piece(
-		m.c,
+		m.Vocab(),
 		C.int32_t(token),
 		(*C.char)(unsafe.Pointer(&buf[0])),
 		C.int32_t(tokenLen),
@@ -426,7 +431,7 @@ func (m *Model) TokenToPiece(token int) string {
 
 		buf = make([]byte, tokenLen)
 		C.llama_token_to_piece(
-			m.c,
+			m.Vocab(),
 			C.int32_t(token),
 			(*C.char)(unsafe.Pointer(&buf[0])),
 			C.int32_t(tokenLen),
@@ -444,7 +449,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
 	defer C.free(unsafe.Pointer(cText))
 
 	result := C.llama_tokenize(
-		m.c,
+		m.Vocab(),
 		cText,
 		C.int32_t(len(text)),
 		&cTokens[0],
@@ -458,7 +463,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
 		maxTokens = int(-result)
 		cTokens = make([]C.llama_token, maxTokens)
 		result = C.llama_tokenize(
-			m.c,
+			m.Vocab(),
 			cText,
 			C.int32_t(len(text)),
 			&cTokens[0],
diff --git a/llama/mllama.cpp b/llama/mllama.cpp
index 4e84c60ae..1ba8f5bef 100644
--- a/llama/mllama.cpp
+++ b/llama/mllama.cpp
@@ -5,6 +5,7 @@
 #include "ggml-backend.h"
 #include "ggml-cpu.h"
 #include "ggml.h"
+#include "gguf.h"
 
 #ifdef GGML_USE_CUDA
 #include "ggml-cuda.h"
diff --git a/llama/patches/0001-cuda.patch b/llama/patches/0001-cuda.patch
index 0bf338f2b..a766c30c5 100644
--- a/llama/patches/0001-cuda.patch
+++ b/llama/patches/0001-cuda.patch
@@ -10,7 +10,7 @@ Subject: [PATCH] cuda
  3 files changed, 2 insertions(+), 1 deletion(-)
 
 diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
-index e2d6c405..a12172dc 100644
+index dba7be33..1ca40b2c 100644
 --- a/ggml/src/ggml-backend.cpp
 +++ b/ggml/src/ggml-backend.cpp
 @@ -106,7 +106,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
@@ -22,10 +22,10 @@ index e2d6c405..a12172dc 100644
  
  size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index 0b06be72..be29e979 100644
+index ebb2ccae..b094929b 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -424,6 +424,7 @@ struct ggml_backend_cuda_buffer_context {
+@@ -529,6 +529,7 @@ struct ggml_backend_cuda_buffer_context {
  static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
      delete ctx;
@@ -34,10 +34,10 @@ index 0b06be72..be29e979 100644
  
  static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index a85502ee..cd8ef741 100644
+index c550142a..fd9a4e77 100644
 --- a/ggml/src/ggml-metal/ggml-metal.m
 +++ b/ggml/src/ggml-metal/ggml-metal.m
-@@ -4187,6 +4187,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
+@@ -4350,6 +4350,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
      }
  
      free(ctx);
diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch
index 189a996fc..93a5ce594 100644
--- a/llama/patches/0002-pretokenizer.patch
+++ b/llama/patches/0002-pretokenizer.patch
@@ -4,17 +4,17 @@ Date: Mon, 16 Sep 2024 15:53:13 -0700
 Subject: [PATCH] pretokenizer
 
 ---
- src/llama-model.cpp | 14 +++-----------
+ src/llama-vocab.cpp | 14 +++-----------
  1 file changed, 3 insertions(+), 11 deletions(-)
 
-diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index 405e0528..00b80c52 100644
---- a/src/llama-model.cpp
-+++ b/src/llama-model.cpp
-@@ -1249,16 +1249,7 @@ void llm_load_vocab(llama_model_loader & ml, llama_model & model) {
-         if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
-             vocab.tokenizer_add_space_prefix = false;
-             vocab.tokenizer_clean_spaces = true;
+diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
+index ad9ffe66..a4eee9b8 100644
+--- a/src/llama-vocab.cpp
++++ b/src/llama-vocab.cpp
+@@ -1468,16 +1468,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+         if (type == LLAMA_VOCAB_TYPE_BPE) {
+             add_space_prefix = false;
+             clean_spaces = true;
 -            if (tokenizer_pre.empty()) {
 -                LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
 -                LLAMA_LOG_WARN("%s:                                             \n", __func__);
@@ -23,19 +23,19 @@ index 405e0528..00b80c52 100644
 -                LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL             \n", __func__);
 -                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
 -                LLAMA_LOG_WARN("%s:                                             \n", __func__);
--                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+-                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
 -            } else if (tokenizer_pre == "default") {
 +            if (tokenizer_pre == "default") {
-                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+                 pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              } else if (
                      tokenizer_pre == "llama3"   ||
-@@ -1373,7 +1364,8 @@ void llm_load_vocab(llama_model_loader & ml, llama_model & model) {
+@@ -1593,7 +1584,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                  tokenizer_pre == "megrez") {
-                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+                 pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
              } else {
 -                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
 +                LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
-+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
++                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              }
-         } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+         } else if (type == LLAMA_VOCAB_TYPE_SPM) {
+             pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
diff --git a/llama/patches/0003-embeddings.patch b/llama/patches/0003-embeddings.patch
index c04ee563b..176cb41da 100644
--- a/llama/patches/0003-embeddings.patch
+++ b/llama/patches/0003-embeddings.patch
@@ -9,10 +9,10 @@ Subject: [PATCH] embeddings
  2 files changed, 5 insertions(+), 3 deletions(-)
 
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index 38a55fb2..b9c4a5bf 100644
+index 671d2a81..47e79ed4 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -475,7 +475,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
+@@ -479,7 +479,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
      const auto n_embd  = hparams.n_embd;
  
      // TODO: use a per-batch flag for logits presence instead
@@ -22,10 +22,10 @@ index 38a55fb2..b9c4a5bf 100644
  
      const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
 diff --git a/src/llama.cpp b/src/llama.cpp
-index ea78ea48..4eb3f6b9 100644
+index 607f2786..ac85bfed 100644
 --- a/src/llama.cpp
 +++ b/src/llama.cpp
-@@ -10876,7 +10876,6 @@ static int llama_decode_internal(
+@@ -8652,7 +8652,6 @@ static int llama_decode_impl(
              res  = nullptr;
              embd = nullptr;
          } else if (cparams.embeddings) {
@@ -33,7 +33,7 @@ index ea78ea48..4eb3f6b9 100644
              embd = nullptr;
              for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
                  if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
-@@ -10884,12 +10883,15 @@ static int llama_decode_internal(
+@@ -8660,12 +8659,15 @@ static int llama_decode_impl(
                      break;
                  }
              }
diff --git a/llama/patches/0004-clip-unicode.patch b/llama/patches/0004-clip-unicode.patch
index 9c90cfd0a..50a12aadd 100644
--- a/llama/patches/0004-clip-unicode.patch
+++ b/llama/patches/0004-clip-unicode.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] clip-unicode
  1 file changed, 39 insertions(+), 1 deletion(-)
 
 diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
-index 3cd0d2fa..b3c1829f 100644
+index 76d4a785..205af1eb 100644
 --- a/examples/llava/clip.cpp
 +++ b/examples/llava/clip.cpp
-@@ -56,6 +56,19 @@
+@@ -58,6 +58,19 @@
  #   define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
  #endif // defined(LLAVA_LOG_OFF)
  
@@ -31,7 +31,7 @@ index 3cd0d2fa..b3c1829f 100644
  //#define CLIP_DEBUG_FUNCTIONS
  
  // RGB uint8 image
-@@ -1322,8 +1335,29 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
+@@ -1402,8 +1415,29 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
              gguf_free(ctx);
              return nullptr;
          }
@@ -62,7 +62,7 @@ index 3cd0d2fa..b3c1829f 100644
          if (!fin) {
              LOG_ERR("cannot open model file for loading tensors\n");
              clip_free(new_clip);
-@@ -1363,7 +1397,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
+@@ -1443,7 +1477,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
                  ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
              }
          }
diff --git a/llama/patches/0005-solar-pro.patch b/llama/patches/0005-solar-pro.patch
index 33be79c2d..e201682bf 100644
--- a/llama/patches/0005-solar-pro.patch
+++ b/llama/patches/0005-solar-pro.patch
@@ -11,21 +11,21 @@ tensor to store the scalar. the scalar is implemented a 1-dimensional
 tensor with 2 elements dervied from the model's bskcn_tv configuration.
 in general, the values are (bskcn_tv, 1 - bskcn_tv)
 ---
- src/llama-arch.cpp         |  53 +++++++----
+ src/llama-arch.cpp         |  21 +++++
  src/llama-arch.h           |   3 +
  src/llama-hparams.cpp      |   8 ++
- src/llama-hparams.h        |   5 +
+ src/llama-hparams.h        |   5 ++
  src/llama-model-loader.cpp |   1 +
- src/llama-model.cpp        |  16 ++++
+ src/llama-model.cpp        |  44 +++++++++++
  src/llama-model.h          |   3 +
- src/llama.cpp              | 185 +++++++++++++++++++++++++++++++++++++
- 8 files changed, 258 insertions(+), 16 deletions(-)
+ src/llama.cpp              | 152 ++++++++++++++++++++++++++++++++++++-
+ 8 files changed, 236 insertions(+), 1 deletion(-)
 
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index 007d79f8..5b376c5e 100644
+index 97a1e7e5..a1e0ebcc 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
-@@ -59,6 +59,7 @@ static const std::map LLM_ARCH_NAMES = {
+@@ -61,6 +61,7 @@ static const std::map LLM_ARCH_NAMES = {
      { LLM_ARCH_GRANITE,          "granite"          },
      { LLM_ARCH_GRANITE_MOE,      "granitemoe"       },
      { LLM_ARCH_CHAMELEON,        "chameleon"        },
@@ -33,48 +33,16 @@ index 007d79f8..5b376c5e 100644
      { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
      { LLM_ARCH_UNKNOWN,          "(unknown)"        },
  };
-@@ -106,22 +107,23 @@ static const std::map LLM_KV_NAMES = {
-     { LLM_KV_RESIDUAL_SCALE,                    "%s.residual_scale"                    },
-     { LLM_KV_EMBEDDING_SCALE,                   "%s.embedding_scale"                   },
- 
--    { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             },
--    { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          },
--    { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,         "%s.attention.max_alibi_bias"         },
--    { LLM_KV_ATTENTION_CLAMP_KQV,              "%s.attention.clamp_kqv"              },
--    { LLM_KV_ATTENTION_KEY_LENGTH,             "%s.attention.key_length"             },
--    { LLM_KV_ATTENTION_VALUE_LENGTH,           "%s.attention.value_length"           },
--    { LLM_KV_ATTENTION_LAYERNORM_EPS,          "%s.attention.layer_norm_epsilon"     },
--    { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,      "%s.attention.layer_norm_rms_epsilon" },
--    { LLM_KV_ATTENTION_GROUPNORM_EPS,          "%s.attention.group_norm_epsilon"     },
--    { LLM_KV_ATTENTION_GROUPNORM_GROUPS,       "%s.attention.group_norm_groups"      },
--    { LLM_KV_ATTENTION_CAUSAL,                 "%s.attention.causal"                 },
--    { LLM_KV_ATTENTION_Q_LORA_RANK,            "%s.attention.q_lora_rank"            },
--    { LLM_KV_ATTENTION_KV_LORA_RANK,           "%s.attention.kv_lora_rank"           },
--    { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
--    { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
--    { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
-+    { LLM_KV_ATTENTION_HEAD_COUNT,               "%s.attention.head_count"               },
-+    { LLM_KV_ATTENTION_HEAD_COUNT_KV,            "%s.attention.head_count_kv"            },
-+    { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,           "%s.attention.max_alibi_bias"           },
-+    { LLM_KV_ATTENTION_CLAMP_KQV,                "%s.attention.clamp_kqv"                },
-+    { LLM_KV_ATTENTION_KEY_LENGTH,               "%s.attention.key_length"               },
-+    { LLM_KV_ATTENTION_VALUE_LENGTH,             "%s.attention.value_length"             },
-+    { LLM_KV_ATTENTION_LAYERNORM_EPS,            "%s.attention.layer_norm_epsilon"       },
-+    { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,        "%s.attention.layer_norm_rms_epsilon"   },
-+    { LLM_KV_ATTENTION_GROUPNORM_EPS,            "%s.attention.group_norm_epsilon"       },
-+    { LLM_KV_ATTENTION_GROUPNORM_GROUPS,         "%s.attention.group_norm_groups"        },
-+    { LLM_KV_ATTENTION_CAUSAL,                   "%s.attention.causal"                   },
-+    { LLM_KV_ATTENTION_Q_LORA_RANK,              "%s.attention.q_lora_rank"              },
-+    { LLM_KV_ATTENTION_KV_LORA_RANK,             "%s.attention.kv_lora_rank"             },
-+    { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,   "%s.attention.relative_buckets_count"   },
-+    { LLM_KV_ATTENTION_SLIDING_WINDOW,           "%s.attention.sliding_window"           },
-+    { LLM_KV_ATTENTION_SCALE,                    "%s.attention.scale"                    },
-+    { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,    "%s.attention.block_skip_connection"    },
+@@ -125,6 +126,7 @@ static const std::map LLM_KV_NAMES = {
+     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
+     { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
+     { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
++    { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,  "%s.attention.block_skip_connection"  },
  
      { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
      { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-@@ -1240,6 +1242,24 @@ static const std::map> LLM_TENSOR_N
-             { LLM_TENSOR_POS_NET_ATTN_OUT,  "posnet.%d.attn_output" },
+@@ -1271,6 +1273,24 @@ static const std::map> LLM_TENSOR_N
+             { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
          },
      },
 +    {
@@ -96,9 +64,9 @@ index 007d79f8..5b376c5e 100644
 +        },
 +    },
      {
-         LLM_ARCH_UNKNOWN,
+         LLM_ARCH_WAVTOKENIZER_DEC,
          {
-@@ -1372,6 +1392,7 @@ static const std::map LLM_TENSOR_INFOS = {
+@@ -1429,6 +1449,7 @@ static const std::map LLM_TENSOR_INFOS = {
      {LLM_TENSOR_FFN_EXP_PROBS_B,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
      // this tensor is loaded for T5, but never used
      {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
@@ -107,10 +75,10 @@ index 007d79f8..5b376c5e 100644
      {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index 45e458bb..eac7055b 100644
+index 122fdceb..77919578 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
-@@ -63,6 +63,7 @@ enum llm_arch {
+@@ -65,6 +65,7 @@ enum llm_arch {
      LLM_ARCH_GRANITE,
      LLM_ARCH_GRANITE_MOE,
      LLM_ARCH_CHAMELEON,
@@ -118,7 +86,7 @@ index 45e458bb..eac7055b 100644
      LLM_ARCH_WAVTOKENIZER_DEC,
      LLM_ARCH_UNKNOWN,
  };
-@@ -126,6 +127,7 @@ enum llm_kv {
+@@ -129,6 +130,7 @@ enum llm_kv {
      LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
      LLM_KV_ATTENTION_SLIDING_WINDOW,
      LLM_KV_ATTENTION_SCALE,
@@ -126,7 +94,7 @@ index 45e458bb..eac7055b 100644
  
      LLM_KV_ROPE_DIMENSION_COUNT,
      LLM_KV_ROPE_DIMENSION_SECTIONS,
-@@ -305,6 +307,7 @@ enum llm_tensor {
+@@ -311,6 +313,7 @@ enum llm_tensor {
      LLM_TENSOR_ENC_OUTPUT_NORM,
      LLM_TENSOR_CLS,
      LLM_TENSOR_CLS_OUT,
@@ -135,7 +103,7 @@ index 45e458bb..eac7055b 100644
      LLM_TENSOR_CONVNEXT_DW,
      LLM_TENSOR_CONVNEXT_NORM,
 diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
-index c4053469..450738da 100644
+index ea87b295..f3955de9 100644
 --- a/src/llama-hparams.cpp
 +++ b/src/llama-hparams.cpp
 @@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
@@ -152,10 +120,10 @@ index c4053469..450738da 100644
 +}
 \ No newline at end of file
 diff --git a/src/llama-hparams.h b/src/llama-hparams.h
-index a29f20ec..fd898e27 100644
+index 1fe45410..1bdcdfd5 100644
 --- a/src/llama-hparams.h
 +++ b/src/llama-hparams.h
-@@ -52,6 +52,8 @@ struct llama_hparams {
+@@ -50,6 +50,8 @@ struct llama_hparams {
      std::array n_head_kv_arr;
      std::array n_ff_arr;
  
@@ -164,7 +132,7 @@ index a29f20ec..fd898e27 100644
      uint32_t n_layer_dense_lead = 0;
      uint32_t n_lora_q           = 0;
      uint32_t n_lora_kv          = 0;
-@@ -134,6 +136,9 @@ struct llama_hparams {
+@@ -133,6 +135,9 @@ struct llama_hparams {
  
      // dimension of the recurrent state embeddings
      uint32_t n_embd_v_s() const;
@@ -175,23 +143,23 @@ index a29f20ec..fd898e27 100644
  
  static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable");
 diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
-index 7743b465..422524a8 100644
+index 05d58ad9..1252aca1 100644
 --- a/src/llama-model-loader.cpp
 +++ b/src/llama-model-loader.cpp
-@@ -364,6 +364,7 @@ namespace GGUFMeta {
+@@ -439,6 +439,7 @@ namespace GGUFMeta {
      // TODO: this is not very clever - figure out something better
      template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
      template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
 +    template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required);
  
- llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
-     int trace = 0;
+ llama_model_loader::llama_model_loader(
+         const std::string & fname,
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index 00b80c52..306c557d 100644
+index 36a0a009..ad1315c6 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
-@@ -1091,6 +1091,21 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
-                     default: model.type = e_model::MODEL_UNKNOWN;
+@@ -1238,6 +1238,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+                     default: type = LLM_TYPE_UNKNOWN;
                 }
              } break;
 +        case LLM_ARCH_SOLAR:
@@ -200,52 +168,19 @@ index 00b80c52..306c557d 100644
 +                for (size_t i = 0; i < hparams.n_bskcn_arr.max_size(); ++i) {
 +                    auto & bskcn = hparams.n_bskcn_arr[i];
 +                    bskcn.fill(0);
-+                    auto kv = LLM_KV(model.arch);
++                    auto kv = LLM_KV(arch);
 +                    ml.get_key_or_arr(format((kv(LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION) + ".%d").c_str(), i), bskcn, hparams.n_layer, false);
 +                }
 +
 +                switch (hparams.n_layer) {
-+                    case 64: model.type = e_model::MODEL_22B; break;
-+                    default: model.type = e_model::MODEL_UNKNOWN;
++                    case 64: type = LLM_TYPE_22B; break;
++                    default: type = LLM_TYPE_UNKNOWN;
 +                }
 +            } break;
          case LLM_ARCH_WAVTOKENIZER_DEC:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-@@ -2065,6 +2080,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
-         case LLM_ARCH_GRANITE:
-         case LLM_ARCH_GRANITE_MOE:
-         case LLM_ARCH_CHAMELEON:
-+        case LLM_ARCH_SOLAR:
-             return LLAMA_ROPE_TYPE_NORM;
- 
-         // the pairs of head values are offset by n_rot/2
-diff --git a/src/llama-model.h b/src/llama-model.h
-index ce038932..c1b9c0a1 100644
---- a/src/llama-model.h
-+++ b/src/llama-model.h
-@@ -54,6 +54,7 @@ enum llm_type {
-     MODEL_15B,
-     MODEL_16B,
-     MODEL_20B,
-+    MODEL_22B,
-     MODEL_30B,
-     MODEL_32B,
-     MODEL_34B,
-@@ -275,6 +276,8 @@ struct llama_layer {
-     struct ggml_tensor * ffn_up_scale   = nullptr;
-     struct ggml_tensor * ffn_down_scale = nullptr;
- 
-+    struct ggml_tensor * bskcn_tv = nullptr;
-+
-     struct llama_layer_posnet posnet;
- 
-     struct llama_layer_convnext convnext;
-diff --git a/src/llama.cpp b/src/llama.cpp
-index 4eb3f6b9..7dec50ae 100644
---- a/src/llama.cpp
-+++ b/src/llama.cpp
-@@ -2206,6 +2206,35 @@ static bool llm_load_tensors(
+@@ -3316,6 +3331,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
  
                          layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
  
@@ -256,16 +191,16 @@ index 4eb3f6b9..7dec50ae 100644
 +                } break;
 +            case LLM_ARCH_SOLAR:
 +                {
-+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
++                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 +
 +                    // output
 +                    {
-+                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-+                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
++                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
++                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
 +                    }
 +
 +                    for (int i = 0; i < n_layer; ++i) {
-+                        auto & layer = model.layers[i];
++                        auto & layer = layers[i];
 +
 +                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 +
@@ -277,16 +212,53 @@ index 4eb3f6b9..7dec50ae 100644
 +                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 +
 +                        layer.bskcn_tv = create_tensor(tn(LLM_TENSOR_BSKCN_TV, "weight", i), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-+
                          layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
                          layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                          layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-@@ -10226,6 +10255,158 @@ struct llm_build_context {
-         return gf;
-     }
+@@ -3900,6 +3943,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
+         case LLM_ARCH_GRANITE:
+         case LLM_ARCH_GRANITE_MOE:
+         case LLM_ARCH_CHAMELEON:
++        case LLM_ARCH_SOLAR:
+             return LLAMA_ROPE_TYPE_NORM;
  
-+    ggml_cgraph * build_solar() {
-+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+         // the pairs of head values are offset by n_rot/2
+diff --git a/src/llama-model.h b/src/llama-model.h
+index a7c30444..1afb0024 100644
+--- a/src/llama-model.h
++++ b/src/llama-model.h
+@@ -55,6 +55,7 @@ enum llm_type {
+     LLM_TYPE_15B,
+     LLM_TYPE_16B,
+     LLM_TYPE_20B,
++    LLM_TYPE_22B,
+     LLM_TYPE_30B,
+     LLM_TYPE_32B,
+     LLM_TYPE_34B,
+@@ -281,6 +282,8 @@ struct llama_layer {
+     struct ggml_tensor * ffn_up_scale   = nullptr;
+     struct ggml_tensor * ffn_down_scale = nullptr;
+ 
++    struct ggml_tensor * bskcn_tv = nullptr;
++
+     struct llama_layer_posnet posnet;
+ 
+     struct llama_layer_convnext convnext;
+diff --git a/src/llama.cpp b/src/llama.cpp
+index ac85bfed..6d320ea4 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -7953,9 +7953,155 @@ struct llm_build_context {
+         cb(img_logits, "img_logits", -1);
+         cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
+         cb(cur, "result_output", -1);
+-
+         ggml_build_forward_expand(gf, cur);
++        return gf;
++   }
++
++   ggml_cgraph * build_solar() {
++        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 +
 +        // mutable variable, needed during the last layer of the computation to skip unused tokens
 +        int32_t n_tokens = this->n_tokens;
@@ -333,7 +305,7 @@ index 4eb3f6b9..7dec50ae 100644
 +                   ggml_mul(ctx0, bskcn_2, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
 +                   ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
 +            }
-+
+ 
 +            // norm
 +            cur = llm_build_norm(ctx0, inpL, hparams,
 +                    model.layers[il].attn_norm, NULL,
@@ -422,25 +394,18 @@ index 4eb3f6b9..7dec50ae 100644
 +        }
 +
 +        cur = inpL;
-+
 +        cur = llm_build_norm(ctx0, cur, hparams,
 +                model.output_norm, NULL,
 +                LLM_NORM_RMS, cb, -1);
 +        cb(cur, "result_norm", -1);
-+
 +        // lm_head
 +        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 +        cb(cur, "result_output", -1);
-+
 +        ggml_build_forward_expand(gf, cur);
-+
-+        return gf;
-+    }
-+
-     struct ggml_cgraph * build_wavtokenizer_dec() {
-         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+         return gf;
+     }
  
-@@ -10660,6 +10841,10 @@ static struct ggml_cgraph * llama_build_graph(
+@@ -8398,6 +8544,10 @@ static struct ggml_cgraph * llama_build_graph(
              {
                  result = llm.build_chameleon();
              } break;
diff --git a/llama/patches/0006-conditional-fattn.patch b/llama/patches/0006-conditional-fattn.patch
index 739905780..63af1f5c4 100644
--- a/llama/patches/0006-conditional-fattn.patch
+++ b/llama/patches/0006-conditional-fattn.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] conditional-fattn
  1 file changed, 2 insertions(+)
 
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index be29e979..aaa79ea4 100644
+index b094929b..36165840 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2159,9 +2159,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
+@@ -2282,9 +2282,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
          case GGML_OP_ARGSORT:
              ggml_cuda_op_argsort(ctx, dst);
              break;
diff --git a/llama/patches/0007-add-mllama-support.patch b/llama/patches/0007-add-mllama-support.patch
index 678dabad4..efed0923b 100644
--- a/llama/patches/0007-add-mllama-support.patch
+++ b/llama/patches/0007-add-mllama-support.patch
@@ -15,27 +15,27 @@ remaining is to implement the cross attention mask
  examples/llava/llava.cpp      |   5 +-
  ggml/src/ggml-backend-reg.cpp |   6 +-
  include/llama.h               |   6 +
- src/llama-arch.cpp            |  44 +++++
+ src/llama-arch.cpp            |  44 ++++++
  src/llama-arch.h              |  10 ++
  src/llama-batch.cpp           |   3 +
- src/llama-context.cpp         |  19 ++-
+ src/llama-context.cpp         |  28 ++--
  src/llama-context.h           |   2 +
  src/llama-cparams.h           |   1 +
- src/llama-hparams.cpp         |   8 +-
- src/llama-hparams.h           |   4 +
- src/llama-kv-cache.cpp        |  33 ++++
+ src/llama-hparams.cpp         |   6 +
+ src/llama-hparams.h           |   5 +
+ src/llama-kv-cache.cpp        |  13 +-
  src/llama-model-loader.cpp    |   2 +
- src/llama-model.cpp           |  59 ++-----
- src/llama-model.h             |  51 ++++++
+ src/llama-model.cpp           |  65 ++++++++-
+ src/llama-model.h             |  12 ++
  src/llama-quant.cpp           |   4 +-
- src/llama.cpp                 | 307 +++++++++++++++++++++++++++++++++-
- 17 files changed, 508 insertions(+), 56 deletions(-)
+ src/llama.cpp                 | 262 +++++++++++++++++++++++++++++++++-
+ 17 files changed, 452 insertions(+), 22 deletions(-)
 
 diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
-index 16f30c56..0f0f3f62 100644
+index 518aad3f..f0e484a1 100644
 --- a/examples/llava/llava.cpp
 +++ b/examples/llava/llava.cpp
-@@ -429,7 +429,7 @@ struct llava_embd_batch {
+@@ -445,7 +445,7 @@ struct llava_embd_batch {
      std::vector seq_ids;
      std::vector         logits;
      llama_batch batch;
@@ -44,7 +44,7 @@ index 16f30c56..0f0f3f62 100644
          pos     .resize(n_tokens);
          n_seq_id.resize(n_tokens);
          seq_ids .resize(n_tokens + 1);
-@@ -441,6 +441,7 @@ struct llava_embd_batch {
+@@ -457,6 +457,7 @@ struct llava_embd_batch {
              /*n_tokens       =*/ n_tokens,
              /*tokens         =*/ nullptr,
              /*embd           =*/ embd,
@@ -52,7 +52,7 @@ index 16f30c56..0f0f3f62 100644
              /*pos            =*/ pos.data(),
              /*n_seq_id       =*/ n_seq_id.data(),
              /*seq_id         =*/ seq_ids.data(),
-@@ -464,7 +465,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
+@@ -480,7 +481,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
              n_eval = n_batch;
          }
          float * embd = image_embed->embed+i*n_embd;
@@ -62,7 +62,7 @@ index 16f30c56..0f0f3f62 100644
              LOG_ERR("%s : failed to eval\n", __func__);
              return false;
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 7ddd178b..899d16f2 100644
+index 955ed505..95036ef8 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -171,9 +171,9 @@ struct ggml_backend_registry {
@@ -79,10 +79,10 @@ index 7ddd178b..899d16f2 100644
          register_backend(ggml_backend_rpc_reg());
  #endif
 diff --git a/include/llama.h b/include/llama.h
-index a0d5ba5d..9f411960 100644
+index 47919602..cc948005 100644
 --- a/include/llama.h
 +++ b/include/llama.h
-@@ -250,6 +250,7 @@ extern "C" {
+@@ -249,6 +249,7 @@ extern "C" {
  
          llama_token  *  token;
          float        *  embd;
@@ -90,7 +90,7 @@ index a0d5ba5d..9f411960 100644
          llama_pos    *  pos;
          int32_t      *  n_seq_id;
          llama_seq_id ** seq_id;
-@@ -347,6 +348,7 @@ extern "C" {
+@@ -343,6 +344,7 @@ extern "C" {
          bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
          bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
          bool no_perf;     // whether to measure performance timings
@@ -98,9 +98,9 @@ index a0d5ba5d..9f411960 100644
  
          // Abort callback
          // if it returns true, execution of llama_decode() will be aborted
-@@ -426,6 +428,10 @@ extern "C" {
-                      struct llama_model * model,
-             struct llama_context_params   params);
+@@ -443,6 +445,10 @@ extern "C" {
+             struct llama_context_params   params),
+             "use llama_init_from_model instead");
  
 +    // TODO (jmorganca): this should most likely be passed in as part of a batch
 +    // and not set on the context for all batches.
@@ -110,7 +110,7 @@ index a0d5ba5d..9f411960 100644
      LLAMA_API void llama_free(struct llama_context * ctx);
  
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index 5b376c5e..b35aeb31 100644
+index a1e0ebcc..b6f20286 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
 @@ -6,6 +6,7 @@
@@ -121,15 +121,15 @@ index 5b376c5e..b35aeb31 100644
      { LLM_ARCH_DECI,             "deci"             },
      { LLM_ARCH_FALCON,           "falcon"           },
      { LLM_ARCH_GROK,             "grok"             },
-@@ -124,6 +125,7 @@ static const std::map LLM_KV_NAMES = {
-     { LLM_KV_ATTENTION_SLIDING_WINDOW,           "%s.attention.sliding_window"           },
-     { LLM_KV_ATTENTION_SCALE,                    "%s.attention.scale"                    },
-     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,    "%s.attention.block_skip_connection"    },
-+    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,   "%s.attention.cross_attention_layers"   },
+@@ -127,6 +128,7 @@ static const std::map LLM_KV_NAMES = {
+     { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
+     { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
+     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,  "%s.attention.block_skip_connection"  },
++    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
  
      { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
      { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-@@ -220,6 +222,40 @@ static const std::map> LLM_TENSOR_N
+@@ -225,6 +227,40 @@ static const std::map> LLM_TENSOR_N
              { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
          },
      },
@@ -170,7 +170,7 @@ index 5b376c5e..b35aeb31 100644
      {
          LLM_ARCH_DECI,
          {
-@@ -1393,6 +1429,14 @@ static const std::map LLM_TENSOR_INFOS = {
+@@ -1450,6 +1486,14 @@ static const std::map LLM_TENSOR_INFOS = {
      // this tensor is loaded for T5, but never used
      {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
      {LLM_TENSOR_BSKCN_TV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -186,7 +186,7 @@ index 5b376c5e..b35aeb31 100644
      {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index eac7055b..e8235ae0 100644
+index 77919578..ec742224 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
 @@ -10,6 +10,7 @@
@@ -197,7 +197,7 @@ index eac7055b..e8235ae0 100644
      LLM_ARCH_DECI,
      LLM_ARCH_FALCON,
      LLM_ARCH_BAICHUAN,
-@@ -128,6 +129,7 @@ enum llm_kv {
+@@ -131,6 +132,7 @@ enum llm_kv {
      LLM_KV_ATTENTION_SLIDING_WINDOW,
      LLM_KV_ATTENTION_SCALE,
      LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
@@ -205,7 +205,7 @@ index eac7055b..e8235ae0 100644
  
      LLM_KV_ROPE_DIMENSION_COUNT,
      LLM_KV_ROPE_DIMENSION_SECTIONS,
-@@ -308,6 +310,14 @@ enum llm_tensor {
+@@ -314,6 +316,14 @@ enum llm_tensor {
      LLM_TENSOR_CLS,
      LLM_TENSOR_CLS_OUT,
      LLM_TENSOR_BSKCN_TV,
@@ -249,10 +249,10 @@ index 01d5ca57..8682b0e6 100644
          batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
      }
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index b9c4a5bf..9d0e7ca3 100644
+index 47e79ed4..7b22fe13 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -71,10 +71,19 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
+@@ -74,10 +74,19 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
      }
  
      if (ubatch.embd) {
@@ -275,7 +275,30 @@ index b9c4a5bf..9d0e7ca3 100644
      }
  
      if (ubatch.pos && lctx.inp_pos) {
-@@ -653,6 +662,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
+@@ -470,12 +479,11 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
+ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
+     const auto & cparams = lctx.cparams;
+     const auto & hparams = lctx.model.hparams;
+-    const auto & vocab   = lctx.model.vocab;
+ 
+     const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
+ 
+     const auto n_batch = cparams.n_batch;
+-    const auto n_vocab = vocab.n_tokens();
++    const auto n_vocab = hparams.n_vocab;
+     const auto n_embd  = hparams.n_embd;
+ 
+     // TODO: use a per-batch flag for logits presence instead
+@@ -542,7 +550,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
+ void llama_output_reorder(struct llama_context & ctx) {
+     std::vector & out_ids = ctx.sbatch.out_ids;
+     if (!out_ids.empty()) {
+-        const uint32_t n_vocab = ctx.model.vocab.n_tokens();
++        const uint32_t n_vocab = ctx.model.hparams.n_vocab;
+         const uint32_t n_embd  = ctx.model.hparams.n_embd;
+ 
+         const int32_t n_outputs = ctx.n_outputs;
+@@ -657,6 +665,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
      ctx->cparams.causal_attn = causal_attn;
  }
  
@@ -286,8 +309,26 @@ index b9c4a5bf..9d0e7ca3 100644
  void llama_synchronize(struct llama_context * ctx) {
      ggml_backend_sched_synchronize(ctx->sched.get());
  
+@@ -726,7 +738,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
+             throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
+         }
+ 
+-        return ctx->logits + j*ctx->model.vocab.n_tokens();
++        return ctx->logits + j*ctx->model.hparams.n_vocab;
+     } catch (const std::exception & err) {
+         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
+ #ifndef NDEBUG
+@@ -886,7 +898,7 @@ struct llama_data_write {
+     }
+ 
+     void write_logits(const struct llama_context * ctx) {
+-        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
++        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
+ 
+         write(&logits_size, sizeof(logits_size));
+ 
 diff --git a/src/llama-context.h b/src/llama-context.h
-index 0d163c47..4980a60e 100644
+index a9268b29..cf12c9d7 100644
 --- a/src/llama-context.h
 +++ b/src/llama-context.h
 @@ -107,6 +107,8 @@ struct llama_context {
@@ -312,7 +353,7 @@ index 252012f3..9681e5a0 100644
      enum llama_pooling_type pooling_type;
  
 diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
-index 450738da..42f8a58f 100644
+index f3955de9..0b841028 100644
 --- a/src/llama-hparams.cpp
 +++ b/src/llama-hparams.cpp
 @@ -2,6 +2,8 @@
@@ -328,18 +369,25 @@ index 450738da..42f8a58f 100644
      }
  
      GGML_ABORT("fatal error");
--}
-\ No newline at end of file
 +}
 +
 +bool llama_hparams::cross_attention_layers(uint32_t il) const {
 +    return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
-+}
+ }
+\ No newline at end of file
 diff --git a/src/llama-hparams.h b/src/llama-hparams.h
-index fd898e27..f826cd9a 100644
+index 1bdcdfd5..05383046 100644
 --- a/src/llama-hparams.h
 +++ b/src/llama-hparams.h
-@@ -53,6 +53,7 @@ struct llama_hparams {
+@@ -41,6 +41,7 @@ struct llama_hparams {
+     uint32_t n_expert = 0;
+     uint32_t n_expert_used = 0;
+     uint32_t n_rel_attn_bkts = 0;
++    uint32_t n_vocab = 0;
+ 
+     // for WavTokenizer
+     struct llama_hparams_posnet   posnet;
+@@ -51,6 +52,7 @@ struct llama_hparams {
      std::array n_ff_arr;
  
      std::array, 4> n_bskcn_arr = {};
@@ -347,65 +395,45 @@ index fd898e27..f826cd9a 100644
  
      uint32_t n_layer_dense_lead = 0;
      uint32_t n_lora_q           = 0;
-@@ -139,6 +140,9 @@ struct llama_hparams {
+@@ -138,6 +140,9 @@ struct llama_hparams {
  
      // Block skip connection
      bool n_bskcn(uint32_t n, uint32_t il) const;
 +
-+    // cross attention layers   
++    // cross attention layers
 +    bool cross_attention_layers(uint32_t il) const;
  };
  
  static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable");
 diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
-index 53379253..cf814dbe 100644
+index feffdf0d..b541c5a3 100644
 --- a/src/llama-kv-cache.cpp
 +++ b/src/llama-kv-cache.cpp
-@@ -72,6 +72,39 @@ bool llama_kv_cache_init(
-     cache.v_l.reserve(n_layer);
+@@ -91,8 +91,17 @@ bool llama_kv_cache_init(
+             return false;
+         }
  
-     for (int i = 0; i < n_layer; i++) {
+-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
++        ggml_tensor * k, *v;
++
 +        // for cross attention layers
 +        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
-+            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
-+            const llama_model::buft_list_t * buft_list;
-+            if (offload) {
-+                buft_list = model.dev_layer.at(i).buft_list;
-+            } else {
-+                buft_list = &model.cpu_buft_list;
-+            }
-+            ggml_backend_buffer_type_t buft = select_buft(*buft_list,
-+                [&](ggml_context * ctx) {
-+                    ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-+                    if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
-+                        return k;
-+                    }
-+                    ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
-+                    return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
-+                });
-+            ggml_context * ctx = ctx_for_buft(buft);
-+
-+            if (!ctx) {
-+                LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
-+                return false;
-+            }
-+            ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
-+            ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
-+            ggml_format_name(k, "cache_k_l%d", i);
-+            ggml_format_name(v, "cache_v_l%d", i);
-+            cache.k_l.push_back(k);
-+            cache.v_l.push_back(v);
-+            continue;
++            k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
++            v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
++        } else {
++            k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
++            v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
 +        }
 +
-         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
-         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
- 
+         ggml_format_name(k, "cache_k_l%d", i);
+         ggml_format_name(v, "cache_v_l%d", i);
+         cache.k_l.push_back(k);
 diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
-index 422524a8..b12d6566 100644
+index 1252aca1..45d08721 100644
 --- a/src/llama-model-loader.cpp
 +++ b/src/llama-model-loader.cpp
-@@ -240,6 +240,8 @@ namespace GGUFMeta {
+@@ -315,6 +315,8 @@ namespace GGUFMeta {
          return true;
      }
  
@@ -415,80 +443,47 @@ index 422524a8..b12d6566 100644
      bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) {
          const int kid = gguf_find_key(meta.get(), key.c_str());
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index 306c557d..4f9bbf90 100644
+index ad1315c6..21819080 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
-@@ -146,46 +146,6 @@ std::string llama_model_ftype_name(const llama_model & model) {
-     return llama_model_ftype_name(model.ftype);
- }
+@@ -401,6 +401,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
  
--template
--static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
--    ggml_init_params params = {
--        /*.mem_size   =*/ ggml_tensor_overhead()*8,
--        /*.mem_buffer =*/ NULL,
--        /*.no_alloc   =*/ true,
--    };
--
--    ggml_context_ptr ctx { ggml_init(params) };
--    if (!ctx) {
--        throw std::runtime_error(format("failed to create ggml context"));
--    }
--
--    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
--    ggml_tensor * op_tensor = fn(ctx.get());
--    for (int i = 0; i < GGML_MAX_SRC; i++) {
--        if (op_tensor->src[i] != nullptr) {
--            assert(op_tensor->src[i]->buffer == nullptr);
--            op_tensor->src[i]->buffer = buf.get();
--        }
--    }
--
--    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
--
--    return op_supported;
--}
--
--template
--static ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) {
--    for (const auto & cur : buft_list) {
--        ggml_backend_dev_t cur_dev = cur.first;
--        ggml_backend_buffer_type_t cur_buft = cur.second;
--        if (buft_supported(cur_buft, cur_dev, fn)) {
--            return cur_buft;
--        }
--    }
--
--    throw std::runtime_error(format("no suitable buffer type found"));
--}
--
- ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il) {
-     return select_buft(
-             *model.dev_layer.at(il).buft_list,
-@@ -312,9 +272,11 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
+     // get general kv
+     ml.get_key(LLM_KV_GENERAL_NAME, name, false);
++    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);
+ 
+     // everything past this point is not vocab-related
+     if (hparams.vocab_only) {
+@@ -412,6 +413,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+     ml.get_key(LLM_KV_BLOCK_COUNT,       hparams.n_layer);
+     ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
+     ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
++    ml.get_key(LLM_KV_VOCAB_SIZE,        hparams.n_vocab,       false);
+ 
+     if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
+         ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
+@@ -435,9 +437,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
      std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
      std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
      std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
 +    std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1);
  
--    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
--    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
-+    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,       hparams.n_ff_arr,   hparams.n_layer, false);
-+    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT,      hparams.n_head_arr, hparams.n_layer, false);
+     ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
+     ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
 +    ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
  
      // n_head_kv is optional, default to n_head
      hparams.n_head_kv_arr = hparams.n_head_arr;
-@@ -363,7 +325,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
+@@ -486,7 +490,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
  
          ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
  
--        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) {
-+        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) {
+-        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
++        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_MLLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
              if (hparams.n_rot != hparams.n_embd_head_k) {
                  throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
              }
-@@ -405,6 +367,16 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
+@@ -530,6 +534,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                      }
                  }
              } break;
@@ -497,145 +492,44 @@ index 306c557d..4f9bbf90 100644
 +                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 +
 +                switch (hparams.n_layer) {
-+                    case 40: model.type = e_model::MODEL_11B; break;
-+                    case 100: model.type = e_model::MODEL_90B; break;
-+                    default: model.type = e_model::MODEL_UNKNOWN;
++                    case 40: type = LLM_TYPE_11B; break;
++                    case 100: type = LLM_TYPE_90B; break;
++                    default: type = LLM_TYPE_UNKNOWN;
 +                }
 +            } break;
          case LLM_ARCH_DECI:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-@@ -2062,6 +2034,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
- 
-         // use what we call a normal RoPE, operating on pairs of consecutive head values
-         case LLM_ARCH_LLAMA:
-+        case LLM_ARCH_MLLAMA:
-         case LLM_ARCH_DECI:
-         case LLM_ARCH_BAICHUAN:
-         case LLM_ARCH_STARCODER:
-diff --git a/src/llama-model.h b/src/llama-model.h
-index c1b9c0a1..5b23e2ba 100644
---- a/src/llama-model.h
-+++ b/src/llama-model.h
-@@ -9,6 +9,7 @@
- #include "ggml-cpp.h"
- 
- #include 
-+#include 
- 
- // available models
- // TODO: this enum does not follow the enum naming convention
-@@ -62,6 +63,7 @@ enum llm_type {
-     MODEL_40B,
-     MODEL_65B,
-     MODEL_70B,
-+    MODEL_90B,
-     MODEL_236B,
-     MODEL_314B,
-     MODEL_671B,
-@@ -278,6 +280,16 @@ struct llama_layer {
- 
-     struct ggml_tensor * bskcn_tv = nullptr;
- 
-+     // cross attention
-+    struct ggml_tensor * cross_attn_k_norm = nullptr;
-+    struct ggml_tensor * cross_attn_k_proj = nullptr;
-+    struct ggml_tensor * cross_attn_o_proj = nullptr;
-+    struct ggml_tensor * cross_attn_q_norm = nullptr;
-+    struct ggml_tensor * cross_attn_q_proj = nullptr;
-+    struct ggml_tensor * cross_attn_v_proj = nullptr;
-+    struct ggml_tensor * cross_attn_attn_gate = nullptr;
-+    struct ggml_tensor * cross_attn_mlp_gate = nullptr;
-+
-     struct llama_layer_posnet posnet;
- 
-     struct llama_layer_convnext convnext;
-@@ -376,6 +388,45 @@ std::string llama_model_arch_name (const llama_model & model);
- std::string llama_model_type_name (const llama_model & model);
- std::string llama_model_ftype_name(const llama_model & model);
- 
-+template
-+bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
-+    ggml_init_params params = {
-+        /*.mem_size   =*/ ggml_tensor_overhead()*8,
-+        /*.mem_buffer =*/ NULL,
-+        /*.no_alloc   =*/ true,
-+    };
-+
-+    ggml_context_ptr ctx { ggml_init(params) };
-+    if (!ctx) {
-+        throw std::runtime_error("failed to create ggml context");
-+    }
-+
-+    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
-+    ggml_tensor * op_tensor = fn(ctx.get());
-+    for (int i = 0; i < GGML_MAX_SRC; i++) {
-+        if (op_tensor->src[i] != nullptr) {
-+            op_tensor->src[i]->buffer = buf.get();
-+        }
-+    }
-+
-+    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
-+
-+    return op_supported;
-+}
-+
-+template
-+ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) {
-+    for (const auto & cur : buft_list) {
-+        ggml_backend_dev_t cur_dev = cur.first;
-+        ggml_backend_buffer_type_t cur_buft = cur.second;
-+        if (buft_supported(cur_buft, cur_dev, fn)) {
-+            return cur_buft;
-+        }
-+    }
-+
-+    throw std::runtime_error("no suitable buffer type found");
-+}
-+
- // used by llama_adapter_cvec
- ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
- 
-diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
-index 42974f8f..27def6fd 100644
---- a/src/llama-quant.cpp
-+++ b/src/llama-quant.cpp
-@@ -629,7 +629,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
-         if (llama_model_has_encoder(&model)) {
-             n_attn_layer *= 3;
-         }
--        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
-+        if (qs.n_attention_wv != n_attn_layer) {
-+            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
-+        }
-     }
- 
-     size_t total_size_org = 0;
-diff --git a/src/llama.cpp b/src/llama.cpp
-index 7dec50ae..bac66c24 100644
---- a/src/llama.cpp
-+++ b/src/llama.cpp
-@@ -563,6 +563,52 @@ static bool llm_load_tensors(
+@@ -1398,7 +1412,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+         const int64_t n_embd_head_v = hparams.n_embd_head_v;
+         const int64_t n_ff          = hparams.n_ff();
+         const int64_t n_embd_gqa    = n_embd_v_gqa;
+-        const int64_t n_vocab       = vocab.n_tokens();
++        const int64_t n_vocab       = hparams.n_vocab;
+         const int64_t n_token_types = vocab.n_token_types();
+         const int64_t n_rot         = hparams.n_rot;
+         const int64_t n_expert      = hparams.n_expert;
+@@ -1581,6 +1595,52 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                          }
                      }
                  } break;
 +            case LLM_ARCH_MLLAMA:
 +                {
-+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
++                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
 +
 +                    // output
 +                    {
-+                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-+                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
++                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
++                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
 +
 +                        // if output is NULL, init from the input tok embed
-+                        if (model.output == NULL) {
-+                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
++                        if (output == NULL) {
++                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
 +                        }
 +                    }
 +
 +                    for (int i = 0; i < n_layer; ++i) {
-+                        auto & layer = model.layers[i];
++                        auto & layer = layers[i];
 +
 +                        if (hparams.cross_attention_layers(i)) {
 +                            layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128}, 0);
@@ -667,17 +561,72 @@ index 7dec50ae..bac66c24 100644
 +                } break;
              case LLM_ARCH_DECI:
                  {
-                     model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-@@ -2514,7 +2560,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
+                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+@@ -3925,6 +3985,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
  
-         if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
-             model.hparams.n_vocab != model.vocab.id_to_token.size()) {
--            throw std::runtime_error("vocab size mismatch");
-+            LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size());
+         // use what we call a normal RoPE, operating on pairs of consecutive head values
+         case LLM_ARCH_LLAMA:
++        case LLM_ARCH_MLLAMA:
+         case LLM_ARCH_DECI:
+         case LLM_ARCH_BAICHUAN:
+         case LLM_ARCH_STARCODER:
+diff --git a/src/llama-model.h b/src/llama-model.h
+index 1afb0024..7cf57587 100644
+--- a/src/llama-model.h
++++ b/src/llama-model.h
+@@ -9,6 +9,7 @@
+ #include 
+ #include 
+ #include 
++#include 
+ 
+ struct llama_model_loader;
+ 
+@@ -63,6 +64,7 @@ enum llm_type {
+     LLM_TYPE_40B,
+     LLM_TYPE_65B,
+     LLM_TYPE_70B,
++    LLM_TYPE_90B,
+     LLM_TYPE_236B,
+     LLM_TYPE_314B,
+     LLM_TYPE_671B,
+@@ -284,6 +286,16 @@ struct llama_layer {
+ 
+     struct ggml_tensor * bskcn_tv = nullptr;
+ 
++    // cross attention
++    struct ggml_tensor * cross_attn_k_norm = nullptr;
++    struct ggml_tensor * cross_attn_k_proj = nullptr;
++    struct ggml_tensor * cross_attn_o_proj = nullptr;
++    struct ggml_tensor * cross_attn_q_norm = nullptr;
++    struct ggml_tensor * cross_attn_q_proj = nullptr;
++    struct ggml_tensor * cross_attn_v_proj = nullptr;
++    struct ggml_tensor * cross_attn_attn_gate = nullptr;
++    struct ggml_tensor * cross_attn_mlp_gate = nullptr;
++
+     struct llama_layer_posnet posnet;
+ 
+     struct llama_layer_convnext convnext;
+diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
+index fb798265..6eb1da08 100644
+--- a/src/llama-quant.cpp
++++ b/src/llama-quant.cpp
+@@ -632,7 +632,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
+         if (llama_model_has_encoder(&model)) {
+             n_attn_layer *= 3;
          }
+-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
++        if (qs.n_attention_wv != n_attn_layer) {
++            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
++        }
+     }
  
-         if (params.vocab_only) {
-@@ -2598,6 +2644,21 @@ static struct ggml_tensor * llm_build_inp_embd(
+     size_t total_size_org = 0;
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 6d320ea4..8f7902df 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -154,6 +154,21 @@ static struct ggml_tensor * llm_build_inp_embd(
      return inpL;
  }
  
@@ -699,7 +648,7 @@ index 7dec50ae..bac66c24 100644
  static void llm_build_kv_store(
          struct ggml_context * ctx,
          const llama_hparams & hparams,
-@@ -3593,6 +3654,7 @@ struct llm_build_context {
+@@ -1157,6 +1172,7 @@ struct llm_build_context {
          lctx.inp_pos_bucket    = nullptr;
          lctx.inp_embd_enc      = nullptr;
          lctx.inp_KQ_mask_cross = nullptr;
@@ -707,12 +656,12 @@ index 7dec50ae..bac66c24 100644
      }
  
      void free() {
-@@ -4074,6 +4136,240 @@ struct llm_build_context {
+@@ -1639,6 +1655,240 @@ struct llm_build_context {
          return gf;
      }
  
-+        struct ggml_cgraph * build_mllama() {
-+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
++    struct ggml_cgraph * build_mllama() {
++        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 +
 +        // mutable variable, needed during the last layer of the computation to skip unused tokens
 +        int32_t n_tokens = this->n_tokens;
@@ -946,9 +895,9 @@ index 7dec50ae..bac66c24 100644
 +    }
 +
      struct ggml_cgraph * build_deci() {
-         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
  
-@@ -10646,6 +10942,10 @@ static struct ggml_cgraph * llama_build_graph(
+@@ -8344,6 +8594,10 @@ static struct ggml_cgraph * llama_build_graph(
              {
                  result = llm.build_llama();
              } break;
@@ -959,16 +908,33 @@ index 7dec50ae..bac66c24 100644
          case LLM_ARCH_DECI:
              {
                  result = llm.build_deci();
-@@ -10971,7 +11271,7 @@ static int llama_decode_internal(
+@@ -8634,7 +8888,7 @@ static int llama_prepare_sbatch(
          n_outputs = 1;
      }
  
 -    lctx.sbatch.from_batch(batch, n_embd,
 +    lctx.sbatch.from_batch(batch, batch.n_embd,
-         /* simple_split */ !kv_self.recurrent,
+         /* simple_split */ !lctx.kv_self.recurrent,
          /* logits_all   */ n_outputs == n_tokens_all);
  
-@@ -11282,7 +11582,7 @@ static int llama_encode_internal(
+@@ -8749,7 +9003,6 @@ static int llama_decode_impl(
+     const llama_batch & batch = batch_allocr.batch;
+ 
+     const auto & model   = lctx.model;
+-    const auto & vocab   = model.vocab;
+     const auto & hparams = model.hparams;
+     const auto & cparams = lctx.cparams;
+ 
+@@ -8760,7 +9013,7 @@ static int llama_decode_impl(
+     llama_kv_slot_restorer kv_slot_restorer(kv_self);
+ 
+     const int64_t n_embd  = hparams.n_embd;
+-    const int64_t n_vocab = vocab.n_tokens();
++    const int64_t n_vocab = hparams.n_vocab;
+ 
+     uint32_t n_outputs = 0;
+     uint32_t n_outputs_prev = 0;
+@@ -9025,7 +9278,7 @@ static int llama_encode_impl(
  
      const int64_t n_embd = hparams.n_embd;
  
@@ -977,7 +943,7 @@ index 7dec50ae..bac66c24 100644
  
      const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
  
-@@ -11775,6 +12075,7 @@ struct llama_context_params llama_context_default_params() {
+@@ -9511,6 +9764,7 @@ struct llama_context_params llama_context_default_params() {
          /*.offload_kqv                 =*/ true,
          /*.flash_attn                  =*/ false,
          /*.no_perf                     =*/ true,
diff --git a/llama/patches/0008-add-unpad-operator.patch b/llama/patches/0008-add-unpad-operator.patch
index fd070df9c..bfa82de2b 100644
--- a/llama/patches/0008-add-unpad-operator.patch
+++ b/llama/patches/0008-add-unpad-operator.patch
@@ -15,10 +15,10 @@ Subject: [PATCH] add unpad operator
  8 files changed, 220 insertions(+), 2 deletions(-)
 
 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
-index c714fc8c..1bc50fca 100644
+index dd0c6a96..8d269a9c 100644
 --- a/ggml/include/ggml.h
 +++ b/ggml/include/ggml.h
-@@ -499,6 +499,7 @@ extern "C" {
+@@ -487,6 +487,7 @@ extern "C" {
          GGML_OP_UPSCALE, // nearest interpolate
          GGML_OP_PAD,
          GGML_OP_PAD_REFLECT_1D,
@@ -26,7 +26,7 @@ index c714fc8c..1bc50fca 100644
          GGML_OP_ARANGE,
          GGML_OP_TIMESTEP_EMBEDDING,
          GGML_OP_ARGSORT,
-@@ -1735,6 +1736,15 @@ extern "C" {
+@@ -1743,6 +1744,15 @@ extern "C" {
              int                   p0,
              int                   p1);
  
@@ -43,10 +43,10 @@ index c714fc8c..1bc50fca 100644
      // timesteps: [N,]
      // return: [N, dim]
 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
-index b7fefb9d..b307d554 100644
+index 72325349..2f606d82 100644
 --- a/ggml/src/ggml-cpu/ggml-cpu.c
 +++ b/ggml/src/ggml-cpu/ggml-cpu.c
-@@ -10588,6 +10588,59 @@ static void ggml_compute_forward_pad_reflect_1d(
+@@ -10844,6 +10844,59 @@ static void ggml_compute_forward_pad_reflect_1d(
      }
  }
  
@@ -106,7 +106,7 @@ index b7fefb9d..b307d554 100644
  // ggml_compute_forward_arange
  
  static void ggml_compute_forward_arange_f32(
-@@ -12690,6 +12743,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
+@@ -13137,6 +13190,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
              {
                  ggml_compute_forward_pad_reflect_1d(params, tensor);
              } break;
@@ -117,7 +117,7 @@ index b7fefb9d..b307d554 100644
          case GGML_OP_ARANGE:
              {
                  ggml_compute_forward_arange(params, tensor);
-@@ -13033,6 +13090,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+@@ -13484,6 +13541,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
          case GGML_OP_UPSCALE:
          case GGML_OP_PAD:
          case GGML_OP_PAD_REFLECT_1D:
@@ -126,10 +126,10 @@ index b7fefb9d..b307d554 100644
          case GGML_OP_TIMESTEP_EMBEDDING:
          case GGML_OP_ARGSORT:
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index aaa79ea4..9286f866 100644
+index 36165840..1adf08fa 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2082,6 +2082,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
+@@ -2198,6 +2198,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
          case GGML_OP_PAD:
              ggml_cuda_op_pad(ctx, dst);
              break;
@@ -139,8 +139,8 @@ index aaa79ea4..9286f866 100644
          case GGML_OP_ARANGE:
              ggml_cuda_op_arange(ctx, dst);
              break;
-@@ -3010,6 +3013,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
-         case GGML_OP_GROUP_NORM:
+@@ -3197,6 +3200,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
+             return ggml_is_contiguous(op->src[0]);
          case GGML_OP_UPSCALE:
          case GGML_OP_PAD:
 +        case GGML_OP_UNPAD:
@@ -148,7 +148,7 @@ index aaa79ea4..9286f866 100644
          case GGML_OP_TIMESTEP_EMBEDDING:
          case GGML_OP_LEAKY_RELU:
 diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
-index aba539e8..39fd4b16 100644
+index aba539e8..b4b87409 100644
 --- a/ggml/src/ggml-cuda/pad.cu
 +++ b/ggml/src/ggml-cuda/pad.cu
 @@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -201,6 +201,7 @@ index aba539e8..39fd4b16 100644
 +        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
 +        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
 +}
+\ No newline at end of file
 diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh
 index 8fd386b0..e2ededc3 100644
 --- a/ggml/src/ggml-cuda/pad.cuh
@@ -211,10 +212,10 @@ index 8fd386b0..e2ededc3 100644
  void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 +void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index cd8ef741..318addec 100644
+index fd9a4e77..e4c093f9 100644
 --- a/ggml/src/ggml-metal/ggml-metal.m
 +++ b/ggml/src/ggml-metal/ggml-metal.m
-@@ -311,6 +311,7 @@ enum ggml_metal_kernel_type {
+@@ -331,6 +331,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
      GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
      GGML_METAL_KERNEL_TYPE_PAD_F32,
      GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
@@ -222,7 +223,7 @@ index cd8ef741..318addec 100644
      GGML_METAL_KERNEL_TYPE_ARANGE_F32,
      GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
      GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
-@@ -910,6 +911,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
+@@ -946,6 +947,7 @@ @implementation GGMLMetalClass
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,            pad_reflect_1d_f32,             true);
@@ -230,7 +231,7 @@ index cd8ef741..318addec 100644
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
-@@ -1145,6 +1147,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
+@@ -1254,6 +1256,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
          case GGML_OP_UPSCALE:
          case GGML_OP_PAD:
          case GGML_OP_PAD_REFLECT_1D:
@@ -238,7 +239,7 @@ index cd8ef741..318addec 100644
          case GGML_OP_ARANGE:
          case GGML_OP_TIMESTEP_EMBEDDING:
          case GGML_OP_ARGSORT:
-@@ -3348,6 +3351,36 @@ static void ggml_metal_encode_node(
+@@ -3469,6 +3472,36 @@ static void ggml_metal_encode_node(
  
                  const int nth = MIN(1024, ne0);
  
@@ -276,10 +277,10 @@ index cd8ef741..318addec 100644
              } break;
          case GGML_OP_ARANGE:
 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index 8ba43904..204c93e6 100644
+index d092a169..f38909d0 100644
 --- a/ggml/src/ggml-metal/ggml-metal.metal
 +++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -2944,6 +2944,51 @@ kernel void kernel_pad_reflect_1d_f32(
+@@ -2953,6 +2953,51 @@ kernel void kernel_pad_reflect_1d_f32(
      }
  }
  
@@ -332,10 +333,10 @@ index 8ba43904..204c93e6 100644
      device        char * dst,
      constant   int64_t & ne0,
 diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
-index 2bbe5f48..7ffcd907 100644
+index 7fc06724..635aa299 100644
 --- a/ggml/src/ggml.c
 +++ b/ggml/src/ggml.c
-@@ -954,6 +954,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+@@ -962,6 +962,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
      "UPSCALE",
      "PAD",
      "PAD_REFLECT_1D",
@@ -343,16 +344,16 @@ index 2bbe5f48..7ffcd907 100644
      "ARANGE",
      "TIMESTEP_EMBEDDING",
      "ARGSORT",
-@@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+@@ -996,7 +997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
      "OPT_STEP_ADAMW",
  };
  
--static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
-+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
++static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
  
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
      "none",
-@@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+@@ -1059,6 +1060,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
      "upscale(x)",
      "pad(x)",
      "pad_reflect_1d(x)",
@@ -360,16 +361,16 @@ index 2bbe5f48..7ffcd907 100644
      "arange(start, stop, step)",
      "timestep_embedding(timesteps, dim, max_period)",
      "argsort(x)",
-@@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+@@ -1093,7 +1095,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
      "adamw(x)",
  };
  
--static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
-+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
++static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
  
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
  
-@@ -4214,6 +4216,25 @@ struct ggml_tensor * ggml_pad_reflect_1d(
+@@ -4225,6 +4227,25 @@ struct ggml_tensor * ggml_pad_reflect_1d(
      return result;
  }
  
diff --git a/llama/patches/0009-fix-deepseek-deseret-regex.patch b/llama/patches/0009-fix-deepseek-deseret-regex.patch
index 5c334cfd9..715c5206f 100644
--- a/llama/patches/0009-fix-deepseek-deseret-regex.patch
+++ b/llama/patches/0009-fix-deepseek-deseret-regex.patch
@@ -11,10 +11,10 @@ the characters
  2 files changed, 23 insertions(+), 1 deletion(-)
 
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index 3fcfcaa3..8f44705a 100644
+index a4eee9b8..1ca827eb 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
-@@ -375,7 +375,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
+@@ -295,7 +295,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
              case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
                  regex_exprs = {
                      "[\r\n]",
@@ -24,7 +24,7 @@ index 3fcfcaa3..8f44705a 100644
                      "\\s+$",
                      "[一-龥ࠀ-一가-퟿]+",
 diff --git a/src/unicode.cpp b/src/unicode.cpp
-index 7aca6544..6155da80 100644
+index e63bb4ab..9dd53b9a 100644
 --- a/src/unicode.cpp
 +++ b/src/unicode.cpp
 @@ -2,6 +2,11 @@
@@ -39,7 +39,7 @@ index 7aca6544..6155da80 100644
  #include "unicode.h"
  #include "unicode-data.h"
  
-@@ -201,6 +206,22 @@ static std::unordered_map unicode_utf8_to_byte_map() {
+@@ -200,6 +205,22 @@ static std::unordered_map unicode_utf8_to_byte_map() {
  }
  
  static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
@@ -62,7 +62,7 @@ index 7aca6544..6155da80 100644
  #if defined(__clang__)
      // disable C++17 deprecation warning for std::codecvt_utf8
  #    pragma clang diagnostic push
-@@ -214,6 +235,7 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
+@@ -213,6 +234,7 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
  #endif
  
      return conv.from_bytes(s);
diff --git a/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch b/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
index 33b504ec1..1e930fb29 100644
--- a/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
+++ b/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
@@ -8,11 +8,11 @@ Subject: [PATCH] Maintain ordering for rules for grammar
  1 file changed, 1 insertion(+), 1 deletion(-)
 
 diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
-index dadc18c8..2a8dbd22 100644
+index 3ebcc3d9..30c28808 100644
 --- a/common/json-schema-to-grammar.cpp
 +++ b/common/json-schema-to-grammar.cpp
-@@ -391,7 +391,7 @@ class SchemaConverter {
- private:
+@@ -346,7 +346,7 @@ private:
+     friend std::string build_grammar(const std::function & cb, const common_grammar_options & options);
      std::function _fetch_json;
      bool _dotall;
 -    std::map _rules;
diff --git a/llama/patches/0011-fix-missing-arg-in-static-assert-on-windows.patch b/llama/patches/0011-fix-missing-arg-in-static-assert-on-windows.patch
deleted file mode 100644
index 8c43ad3d4..000000000
--- a/llama/patches/0011-fix-missing-arg-in-static-assert-on-windows.patch
+++ /dev/null
@@ -1,22 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Sat, 14 Dec 2024 12:54:00 -0800
-Subject: [PATCH] fix missing arg in static assert on windows
-
----
- ggml/src/ggml-cuda/concat.cu | 2 +-
- 1 file changed, 1 insertion(+), 1 deletion(-)
-
-diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu
-index 2f42b8a9..5eb9f08d 100644
---- a/ggml/src/ggml-cuda/concat.cu
-+++ b/ggml/src/ggml-cuda/concat.cu
-@@ -124,7 +124,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
-           uint64_t   nb1,
-           uint64_t   nb2,
-           uint64_t   nb3){
--    static_assert(dim >= 0 && dim <= 3);
-+    static_assert(dim >= 0 && dim <= 3, "dim must be between 0 and 3");
- 
-     const int64_t i3 = blockIdx.z;
-     const int64_t i2 = blockIdx.y;
diff --git a/llama/patches/0012-llama-Ensure-KV-cache-is-fully-defragmented.patch b/llama/patches/0011-llama-Ensure-KV-cache-is-fully-defragmented.patch
similarity index 84%
rename from llama/patches/0012-llama-Ensure-KV-cache-is-fully-defragmented.patch
rename to llama/patches/0011-llama-Ensure-KV-cache-is-fully-defragmented.patch
index 3ef51f4ef..ff0575390 100644
--- a/llama/patches/0012-llama-Ensure-KV-cache-is-fully-defragmented.patch
+++ b/llama/patches/0011-llama-Ensure-KV-cache-is-fully-defragmented.patch
@@ -19,10 +19,10 @@ multiple batches of processing until everything is complete.
  1 file changed, 46 insertions(+), 53 deletions(-)
 
 diff --git a/src/llama.cpp b/src/llama.cpp
-index bac66c24..c95da45d 100644
+index 8f7902df..01854fce 100644
 --- a/src/llama.cpp
 +++ b/src/llama.cpp
-@@ -3536,6 +3536,13 @@ static struct ggml_tensor * llm_build_rwkv6_channel_mix(
+@@ -1054,6 +1054,13 @@ static struct ggml_tensor * llm_build_rwkv6_channel_mix(
      return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
  }
  
@@ -36,13 +36,13 @@ index bac66c24..c95da45d 100644
  struct llm_build_context {
      const llama_model    & model;
            llama_context  & lctx;
-@@ -3712,35 +3719,23 @@ struct llm_build_context {
+@@ -1230,35 +1237,23 @@ struct llm_build_context {
          return gf;
      }
  
 -    struct ggml_cgraph * build_defrag(const std::vector & ids) {
 +    struct ggml_cgraph * build_defrag(const std::vector & moves) {
-         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
  
 -        for (uint32_t i = 0; i < ids.size(); ++i) {
 -            const uint32_t id = ids[i];
@@ -78,7 +78,7 @@ index bac66c24..c95da45d 100644
  
                  ggml_tensor * view_v_src;
                  ggml_tensor * view_v_dst;
-@@ -3748,31 +3743,29 @@ struct llm_build_context {
+@@ -1266,31 +1261,29 @@ struct llm_build_context {
                  if (flash_attn) {
                      // NOTE: the V cache is not transposed when using flash attention
                      view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
@@ -118,7 +118,7 @@ index bac66c24..c95da45d 100644
          }
  
          //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
-@@ -10856,7 +10849,7 @@ struct llm_build_context {
+@@ -8508,7 +8501,7 @@ struct llm_build_context {
      }
  };
  
@@ -127,7 +127,7 @@ index bac66c24..c95da45d 100644
      llama_ubatch dummy = {};
      dummy.equal_seqs = true;
  
-@@ -10866,7 +10859,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
+@@ -8518,7 +8511,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
  
      llm.init();
  
@@ -136,21 +136,21 @@ index bac66c24..c95da45d 100644
  
      llm.free();
  
-@@ -11329,7 +11322,12 @@ static int llama_decode_internal(
-                 kv_self.head = 0;
-             }
+@@ -8956,7 +8949,12 @@ static int llama_prepare_ubatch(
+             kv_self.head = 0;
+         }
  
--            const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-+            auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-+            if (!slot) {
-+                llama_kv_cache_defrag(kv_self);
-+                llama_kv_cache_update(&lctx);
-+                slot = llama_kv_cache_find_slot(kv_self, ubatch);
-+            }
-             if (!slot) {
-                 return 1;
-             }
-@@ -11735,8 +11733,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+-        const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
++        auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
++        if (!slot) {
++            llama_kv_cache_defrag(kv_self);
++            llama_kv_cache_update(&lctx);
++            slot = llama_kv_cache_find_slot(kv_self, ubatch);
++        }
+         if (!slot) {
+             return 1;
+         }
+@@ -9431,8 +9429,8 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
  
      //const int64_t t_start = ggml_time_us();
  
@@ -161,7 +161,7 @@ index bac66c24..c95da45d 100644
  
      // each move requires 6*n_layer tensors (see build_defrag)
      //   - source view, destination view, copy operation
-@@ -11800,19 +11798,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+@@ -9496,19 +9494,11 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
          // are we moving a continuous block of memory?
          bool cont = false;
  
@@ -181,7 +181,7 @@ index bac66c24..c95da45d 100644
                  cont = false;
                  continue;
              }
-@@ -11828,8 +11818,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+@@ -9524,8 +9514,10 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
              kv_self.head = n_used;
  
              if (!cont) {
@@ -193,7 +193,7 @@ index bac66c24..c95da45d 100644
              }
  
              nf++;
-@@ -11839,22 +11831,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+@@ -9535,22 +9527,16 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
              }
          }
  
@@ -218,7 +218,7 @@ index bac66c24..c95da45d 100644
  
  #if 0
      // CPU defrag
-@@ -11929,11 +11915,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+@@ -9625,11 +9611,18 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
  #else
      // ggml_graph defrag
  
diff --git a/llama/patches/0013-use-dynamic-backend-loading-for-clip.patch b/llama/patches/0012-use-dynamic-backend-loading-for-clip.patch
similarity index 94%
rename from llama/patches/0013-use-dynamic-backend-loading-for-clip.patch
rename to llama/patches/0012-use-dynamic-backend-loading-for-clip.patch
index e283a857f..e28571695 100644
--- a/llama/patches/0013-use-dynamic-backend-loading-for-clip.patch
+++ b/llama/patches/0012-use-dynamic-backend-loading-for-clip.patch
@@ -8,12 +8,12 @@ Subject: [PATCH] use dynamic backend loading for clip
  1 file changed, 27 insertions(+), 47 deletions(-)
 
 diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
-index b3c1829f..86b91d5c 100644
+index 205af1eb..560021c7 100644
 --- a/examples/llava/clip.cpp
 +++ b/examples/llava/clip.cpp
-@@ -8,25 +8,25 @@
- #include "ggml-alloc.h"
+@@ -9,25 +9,25 @@
  #include "ggml-backend.h"
+ #include "gguf.h"
  
 -//#ifdef GGML_USE_CUDA
 -//#include "ggml-cuda.h"
@@ -56,7 +56,7 @@ index b3c1829f..86b91d5c 100644
  
  #define STB_IMAGE_IMPLEMENTATION
  #include "stb_image.h"
-@@ -1235,35 +1235,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
+@@ -1309,35 +1309,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
          }
      }
  
diff --git a/llama/patches/0014-sort-devices-by-score.patch b/llama/patches/0013-sort-devices-by-score.patch
similarity index 98%
rename from llama/patches/0014-sort-devices-by-score.patch
rename to llama/patches/0013-sort-devices-by-score.patch
index d57a43661..7640a8db0 100644
--- a/llama/patches/0014-sort-devices-by-score.patch
+++ b/llama/patches/0013-sort-devices-by-score.patch
@@ -8,7 +8,7 @@ Subject: [PATCH] sort devices by score
  1 file changed, 13 insertions(+), 8 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 899d16f2..135f7df0 100644
+index 95036ef8..98d5e14d 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -150,7 +150,7 @@ struct ggml_backend_reg_entry {
diff --git a/llama/patches/0015-add-phony-target-ggml-cpu-for-all-cpu-variants.patch b/llama/patches/0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
similarity index 86%
rename from llama/patches/0015-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
rename to llama/patches/0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
index e68950a57..f263ece17 100644
--- a/llama/patches/0015-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
+++ b/llama/patches/0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants
  1 file changed, 2 insertions(+)
 
 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index 84101c32..72b488dd 100644
+index 0002ac18..0a8d1092 100644
 --- a/ggml/src/CMakeLists.txt
 +++ b/ggml/src/CMakeLists.txt
-@@ -278,6 +278,7 @@ function(ggml_add_cpu_backend_variant tag_name)
+@@ -297,6 +297,7 @@ function(ggml_add_cpu_backend_variant tag_name)
      endforeach()
  
      ggml_add_cpu_backend_variant_impl(${tag_name})
@@ -19,7 +19,7 @@ index 84101c32..72b488dd 100644
  endfunction()
  
  ggml_add_backend(CPU)
-@@ -286,6 +287,7 @@ if (GGML_CPU_ALL_VARIANTS)
+@@ -305,6 +306,7 @@ if (GGML_CPU_ALL_VARIANTS)
      if (NOT GGML_BACKEND_DL)
          message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
      endif()
diff --git a/llama/patches/0017-try-catch-backend-load.patch b/llama/patches/0015-try-catch-backend-load.patch
similarity index 99%
rename from llama/patches/0017-try-catch-backend-load.patch
rename to llama/patches/0015-try-catch-backend-load.patch
index e7f71c7c6..9aea61836 100644
--- a/llama/patches/0017-try-catch-backend-load.patch
+++ b/llama/patches/0015-try-catch-backend-load.patch
@@ -8,7 +8,7 @@ Subject: [PATCH] try/catch backend load
  1 file changed, 23 insertions(+), 22 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 135f7df0..84b21dd8 100644
+index 98d5e14d..1c19129a 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
diff --git a/llama/patches/0016-remove-sgemm-global-variables.patch b/llama/patches/0016-remove-sgemm-global-variables.patch
deleted file mode 100644
index 31a59aeaf..000000000
--- a/llama/patches/0016-remove-sgemm-global-variables.patch
+++ /dev/null
@@ -1,55 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Sun, 9 Feb 2025 17:22:15 -0800
-Subject: [PATCH] remove sgemm global variables
-
-removes the 'iq4nlt' global variable in sgemm.cpp that causes
-a runtime crash when calling dlopen on ggml-cpu libraries as
-its initialization depends on AVX instructions the host machine
-may not have
----
- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 17 +++++++++--------
- 1 file changed, 9 insertions(+), 8 deletions(-)
-
-diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
-index 8fce576c..3f260ce5 100644
---- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp
-+++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
-@@ -279,14 +279,6 @@ template <> inline __m256bh load(const float *p) {
- }
- #endif
- 
--////////////////////////////////////////////////////////////////////////////////////////////////////
--// CONSTANTS
--
--#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
--static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
--static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
--#endif
--
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- // FLOATING POINT MATRIX MULTIPLICATION
- 
-@@ -613,6 +605,14 @@ class tinyBLAS_Q0_AVX {
-                     TC *C, int64_t ldc,
-                     int ith, int nth)
-         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-+        const int8_t kvalues_iq4nl[16] = {
-+            -127, -104, -83, -65,
-+            -49,  -35,  -22, -10,
-+              1,   13,   25,  38,
-+             53,   69,   89, 113
-+        };
-+
-+        iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
-     }
- 
-     void matmul(int64_t m, int64_t n) {
-@@ -1037,6 +1037,7 @@ class tinyBLAS_Q0_AVX {
-     const int64_t ldc;
-     const int ith;
-     const int nth;
-+    __m128i iq4nlt;
- };
- #endif // __AVX__
- 
diff --git a/llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch b/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch
similarity index 99%
rename from llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch
rename to llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch
index 95144fb45..d60066c13 100644
--- a/llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch
+++ b/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch
@@ -8,7 +8,7 @@ Subject: [PATCH] use std::filesystem::path instead of wstring
  1 file changed, 58 insertions(+), 86 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 84b21dd8..e35a6936 100644
+index 1c19129a..c854e6bb 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -66,26 +66,6 @@
diff --git a/llama/patches/0019-remove-amx.patch b/llama/patches/0017-remove-amx.patch
similarity index 89%
rename from llama/patches/0019-remove-amx.patch
rename to llama/patches/0017-remove-amx.patch
index 5428ee64a..234d51cc7 100644
--- a/llama/patches/0019-remove-amx.patch
+++ b/llama/patches/0017-remove-amx.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] remove amx
  1 file changed, 4 deletions(-)
 
 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index 72b488dd..50828717 100644
+index 0a8d1092..4564df91 100644
 --- a/ggml/src/CMakeLists.txt
 +++ b/ggml/src/CMakeLists.txt
-@@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS)
+@@ -312,10 +312,6 @@ if (GGML_CPU_ALL_VARIANTS)
      ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 FMA AVX512)
      ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
      ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 FMA AVX_VNNI)
@@ -19,6 +19,6 @@ index 72b488dd..50828717 100644
 -        # MSVC doesn't support AMX
 -        ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
 -    endif()
- else ()
+ elseif (GGML_CPU)
      ggml_add_cpu_backend_variant_impl("")
  endif()
diff --git a/llama/patches/0018-fix-clip-compiler-error.patch b/llama/patches/0018-fix-clip-compiler-error.patch
new file mode 100644
index 000000000..ef6e247bd
--- /dev/null
+++ b/llama/patches/0018-fix-clip-compiler-error.patch
@@ -0,0 +1,36 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Tue, 25 Feb 2025 19:14:51 -0800
+Subject: [PATCH] fix-clip-compiler-error
+
+---
+ examples/llava/clip.cpp | 2 +-
+ examples/llava/clip.h   | 2 +-
+ 2 files changed, 2 insertions(+), 2 deletions(-)
+
+diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
+index 560021c7..54265beb 100644
+--- a/examples/llava/clip.cpp
++++ b/examples/llava/clip.cpp
+@@ -1788,7 +1788,7 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch  * batch) {
+     }
+ }
+ 
+-void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
++void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img) {
+     img->nx = nx;
+     img->ny = ny;
+     img->buf.resize(3 * nx * ny);
+diff --git a/examples/llava/clip.h b/examples/llava/clip.h
+index ce6f6194..f9f80d7d 100644
+--- a/examples/llava/clip.h
++++ b/examples/llava/clip.h
+@@ -75,7 +75,7 @@ CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch  * batch);
+ CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
+ 
+ /** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */
+-CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img);
++CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
+ 
+ CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
+ 
diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h
index 7221a0830..fc9571c82 100644
--- a/ml/backend/ggml/ggml/include/ggml-backend.h
+++ b/ml/backend/ggml/ggml/include/ggml-backend.h
@@ -203,6 +203,8 @@ extern "C" {
     // Backend registry
     //
 
+    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
+
     // Backend (reg) enumeration
     GGML_API size_t             ggml_backend_reg_count(void);
     GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);
diff --git a/ml/backend/ggml/ggml/include/ggml-cpp.h b/ml/backend/ggml/ggml/include/ggml-cpp.h
index 219361af4..a12342c25 100644
--- a/ml/backend/ggml/ggml/include/ggml-cpp.h
+++ b/ml/backend/ggml/ggml/include/ggml-cpp.h
@@ -7,6 +7,7 @@
 #include "ggml.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
+#include "gguf.h"
 #include 
 
 // Smart pointers for ggml types
diff --git a/ml/backend/ggml/ggml/include/ggml-cpu.h b/ml/backend/ggml/ggml/include/ggml-cpu.h
index 3aa71badb..b48cc560e 100644
--- a/ml/backend/ggml/ggml/include/ggml-cpu.h
+++ b/ml/backend/ggml/ggml/include/ggml-cpu.h
@@ -8,7 +8,7 @@ extern "C" {
 #endif
 
     // the compute plan that needs to be prepared for ggml_graph_compute()
-    // since https://github.com/ggerganov/ggml/issues/287
+    // since https://github.com/ggml-org/ggml/issues/287
     struct ggml_cplan {
         size_t    work_size; // size of work buffer, calculated by `ggml_graph_plan()`
         uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
@@ -95,9 +95,11 @@ extern "C" {
     GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
     GGML_BACKEND_API int ggml_cpu_has_sve        (void);
     GGML_BACKEND_API int ggml_cpu_get_sve_cnt    (void);  // sve vector length in bytes
+    GGML_BACKEND_API int ggml_cpu_has_sme        (void);
     // other
     GGML_BACKEND_API int ggml_cpu_has_riscv_v    (void);
     GGML_BACKEND_API int ggml_cpu_has_vsx        (void);
+    GGML_BACKEND_API int ggml_cpu_has_vxe        (void);
     GGML_BACKEND_API int ggml_cpu_has_wasm_simd  (void);
     GGML_BACKEND_API int ggml_cpu_has_llamafile  (void);
 
diff --git a/ml/backend/ggml/ggml/include/ggml-metal.h b/ml/backend/ggml/ggml/include/ggml-metal.h
index 669c1f84a..a61069442 100644
--- a/ml/backend/ggml/ggml/include/ggml-metal.h
+++ b/ml/backend/ggml/ggml/include/ggml-metal.h
@@ -45,7 +45,7 @@ GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
 
 GGML_DEPRECATED(
         GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
-        "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
+        "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
 
 GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
 
diff --git a/ml/backend/ggml/ggml/include/ggml-vulkan.h b/ml/backend/ggml/ggml/include/ggml-vulkan.h
index 53cdba072..ed5ea5f79 100644
--- a/ml/backend/ggml/ggml/include/ggml-vulkan.h
+++ b/ml/backend/ggml/ggml/include/ggml-vulkan.h
@@ -10,8 +10,6 @@ extern "C" {
 #define GGML_VK_NAME "Vulkan"
 #define GGML_VK_MAX_DEVICES 16
 
-GGML_BACKEND_API void ggml_vk_instance_init(void);
-
 // backend API
 GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
 
diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h
index 1bc50fcae..8d269a9cc 100644
--- a/ml/backend/ggml/ggml/include/ggml.h
+++ b/ml/backend/ggml/ggml/include/ggml.h
@@ -198,7 +198,7 @@
 
 #ifndef __GNUC__
 #    define GGML_ATTRIBUTE_FORMAT(...)
-#elif defined(__MINGW32__)
+#elif defined(__MINGW32__) && !defined(__clang__)
 #    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
 #else
 #    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
@@ -241,12 +241,6 @@
 #define GGML_ROPE_TYPE_MROPE  8
 #define GGML_ROPE_TYPE_VISION 24
 
-#define GGUF_MAGIC "GGUF"
-
-#define GGUF_VERSION 3
-
-#define GGUF_DEFAULT_ALIGNMENT 32
-
 #define GGML_UNUSED(x) (void)(x)
 
 #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
@@ -403,12 +397,6 @@ extern "C" {
         GGML_PREC_F32,
     };
 
-    enum ggml_backend_type {
-        GGML_BACKEND_TYPE_CPU = 0,
-        GGML_BACKEND_TYPE_GPU = 10,
-        GGML_BACKEND_TYPE_GPU_SPLIT = 20,
-    };
-
     // model file types
     enum ggml_ftype {
         GGML_FTYPE_UNKNOWN        = -1,
@@ -514,6 +502,7 @@ extern "C" {
         GGML_OP_GET_REL_POS,
         GGML_OP_ADD_REL_POS,
         GGML_OP_RWKV_WKV6,
+        GGML_OP_GATED_LINEAR_ATTN,
 
         GGML_OP_UNARY,
 
@@ -588,8 +577,6 @@ extern "C" {
     struct ggml_tensor {
         enum ggml_type type;
 
-        GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
-
         struct ggml_backend_buffer * buffer;
 
         int64_t ne[GGML_MAX_DIMS]; // number of elements
@@ -1398,16 +1385,20 @@ extern "C" {
             float                 scale,
             float                 max_bias);
 
-    GGML_API struct ggml_tensor * ggml_soft_max_back(
+    GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            float                 scale,
+            float                 max_bias);
 
     // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+    GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            float                 scale,
+            float                 max_bias);
 
     // rotary position embedding
     // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
@@ -1514,7 +1505,7 @@ extern "C" {
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
-    GGML_API struct ggml_tensor * ggml_rope_back(
+    GGML_API struct ggml_tensor * ggml_rope_ext_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a, // gradients of ggml_rope result
             struct ggml_tensor  * b, // positions
@@ -1529,6 +1520,23 @@ extern "C" {
             float                 beta_fast,
             float                 beta_slow);
 
+    GGML_API struct ggml_tensor * ggml_rope_multi_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   sections[4],
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+
     // clamp
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_clamp(
@@ -1777,7 +1785,7 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   k);
 
-#define GGML_KQ_MASK_PAD 32
+#define GGML_KQ_MASK_PAD 64
 
     // q:    [n_embd, n_batch,     n_head,    1]
     // k:    [n_embd, n_kv,        n_head_kv, 1]
@@ -1883,6 +1891,15 @@ extern "C" {
             struct ggml_tensor  * td,
             struct ggml_tensor  * state);
 
+    GGML_API struct ggml_tensor * ggml_gated_linear_attn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * g,
+            struct ggml_tensor  * state,
+            float scale);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -2121,132 +2138,6 @@ extern "C" {
                    int64_t   n_per_row,
                const float * imatrix);
 
-    //
-    // gguf
-    //
-
-    enum gguf_type {
-        GGUF_TYPE_UINT8   = 0,
-        GGUF_TYPE_INT8    = 1,
-        GGUF_TYPE_UINT16  = 2,
-        GGUF_TYPE_INT16   = 3,
-        GGUF_TYPE_UINT32  = 4,
-        GGUF_TYPE_INT32   = 5,
-        GGUF_TYPE_FLOAT32 = 6,
-        GGUF_TYPE_BOOL    = 7,
-        GGUF_TYPE_STRING  = 8,
-        GGUF_TYPE_ARRAY   = 9,
-        GGUF_TYPE_UINT64  = 10,
-        GGUF_TYPE_INT64   = 11,
-        GGUF_TYPE_FLOAT64 = 12,
-        GGUF_TYPE_COUNT,       // marks the end of the enum
-    };
-
-    struct gguf_context;
-
-    struct gguf_init_params {
-        bool no_alloc;
-
-        // if not NULL, create a ggml_context and allocate the tensor data in it
-        struct ggml_context ** ctx;
-    };
-
-    GGML_API struct gguf_context * gguf_init_empty(void);
-    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
-    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
-
-    GGML_API void gguf_free(struct gguf_context * ctx);
-
-    GGML_API const char * gguf_type_name(enum gguf_type type);
-
-    GGML_API int    gguf_get_version    (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_alignment  (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
-    GGML_API void * gguf_get_data       (const struct gguf_context * ctx);
-
-    GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
-    GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
-    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
-
-    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
-    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
-
-    // will abort if the wrong type is used for the key
-    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int key_id);
-    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int key_id);
-    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
-    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
-    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
-    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
-    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
-    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
-    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
-    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
-
-    GGML_API int            gguf_get_n_tensors    (const struct gguf_context * ctx);
-    GGML_API int            gguf_find_tensor      (const struct gguf_context * ctx, const char * name);
-    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
-    GGML_API char *         gguf_get_tensor_name  (const struct gguf_context * ctx, int i);
-    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int i);
-
-    // removes key if it exists
-    GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key);
-
-    // overrides existing values or adds a new one
-    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t  val);
-    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t   val);
-    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
-    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t  val);
-    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
-    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t  val);
-    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float    val);
-    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
-    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t  val);
-    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double   val);
-    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool     val);
-    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
-    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
-    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
-
-    // set or add KV pairs from another context
-    GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
-
-    // manage tensor info
-    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
-    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
-    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
-
-    // writing gguf files can be done in 2 ways:
-    //
-    // - write the entire gguf_context to a binary file in a single pass:
-    //
-    //   gguf_write_to_file(ctx, fname);
-    //
-    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
-    //
-    //   FILE * f = fopen(fname, "wb");
-    //   fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
-    //   fwrite(f, ...);
-    //   void * data = gguf_meta_get_meta_data(ctx);
-    //   fseek(f, 0, SEEK_SET);
-    //   fwrite(f, data, gguf_get_meta_size(ctx));
-    //   free(data);
-    //   fclose(f);
-    //
-
-    // write the entire context to a binary file
-    GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
-
-    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
-    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
-    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
-
 #ifdef __cplusplus
     // restrict not standard in C++
 #    if defined(__GNUC__)
diff --git a/ml/backend/ggml/ggml/include/gguf.h b/ml/backend/ggml/ggml/include/gguf.h
new file mode 100644
index 000000000..79ee20206
--- /dev/null
+++ b/ml/backend/ggml/ggml/include/gguf.h
@@ -0,0 +1,202 @@
+// This file contains functionality related to "GGUF" files, the binary file format used by ggml.
+// GGUF files have the following structure:
+//
+// 1. File magic "GGUF" (4 bytes).
+// 2. File version (uint32_t).
+// 3. Number of ggml tensors in file (int64_t).
+// 4. Number of key-value-pairs in file (int64_t).
+// 5. For each KV pair:
+//   1. The key (string).
+//   2. The value type (gguf_type).
+//   3a. If the value type is GGUF_TYPE_ARRAY:
+//     1. The type of the array (gguf_type).
+//     2. The number of elements in the array (uint64_t).
+//     3. The binary representation of each element in the array.
+//   3b. Otherwise:
+//     1. The binary representation of the value.
+// 6. For each ggml tensor:
+//   1. The tensor name (string).
+//   2. The number of dimensions of the tensor (uint32_t).
+//   3. For each dimension:
+//     1. The size of the tensor in the dimension (int64_t).
+//   4. The tensor data type (ggml_type).
+//   5. The tensor data offset in the tensor data binary blob (uint64_t).
+// 7. The tensor data binary blob (optional, aligned).
+//
+// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator.
+// All enums are stored as int32_t.
+// All bool values are stored as int8_t.
+// If the special key "general.alignment" (uint32_t) is defined it is used for alignment,
+//   otherwise GGUF_DEFAULT_ALIGNMENT is used.
+//
+// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
+
+#pragma once
+
+#include "ggml.h"
+
+#include 
+#include 
+
+#define GGUF_MAGIC   "GGUF"
+#define GGUF_VERSION 3
+
+#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment"
+
+#define GGUF_DEFAULT_ALIGNMENT 32
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    // types that can be stored as GGUF KV data
+    enum gguf_type {
+        GGUF_TYPE_UINT8   = 0,
+        GGUF_TYPE_INT8    = 1,
+        GGUF_TYPE_UINT16  = 2,
+        GGUF_TYPE_INT16   = 3,
+        GGUF_TYPE_UINT32  = 4,
+        GGUF_TYPE_INT32   = 5,
+        GGUF_TYPE_FLOAT32 = 6,
+        GGUF_TYPE_BOOL    = 7,
+        GGUF_TYPE_STRING  = 8,
+        GGUF_TYPE_ARRAY   = 9,
+        GGUF_TYPE_UINT64  = 10,
+        GGUF_TYPE_INT64   = 11,
+        GGUF_TYPE_FLOAT64 = 12,
+        GGUF_TYPE_COUNT,       // marks the end of the enum
+    };
+
+    struct gguf_context;
+
+    struct gguf_init_params {
+        bool no_alloc;
+
+        // if not NULL, create a ggml_context and allocate the tensor data in it
+        struct ggml_context ** ctx;
+    };
+
+    GGML_API struct gguf_context * gguf_init_empty(void);
+    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
+    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
+
+    GGML_API void gguf_free(struct gguf_context * ctx);
+
+    GGML_API const char * gguf_type_name(enum gguf_type type);
+
+    GGML_API uint32_t gguf_get_version    (const struct gguf_context * ctx);
+    GGML_API size_t   gguf_get_alignment  (const struct gguf_context * ctx);
+    GGML_API size_t   gguf_get_data_offset(const struct gguf_context * ctx);
+
+    GGML_API int64_t      gguf_get_n_kv(const struct gguf_context * ctx);
+    GGML_API int64_t      gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found
+    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id);
+
+    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id);
+
+    // will abort if the wrong type is used for the key
+    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id);
+    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id);
+    GGML_API size_t       gguf_get_arr_n   (const struct gguf_context * ctx, int64_t key_id);
+
+    // get raw pointer to the first element of the array with the given key_id
+    // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
+    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
+
+    // get ith C string from array with given key_id
+    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
+
+    GGML_API int64_t        gguf_get_n_tensors    (const struct gguf_context * ctx);
+    GGML_API int64_t        gguf_find_tensor      (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found
+    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API const char *   gguf_get_tensor_name  (const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API size_t         gguf_get_tensor_size  (const struct gguf_context * ctx, int64_t tensor_id);
+
+    // removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist)
+    GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key);
+
+    // overrides an existing KV pair or adds a new one, the new KV pair is always at the back
+    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t      val);
+    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t       val);
+    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t     val);
+    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t      val);
+    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t     val);
+    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t      val);
+    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float        val);
+    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t     val);
+    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t      val);
+    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double       val);
+    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool         val);
+    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
+
+    // creates a new array with n elements of the given type and copies the corresponding number of bytes from data
+    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n);
+
+    // creates a new array with n strings and copies the corresponding strings from data
+    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n);
+
+    // set or add KV pairs from another context
+    GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src);
+
+    // add tensor to GGUF context, tensor name must be unique
+    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
+
+    // after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated
+    //   in such a way that the tensor data remains as one contiguous block (except for padding)
+    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
+
+    // assumes that at least gguf_get_tensor_size bytes can be read from data
+    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data);
+
+    // writing gguf files can be done in 3 ways:
+    //
+    // - write the entire gguf_context to a binary file in a single pass:
+    //
+    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ false);
+    //
+    // - write only the meta data to a file, then re-open the file and append the tensor data:
+    //
+    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ true);
+    //   FILE * f = fopen(fname, "ab");
+    //   fwrite(f, ...); // write tensor data
+    //   fclose(f);
+    //
+    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
+    //
+    //   FILE * f = fopen(fname, "wb");
+    //   const size_t size_meta = gguf_get_meta_size(ctx);
+    //   fseek(f, size_meta, SEEK_SET);
+    //   fwrite(f, ...); // write tensor data
+    //   void * data = malloc(size_meta);
+    //   gguf_get_meta_data(ctx, data);
+    //   rewind(f);
+    //   fwrite(data, 1, data, f);
+    //   free(data);
+    //   fclose(f);
+    //
+
+    // write the entire context to a binary file
+    GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
+
+    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
+    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
+
+    // writes the meta data to pointer "data"
+    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt
index 508287172..4564df91a 100644
--- a/ml/backend/ggml/ggml/src/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/CMakeLists.txt
@@ -93,12 +93,18 @@ endif()
 
 if (GGML_CCACHE)
     find_program(GGML_CCACHE_FOUND ccache)
+    find_program(GGML_SCCACHE_FOUND sccache)
 
-    if (GGML_CCACHE_FOUND)
+    if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND)
+        if(GGML_CCACHE_FOUND)
+            set(GGML_CCACHE_VARIANT ccache)
+        else()
+            set(GGML_CCACHE_VARIANT sccache)
+        endif()
         # TODO: should not be set globally
-        set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
+        set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
         set(ENV{CCACHE_SLOPPINESS} time_macros)
-        message(STATUS "ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
+        message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
     else()
         message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF")
     endif ()
@@ -208,6 +214,7 @@ add_library(ggml-base
             ../include/ggml-backend.h
             ../include/ggml-cpp.h
             ../include/ggml-opt.h
+            ../include/gguf.h
             ggml.c
             ggml-alloc.c
             ggml-backend.cpp
@@ -215,7 +222,8 @@ add_library(ggml-base
             ggml-threading.cpp
             ggml-threading.h
             ggml-quants.c
-            ggml-quants.h)
+            ggml-quants.h
+            gguf.cpp)
 
 target_include_directories(ggml-base PRIVATE .)
 
@@ -248,6 +256,17 @@ function(ggml_add_backend_library backend)
         target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD)
         target_compile_definitions(${backend} PUBLIC  GGML_BACKEND_SHARED)
     endif()
+
+    if(NOT GGML_AVAILABLE_BACKENDS)
+        set(GGML_AVAILABLE_BACKENDS "${backend}"
+            CACHE INTERNAL "List of backends for cmake package")
+    else()
+        list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend)
+        if(has_backend EQUAL -1)
+            set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}"
+                CACHE INTERNAL "List of backends for cmake package")
+        endif()
+    endif()
 endfunction()
 
 function(ggml_add_backend backend)
@@ -293,7 +312,7 @@ if (GGML_CPU_ALL_VARIANTS)
     ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 FMA AVX512)
     ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
     ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 FMA AVX_VNNI)
-else ()
+elseif (GGML_CPU)
     ggml_add_cpu_backend_variant_impl("")
 endif()
 
diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c
index 8dc8226ac..7244a9cbb 100644
--- a/ml/backend/ggml/ggml/src/ggml-alloc.c
+++ b/ml/backend/ggml/ggml/src/ggml-alloc.c
@@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
     return true;
 }
 
+// ops that return true for this function must not use restrict pointers for their backend implementations
 static bool ggml_op_can_inplace(enum ggml_op op) {
     switch (op) {
         case GGML_OP_SCALE:
@@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
         case GGML_OP_LOG:
         case GGML_OP_UNARY:
         case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
+        case GGML_OP_SILU_BACK:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_RMS_NORM_BACK:
         case GGML_OP_SOFT_MAX:
+        case GGML_OP_SOFT_MAX_BACK:
             return true;
 
         default:
@@ -984,19 +989,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
             this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
         }
 
-        if (this_size > max_size) {
-            GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
-                    __func__, t->name,
-                    ggml_backend_buft_name(buft),
-                    this_size, max_size);
-            for (size_t i = 0; i < n_buffers; i++) {
-                ggml_backend_buffer_free(buffers[i]);
-            }
-            free(buffers);
-            return NULL;
-        }
-
-        if ((cur_buf_size + this_size) > max_size) {
+        if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) {
             // allocate tensors in the current buffer
             if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
                 return NULL;
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-impl.h b/ml/backend/ggml/ggml/src/ggml-backend-impl.h
index 36d72e95f..d1c2d76d8 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-backend-impl.h
@@ -208,7 +208,6 @@ extern "C" {
 
     // Internal backend registry API
     GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
-    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
 
     // Add backend dynamic loading support to the backend
 
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
index e35a69369..c854e6bb2 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
@@ -552,4 +552,9 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
     ggml_backend_load_best("opencl", silent, dir_path);
     ggml_backend_load_best("musa", silent, dir_path);
     ggml_backend_load_best("cpu", silent, dir_path);
+    // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
+    const char * backend_path = std::getenv("GGML_BACKEND_PATH");
+    if (backend_path) {
+        ggml_backend_load(backend_path);
+    }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp
index a12172dcd..1ca40b2c4 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp
@@ -763,7 +763,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
         if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
             int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
             // check if a backend with higher prio wants to offload the op
-            if (src_backend_id == sched->n_backends - 1) {
+            if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
                 for (int b = 0; b < src_backend_id; b++) {
                     if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
                         SET_CAUSE(tensor, "1.off");
diff --git a/ml/backend/ggml/ggml/src/ggml-common.h b/ml/backend/ggml/ggml/src/ggml-common.h
index f13fd4dea..6c02b69ea 100644
--- a/ml/backend/ggml/ggml/src/ggml-common.h
+++ b/ml/backend/ggml/ggml/src/ggml-common.h
@@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
     240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
 GGML_TABLE_END()
 
-//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
 GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
     0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
@@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
     0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
 GGML_TABLE_END()
-//#endif
 
 
 GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
index 6b3641c42..aa5ad5d8d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
@@ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 function(check_arm_feature tag code)
                     set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
                     set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
-                    check_cxx_source_runs(
-                        "${code}"
-                        GGML_MACHINE_SUPPORTS_${tag}
-                    )
+                    check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
                     if (GGML_MACHINE_SUPPORTS_${tag})
                         set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
                     else()
-                        set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
+                        set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
+                        check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
+                        if (GGML_MACHINE_SUPPORTS_no${tag})
+                            set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
+                        endif()
                     endif()
                     set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
                 endfunction()
@@ -126,6 +127,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 check_arm_feature(dotprod "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
                 check_arm_feature(i8mm    "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
                 check_arm_feature(sve     "#include \nint main()  { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
+                check_arm_feature(sme     "#include \n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
 
                 list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
             else()
@@ -150,7 +152,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             if (ARM_FEATURE_RESULT)
                 message(WARNING "Failed to get ARM features")
             else()
-                foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
+                foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
                     string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
                     if (NOT ${feature_pos} EQUAL -1)
                         message(STATUS "ARM feature ${feature} enabled")
@@ -308,6 +310,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         if (GGML_RVV)
             list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
         endif()
+    elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
+        message(STATUS "s390x detected")
+        file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
+        string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
+
+        # TODO: Separation to determine activation of VX/VXE/VXE2
+        if (${S390X_M} MATCHES "8561|8562")
+            message(STATUS "z15 target")
+            list(APPEND ARCH_FLAGS -march=z15 -mtune=z15)
+        elseif (${S390X_M} MATCHES "3931")
+            message(STATUS "z16 target")
+            list(APPEND ARCH_FLAGS -march=z16 -mtune=z16)
+        else()
+            message(STATUS "Unknown target")
+            message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
+            list(APPEND ARCH_FLAGS -march=native -mtune=native)
+        endif()
+
+        if (GGML_VXE)
+            list(APPEND ARCH_FLAGS -mvx -mzvector)
+        endif()
     else()
         message(STATUS "Unknown architecture")
     endif()
@@ -316,6 +339,94 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
     endif()
 
+    if (GGML_CPU_KLEIDIAI)
+        message(STATUS "Using KleidiAI optimized kernels if applicable")
+
+        # Disable the KleidiAI tests
+        set(KLEIDIAI_BUILD_TESTS  OFF)
+
+        # Fetch KleidiAI sources:
+        include(FetchContent)
+        set(KLEIDIAI_COMMIT_TAG "v1.3.0")
+        set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
+        set(KLEIDIAI_ARCHIVE_MD5  "060bd2dc64642b091f461cc8dd7426d9")
+
+        if (POLICY CMP0135)
+            cmake_policy(SET CMP0135 NEW)
+        endif()
+
+        FetchContent_Declare(KleidiAI_Download
+            URL ${KLEIDIAI_DOWNLOAD_URL}
+            DOWNLOAD_EXTRACT_TIMESTAMP NEW
+            URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
+
+        FetchContent_MakeAvailable(KleidiAI_Download)
+        FetchContent_GetProperties(KleidiAI_Download
+            SOURCE_DIR  KLEIDIAI_SRC
+            POPULATED   KLEIDIAI_POPULATED)
+
+        if (NOT KLEIDIAI_POPULATED)
+            message(FATAL_ERROR "KleidiAI source downloaded failed.")
+        endif()
+
+        add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
+
+        # Remove kleidiai target after fetching it
+        if (TARGET kleidiai)
+            set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
+        endif()
+
+        list(APPEND GGML_CPU_SOURCES
+            ggml-cpu/kleidiai/kleidiai.cpp
+            ggml-cpu/kleidiai/kernels.cpp
+            ggml-cpu/kleidiai/kleidiai.h
+            ggml-cpu/kleidiai/kernels.h
+            )
+
+        # KleidiAI
+        include_directories(
+            ${KLEIDIAI_SRC}/
+            ${KLEIDIAI_SRC}/kai/
+            ${KLEIDIAI_SRC}/kai/ukernels/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
+
+        set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
+        if (NOT ARCH_FLAGS_TEMP)
+            string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}")
+        endif()
+        string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
+        string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
+        string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
+
+        set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
+
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
+
+        if (NOT DOTPROD_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
+        endif()
+
+        if (NOT I8MM_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
+        endif()
+
+        if (NOT SME_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
+            set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
+        endif()
+
+        set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
+        list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
+    endif()
+
     message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}")
     target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})
     target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
index 622c63f1f..b311a5b1c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
@@ -4169,6 +4169,8 @@ static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(g
     buffer->buft              = buft;
     buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
     buffer->iface.set_tensor  = ggml_backend_cpu_aarch64_buffer_set_tensor;
+    buffer->iface.get_tensor  = nullptr;
+    buffer->iface.cpy_tensor  = nullptr;
     return buffer;
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
index d71076ad1..7f7d210cb 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -59,6 +59,15 @@ struct ggml_compute_params {
 #endif
 #endif
 
+#if defined(__s390x__) && defined(__VEC__)
+#ifndef __VXE__
+#define __VXE__
+#endif
+#ifndef __VXE2__
+#define __VXE2__
+#endif
+#endif
+
 #if defined(__ARM_FEATURE_SVE)
 #include 
 #include 
@@ -359,22 +368,158 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
 #endif
 #endif
 
-#if defined(__loongarch_asx)
+#if defined(__VXE__) || defined(__VXE2__)
+#include 
 
-typedef union {
-    int32_t i;
-    float f;
-} ft_union;
+#define vec_neg(a)    (-(a))                // Vector Negate
+#define vec_add(a, b) ((a) + (b))           // Vector Add
+#define vec_sub(a, b) ((a) - (b))           // Vector Subtract
+#define vec_mul(a, b) ((a) * (b))           // Vector Multiply
+#define vec_div(a, b) ((a) / (b))           // Vector Divide
+#define vec_sl(a, b)  ((a) << (b))          // Vector Shift Left
+#define vec_sra(a, b) ((a) >> (b))          // Vector Shift Right
+#define vec_sr(a, b)  ((a) >> (b))          // Vector Shift Right Algebraic
+#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet
+#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet
 
-/* float type data load instructions */
-static __m128 __lsx_vreplfr2vr_s(float val) {
-    ft_union fi_tmpval = {.f = val};
-    return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
+#ifndef vec_and
+#define vec_and(a, b) ((a) & (b)) // Vector AND
+#endif
+
+#ifndef vec_or
+#define vec_or(a, b)  ((a) | (b)) // Vector OR
+#endif
+
+#ifndef vec_xor
+#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
+#endif
+
+typedef signed char char8x16_t __attribute__((vector_size(16)));
+typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
+
+typedef int8_t  int8x16_t __attribute__((vector_size(16)));
+typedef int16_t int16x8_t __attribute__((vector_size(16)));
+typedef int32_t int32x4_t __attribute__((vector_size(16)));
+
+typedef uint8_t  uint8x16_t __attribute__((vector_size(16)));
+typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
+typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
+
+typedef float float32x4_t __attribute__((vector_size(16)));
+typedef double double64x2_t __attribute((vector_size(16)));
+
+typedef signed long long long64x2_t __attribute((vector_size(16)));
+typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
+
+typedef struct ggml_uint8x16x2_t {
+    uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) {
+    ggml_uint8x16x2_t res;
+
+    res.val[0] = vec_xl( 0, ptr);
+    res.val[1] = vec_xl(16, ptr);
+
+    return res;
 }
 
-static __m256 __lasx_xvreplfr2vr_s(float val) {
-    ft_union fi_tmpval = {.f = val};
-    return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
+typedef struct ggml_uint8x16x4_t {
+    uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) {
+    ggml_uint8x16x4_t res;
+
+    res.val[0] = vec_xl( 0, ptr);
+    res.val[1] = vec_xl(16, ptr);
+    res.val[2] = vec_xl(32, ptr);
+    res.val[3] = vec_xl(48, ptr);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+    int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) {
+    ggml_int8x16x4_t res;
+
+    res.val[0] = vec_xl( 0, ptr);
+    res.val[1] = vec_xl(16, ptr);
+    res.val[2] = vec_xl(32, ptr);
+    res.val[3] = vec_xl(48, ptr);
+
+    return res;
+}
+
+typedef struct ggml_int16x8x2_t {
+    int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) {
+    ggml_int16x8x2_t res;
+
+    res.val[0] = vec_xl( 0, ptr);
+    res.val[1] = vec_xl(16, ptr);
+
+    return res;
+}
+
+/*
+    ! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs
+    !          or iq4_nl for example implementation.
+*/
+inline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) {
+    int8x16_t res;
+
+    res[ 0] = a[b[ 0]];
+    res[ 1] = a[b[ 1]];
+    res[ 2] = a[b[ 2]];
+    res[ 3] = a[b[ 3]];
+    res[ 4] = a[b[ 4]];
+    res[ 5] = a[b[ 5]];
+    res[ 6] = a[b[ 6]];
+    res[ 7] = a[b[ 7]];
+    res[ 8] = a[b[ 8]];
+    res[ 9] = a[b[ 9]];
+    res[10] = a[b[10]];
+    res[11] = a[b[11]];
+    res[12] = a[b[12]];
+    res[13] = a[b[13]];
+    res[14] = a[b[14]];
+    res[15] = a[b[15]];
+
+    return res;
+}
+
+inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
+    const uchar8x16_t v_maske = {  0,  1,  4,  5,  8,  9, 12, 13,
+                                  16, 17, 20, 21, 24, 25, 28, 29 };
+
+    const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b);
+    const int16x8_t v_abe = vec_perm(a, b, v_maske);
+    return v_abo + v_abe;
+}
+
+inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
+    const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
+    return acc + (vec_unpackh(p) + vec_unpackl(p));
+}
+
+#endif
+
+#if defined(__loongarch_asx)
+/* float type data load instructions */
+static __m128 __lsx_vreplfr2vr_s(const float val) {
+    v4f32 res = {val, val, val, val};
+    return (__m128)res;
+}
+
+static __m256 __lasx_xvreplfr2vr_s(const float val) {
+    v8f32 res = {val, val, val, val, val, val, val, val};
+    return (__m256)res;
 }
 #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c
index 8e1472266..8d5e3e20b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c
@@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
 static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
 #endif
 
+#if defined(__loongarch_sx)
+
+static __m128i lsx_packs_w(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_w(a, 15);
+    tmp1 = __lsx_vsat_w(b, 15);
+    return __lsx_vpickev_h(tmp1, tmp);
+}
+
+static __m128i lsx_packs_h(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_h(a, 7);
+    tmp1 = __lsx_vsat_h(b, 7);
+    return __lsx_vpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_packus_h(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_hu(a, 7);
+    tmp1 = __lsx_vsat_hu(b, 7);
+    return __lsx_vpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
+    __m128i tmp1, tmp2;
+    tmp1 = __lsx_vmulwev_h_b(a, b);
+    tmp2 = __lsx_vmulwod_h_b(a, b);
+    return __lsx_vsadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_madd_h(__m128i a, __m128i b) {
+    __m128i tmp1, tmp2;
+    tmp1 = __lsx_vmulwev_w_h(a, b);
+    tmp2 = __lsx_vmulwod_w_h(a, b);
+    return __lsx_vadd_w(tmp1, tmp2);
+}
+
+static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
+    v4i32 __ret = {d, c, b, a};
+    return (__m128i)__ret;
+}
+
+static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
+    __m128i mask_f, zero, tmp0, tmp2, mask;
+    int f = 0x8f;
+    mask_f = __lsx_vreplgr2vr_b(f);
+    zero = __lsx_vldi(0);
+    tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
+    tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
+    mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
+    tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
+    return __lsx_vshuf_b(a, zero, tmp2);
+}
+
+static __m128i lsx_hadd_h(__m128i a, __m128i b) {
+    __m128i tmp1 = __lsx_vpickev_h(b, a);
+    __m128i tmp2 = __lsx_vpickod_h(b, a);
+    return __lsx_vadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_hadd_w(__m128i a, __m128i b) {
+    __m128i tmp1 = __lsx_vpickev_w(b, a);
+    __m128i tmp2 = __lsx_vpickod_w(b, a);
+    return __lsx_vadd_w(tmp1, tmp2);
+}
+
+static __m128 lsx_hadd_s(__m128 a, __m128 b) {
+    __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
+    __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
+
+    return __lsx_vfadd_s(tmp1, tmp2);
+}
+
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+    __m128 res_0 =lsx_hadd_s(a, b);
+    __m128 res_1 =lsx_hadd_s(c, d);
+    __m128 res =lsx_hadd_s(res_0, res_1);
+    res =lsx_hadd_s(res, res);
+    res =lsx_hadd_s(res, res);
+
+    return ((v4f32)res)[0];
+}
+#endif
+
 #if defined(__loongarch_asx)
 
 #ifdef __clang__
@@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
     return (__m256i)__ret;
 }
 
-static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
-    v4i32 __ret = {d, c, b, a};
-    return (__m128i)__ret;
-}
-
 static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
     v4i64 __ret = {d, c, b, a};
     return (__m256i)__ret;
@@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
     return lasx_set_q(x, y);
 }
 
-static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
-    __m128i mask_f, zero, tmp0, tmp2, mask;
-    int f = 0x8f;
-    mask_f = __lsx_vreplgr2vr_b(f);
-    zero = __lsx_vldi(0);
-    tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
-    tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
-    mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
-    tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
-    return __lsx_vshuf_b(a, zero, tmp2);
-}
-
 static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
     __m256i mask_f, zero, tmp0, tmp2, mask;
     int f = 0x8f;
@@ -434,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
 }
 
 static __m256i lasx_extu8_16(__m128i a) {
-    __m128i zero = __lsx_vldi(0);
-    __m128i vlo = __lsx_vilvl_b(zero, a);
-    __m128i vhi = __lsx_vilvh_b(zero, a);
-    return lasx_set_q(vhi, vlo);
+    return __lasx_vext2xv_hu_bu(____m256i(a));
 }
 
 static __m256i lasx_ext8_16(__m128i a) {
-     __m128i sign = __lsx_vslti_b(a, 0);
-     __m128i vlo = __lsx_vilvl_b(sign, a);
-     __m128i vhi = __lsx_vilvh_b(sign, a);
-     return lasx_set_q(vhi, vlo);
+    return __lasx_vext2xv_h_b(____m256i(a));
 }
 
 static __m256i lasx_ext16_32(__m128i a) {
-    __m256i tmp1;
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
-    return tmp1;
+    return __lasx_vext2xv_w_h(____m256i(a));
 }
 
 static __m128i lasx_extracti128( __m256i a, int pos) {
@@ -482,25 +534,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
     return ret;
 }
 
-static __m128i lsx_hadd_h(__m128i a, __m128i b) {
-    __m128i tmp1 = __lsx_vpickev_h(b, a);
-    __m128i tmp2 = __lsx_vpickod_h(b, a);
-    return __lsx_vadd_h(tmp1, tmp2);
-}
-
-static __m128i lsx_hadd_w(__m128i a, __m128i b) {
-    __m128i tmp1 = __lsx_vpickev_w(b, a);
-    __m128i tmp2 = __lsx_vpickod_w(b, a);
-    return __lsx_vadd_w(tmp1, tmp2);
-}
-
-static __m128 lsx_hadd_s(__m128 a, __m128 b) {
-    __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
-    __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
-
-    return __lsx_vfadd_s(tmp1, tmp2);
-}
-
 static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
     __m256i tmp1, tmp2;
     tmp1 = __lasx_xvmulwev_h_b(a, b);
@@ -529,40 +562,39 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
     return __lasx_xvpickev_b(tmp1, tmp);
 }
 
-static __m128i lsx_packs_w(__m128i a, __m128i b) {
-    __m128i tmp, tmp1;
-    tmp = __lsx_vsat_w(a, 15);
-    tmp1 = __lsx_vsat_w(b, 15);
-    return __lsx_vpickev_h(tmp1, tmp);
+static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
+    __m256i tmp1, tmp2;
+    tmp1 = __lasx_xvmulwev_h_b(a, b);
+    tmp2 = __lasx_xvmulwod_h_b(a, b);
+    return __lasx_xvadd_h(tmp1, tmp2);
 }
 
-static __m128i lsx_packs_h(__m128i a, __m128i b) {
-    __m128i tmp, tmp1;
-    tmp = __lsx_vsat_h(a, 7);
-    tmp1 = __lsx_vsat_h(b, 7);
-    return __lsx_vpickev_b(tmp1, tmp);
+static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
+    switch (b) {
+        case 0: return __lasx_xvrepl128vei_h(a, 0);
+        case 1: return __lasx_xvrepl128vei_h(a, 1);
+        case 2: return __lasx_xvrepl128vei_h(a, 2);
+        case 3: return __lasx_xvrepl128vei_h(a, 3);
+        case 4: return __lasx_xvrepl128vei_h(a, 4);
+        case 5: return __lasx_xvrepl128vei_h(a, 5);
+        case 6: return __lasx_xvrepl128vei_h(a, 6);
+        case 7: return __lasx_xvrepl128vei_h(a, 7);
+        default: __builtin_unreachable();
+    }
 }
 
-static __m128i lsx_packus_h(__m128i a, __m128i b) {
-    __m128i tmp, tmp1;
-    tmp = __lsx_vsat_hu(a, 7);
-    tmp1 = __lsx_vsat_hu(b, 7);
-    return __lsx_vpickev_b(tmp1, tmp);
-}
-
-
-static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
-    __m128i tmp1, tmp2;
-    tmp1 = __lsx_vmulwev_h_b(a, b);
-    tmp2 = __lsx_vmulwod_h_b(a, b);
-    return __lsx_vsadd_h(tmp1, tmp2);
-}
-
-static __m128i lsx_madd_h(__m128i a, __m128i b) {
-    __m128i tmp1, tmp2;
-    tmp1 = __lsx_vmulwev_w_h(a, b);
-    tmp2 = __lsx_vmulwod_w_h(a, b);
-    return __lsx_vadd_w(tmp1, tmp2);
+static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
+    switch (b) {
+        case 0: return __lasx_xvandi_b(a, 1 << 0);
+        case 1: return __lasx_xvandi_b(a, 1 << 1);
+        case 2: return __lasx_xvandi_b(a, 1 << 2);
+        case 3: return __lasx_xvandi_b(a, 1 << 3);
+        case 4: return __lasx_xvandi_b(a, 1 << 4);
+        case 5: return __lasx_xvandi_b(a, 1 << 5);
+        case 6: return __lasx_xvandi_b(a, 1 << 6);
+        case 7: return __lasx_xvandi_b(a, 1 << 7);
+        default: __builtin_unreachable();
+    }
 }
 
 // multiply int8_t, add results pairwise twice
@@ -580,12 +612,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
 // horizontally add 8 floats
 static inline float hsum_float_8(const __m256 x) {
     __m128 res = lasx_extractf128(x, 1);
-    ft_union tmp;
     res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
     res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
     res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
-    tmp.i = __lsx_vpickve2gr_w(res, 0);
-    return tmp.f;
+    return ((v4f32)res)[0];
 }
 
 // horizontally add 8 int32_t
@@ -661,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy)
 
 // multiply int8_t, add results pairwise twice and return as float vector
 static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-
-    // Get absolute values of x vectors
-    const __m256i ax = __lasx_xvsigncov_b(x, x);
-    // Sign the values of the y vectors
-    const __m256i sy = __lasx_xvsigncov_b(x, y);
-
-    return mul_sum_us8_pairs_float(ax, sy);
+    const __m256i dot = lasx_madd_h_b(x, y);
+    return sum_i16_pairs_float(dot);
 }
 
 static inline __m128i packNibbles( __m256i bytes ) {
@@ -747,7 +772,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
             y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
         }
     }
-#elif defined(__wasm_simd128__)
+#elif defined __wasm_simd128__
     for (int i = 0; i < nb; i++) {
         v128_t srcv [8];
         v128_t asrcv[8];
@@ -927,7 +952,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
 
 #elif defined(__loongarch_asx)
     for (int i = 0; i < nb; i++) {
-        ft_union fi;
         __m256 v0 = (__m256)__lasx_xvld( x , 0);
         __m256 v1 = (__m256)__lasx_xvld( x , 32);
         __m256 v2 = (__m256)__lasx_xvld( x , 64);
@@ -945,8 +969,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
         __m128 tmp = max4;
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
-        fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
-        const float max_scalar = fi.f;
+        const float max_scalar = ((v4f32)max4)[0];
 
         // Quantize these floats
         const float d = max_scalar / 127.f;
@@ -988,6 +1011,38 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
         __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
 
     }
+#elif defined(__VXE__) || defined(__VXE2__)
+    for (int i = 0; i < nb; i++) {
+        __vector float srcv [8];
+        __vector float asrcv[8];
+        __vector float amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f / d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const __vector float v = vec_mul(srcv[j], vec_splats(id));
+            const __vector int32_t vi = vec_signed(v);
+
+            y[i].qs[4*j + 0] = vec_extract(vi, 0);
+            y[i].qs[4*j + 1] = vec_extract(vi, 1);
+            y[i].qs[4*j + 2] = vec_extract(vi, 2);
+            y[i].qs[4*j + 3] = vec_extract(vi, 3);
+        }
+    }
 #else
     GGML_UNUSED(nb);
     // scalar
@@ -1037,7 +1092,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
 
         y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
     }
-#elif defined(__wasm_simd128__)
+#elif defined __wasm_simd128__
     for (int i = 0; i < nb; i++) {
         v128_t srcv [8];
         v128_t asrcv[8];
@@ -1251,7 +1306,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
 
 #elif defined(__loongarch_asx)
     for (int i = 0; i < nb; i++) {
-        ft_union ft;
         __m256 v0 = (__m256)__lasx_xvld( x , 0 );
         __m256 v1 = (__m256)__lasx_xvld( x , 32 );
         __m256 v2 = (__m256)__lasx_xvld( x , 64 );
@@ -1269,8 +1323,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
         __m128 tmp = max4;
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
-        ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
-        const float max_scalar = ft.f;
+        const float max_scalar = ((v4f32)max4)[0];
 
         // Quantize these floats
         const float d = max_scalar / 127.f;
@@ -1316,6 +1369,44 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
         __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);
         __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
     }
+#elif defined(__VXE__) || defined(__VXE2__)
+    for (int i = 0; i < nb; i++) {
+        __vector float srcv [8];
+        __vector float asrcv[8];
+        __vector float amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f / d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        __vector int32_t acc = vec_splats(0);
+
+        for (int j = 0; j < 8; j++) {
+            const __vector float v = vec_mul(srcv[j], vec_splats(id));
+            const __vector int32_t vi = vec_signed(v);
+
+            y[i].qs[4*j + 0] = vec_extract(vi, 0);
+            y[i].qs[4*j + 1] = vec_extract(vi, 1);
+            y[i].qs[4*j + 2] = vec_extract(vi, 2);
+            y[i].qs[4*j + 3] = vec_extract(vi, 3);
+
+            acc = vec_add(acc, vi);
+        }
+
+        y[i].s = GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3]));
+    }
 #else
     GGML_UNUSED(nb);
     // scalar
@@ -1653,7 +1744,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
 //===================================== Q8_K ==============================================
 
 void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
+#ifdef __wasm_simd128__
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+    block_q8_K * restrict yc = y; // Cast to proper type
+
+    for (int i = 0; i < nb; i++) {
+        const float * x_block = x + i * QK_K;
+
+        v128_t min_vec = wasm_v128_load(x_block);
+        v128_t max_vec = min_vec;
+
+        for (int j = 4; j < QK_K; j += 4) {
+            v128_t x_vec = wasm_v128_load(x_block + j);
+            max_vec = wasm_f32x4_pmax(max_vec, x_vec);
+            min_vec = wasm_f32x4_pmin(min_vec, x_vec);
+        }
+        max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
+        max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
+        min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
+        min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
+        float max = wasm_f32x4_extract_lane(max_vec, 0);
+        float min = wasm_f32x4_extract_lane(min_vec, 0);
+        float amax = -min > max ? min : max;
+
+        if (amax == 0.0f) {
+            yc[i].d = 0.0f;
+            const v128_t zero = wasm_i8x16_splat(0);
+            for (int j = 0; j < QK_K; j += 16) {
+                wasm_v128_store(yc[i].qs + j, zero);
+            }
+            continue;
+        }
+
+        const float iscale = -127.0f / amax;
+        const v128_t scale_vec = wasm_f32x4_splat(iscale);
+
+        // Process 16 elements per iteration
+        for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
+            // Load and quantize 16 floats
+            v128_t x0 = wasm_v128_load(x_block + j);
+            v128_t x1 = wasm_v128_load(x_block + j + 4);
+            v128_t x2 = wasm_v128_load(x_block + j + 8);
+            v128_t x3 = wasm_v128_load(x_block + j + 12);
+
+            v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
+            v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
+            v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
+            v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
+
+            // Convert to i32 with saturation
+            v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
+            v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
+            v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
+            v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
+
+            // Pack into 16 i8 values
+            v128_t i8 = wasm_i8x16_narrow_i16x8(
+                wasm_i16x8_narrow_i32x4(i0, i1),
+                wasm_i16x8_narrow_i32x4(i2, i3)
+            );
+            wasm_v128_store(yc[i].qs + j, i8);
+
+            // Calculate bsums using SIMD
+            v128_t sum16 = wasm_i16x8_add(
+                wasm_i16x8_extend_low_i8x16(i8),
+                wasm_i16x8_extend_high_i8x16(i8)
+            );
+            v128_t sum32 = wasm_i32x4_add(
+                wasm_i32x4_extend_low_i16x8(sum16),
+                wasm_i32x4_extend_high_i16x8(sum16)
+            );
+            sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
+            sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
+            yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
+        }
+
+        yc[i].d = 1.0f / iscale;
+    }
+#else
     quantize_row_q8_K_ref(x, y, k);
+#endif
 }
 
 //===================================== Dot products =================================
@@ -2011,6 +2182,94 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined __wasm_simd128__
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    const v128_t m4b = wasm_i8x16_splat(0x0F);
+    const v128_t s8b = wasm_i8x16_splat(0x8);
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q4_0 * restrict x0 = &x[ib];
+        const block_q4_0 * restrict x1 = &x[ib + 1];
+        const block_q8_0 * restrict y0 = &y[ib];
+        const block_q8_0 * restrict y1 = &y[ib + 1];
+
+        // Load and process x0
+        v128_t v0_0 = wasm_v128_load(x0->qs);
+        v128_t v0_0l = wasm_v128_and(v0_0, m4b);
+        v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
+        v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
+        v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
+
+        // Load y0 vectors
+        v128_t y0_l = wasm_v128_load(y0->qs);
+        v128_t y0_h = wasm_v128_load(y0->qs + 16);
+
+        // Extend to i16x8 and compute dot products
+        v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
+        v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
+        v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
+        v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
+
+        v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
+        v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
+        v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
+        v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
+
+        v128_t dp0 = wasm_i32x4_add(
+            wasm_i32x4_add(
+                wasm_i32x4_dot_i16x8(dx0l, dy0ll),
+                wasm_i32x4_dot_i16x8(dx0h, dy0lh)
+            ),
+            wasm_i32x4_add(
+                wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
+                wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
+            )
+        );
+
+        // Load and process x1
+        v128_t v0_1 = wasm_v128_load(x1->qs);
+        v128_t v0_1l = wasm_v128_and(v0_1, m4b);
+        v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
+        v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
+        v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
+
+        // Load y1 vectors
+        v128_t y1_l = wasm_v128_load(y1->qs);
+        v128_t y1_h = wasm_v128_load(y1->qs + 16);
+
+        // Extend to i16x8 and compute dot products
+        v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
+        v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
+        v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
+        v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
+
+        v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
+        v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
+        v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
+        v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
+
+        v128_t dp1 = wasm_i32x4_add(
+            wasm_i32x4_add(
+                wasm_i32x4_dot_i16x8(dx1l, dy1ll),
+                wasm_i32x4_dot_i16x8(dx1h, dy1lh)
+            ),
+            wasm_i32x4_add(
+                wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
+                wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
+            )
+        );
+
+        // Accumulate results with scaling
+        float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
+        float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d);
+
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
+    }
+
+    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
 #elif defined(__AVX2__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -2232,21 +2491,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = hsum_float_8(acc);
+
 #elif defined(__loongarch_sx)
     // set constants
     const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
     const __m128i off = __lsx_vreplgr2vr_b(8);
 
     // Initialize accumulator with zeros
-    __m128 acc_0 = __lsx_vldi(0);
-    __m128 acc_1 = __lsx_vldi(0);
-    __m128 acc_2 = __lsx_vldi(0);
-    __m128 acc_3 = __lsx_vldi(0);
+    __m128 acc_0 = (__m128)__lsx_vldi(0);
+    __m128 acc_1 = (__m128)__lsx_vldi(0);
+    __m128 acc_2 = (__m128)__lsx_vldi(0);
+    __m128 acc_3 = (__m128)__lsx_vldi(0);
 
     for (; ib + 1 < nb; ib += 2) {
 
         // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+        const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
 
         const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
 
@@ -2264,7 +2524,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
+        const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
 
         const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
 
@@ -2298,6 +2558,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__VXE__) || defined(__VXE2__)
+    __vector float acc = vec_splats(0.0f);
+
+    const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F);
+    const __vector int8_t  v_s = vec_splats( (const int8_t)0x08);
+
+    for (; ib < nb; ++ib) {
+        const __vector uint8_t v_x = vec_xl(0, x[ib].qs);
+        const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m);
+        const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4);
+
+        const __vector int8_t v_xls = vec_sub(v_xl, v_s);
+        const __vector int8_t v_xhs = vec_sub(v_xh, v_s);
+
+        const __vector int8_t v_yl = vec_xl(0      , y[ib].qs);
+        const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
+
+        const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl);
+        const __vector int16_t v_xylse = vec_mule(v_xls, v_yl);
+        const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh);
+        const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh);
+
+        __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
+
+        const __vector float v_xy = vec_float(vec_unpackh(v_xy_));
+        const __vector float v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        acc = vec_madd(v_xy, v_d, acc);
+    }
+
+    sumf = acc[0] + acc[1] + acc[2] + acc[3];
 #endif
     for (; ib < nb; ++ib) {
         int sumi0 = 0;
@@ -2591,6 +2882,35 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = hsum_float_8(acc) + summs;
+#elif defined(__VXE__) || defined(__VXE2__)
+    float summs = 0;
+    float32x4_t acc = vec_splats(0.0f);
+
+    const uint8x16_t v_m = vec_splat_u8(0x0F);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        const uint8x16_t v_x = vec_xl(0, x[ib].qs);
+        const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
+        const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
+
+        const int8x16_t v_yl = vec_xl(0      , y[ib].qs);
+        const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs);
+
+        const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
+        const float32x4_t v_xy = vec_float(v_xy_);
+
+        const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        acc = vec_madd(v_xy, v_d, acc);
+    }
+
+    sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
 #endif
     for (; ib < nb; ++ib) {
         int sumi0 = 0;
@@ -2696,10 +3016,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__wasm_simd128__)
+#elif defined __wasm_simd128__
     v128_t sumv = wasm_f32x4_splat(0.0f);
 
-    uint32_t qh;
+    uint32_t qh_;
     uint64_t tmp[4];
 
     // TODO: check if unrolling this is better
@@ -2710,12 +3030,12 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         const v128_t m4b  = wasm_i8x16_splat(0x0F);
 
         // extract the 5th bit
-        memcpy(&qh, x0->qh, sizeof(qh));
+        memcpy(&qh_, x0->qh, sizeof(qh_));
 
-        tmp[0] = table_b2b_1[(qh >>  0) & 0xFF];
-        tmp[1] = table_b2b_1[(qh >>  8) & 0xFF];
-        tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
-        tmp[3] = table_b2b_1[(qh >> 24)       ];
+        tmp[0] = table_b2b_1[(qh_ >>  0) & 0xFF];
+        tmp[1] = table_b2b_1[(qh_ >>  8) & 0xFF];
+        tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
+        tmp[3] = table_b2b_1[(qh_ >> 24)       ];
 
         const v128_t qhl = wasm_v128_load(tmp + 0);
         const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3057,12 +3377,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
-#elif defined(__wasm_simd128__)
+#elif defined __wasm_simd128__
     v128_t sumv = wasm_f32x4_splat(0.0f);
 
     float summs = 0.0f;
 
-    uint32_t qh;
+    uint32_t qh_;
     uint64_t tmp[4];
 
     // TODO: check if unrolling this is better
@@ -3075,12 +3395,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
         const v128_t m4b = wasm_i8x16_splat(0x0F);
 
         // extract the 5th bit
-        memcpy(&qh, x0->qh, sizeof(qh));
+        memcpy(&qh_, x0->qh, sizeof(qh_));
 
-        tmp[0] = table_b2b_0[(qh >>  0) & 0xFF];
-        tmp[1] = table_b2b_0[(qh >>  8) & 0xFF];
-        tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
-        tmp[3] = table_b2b_0[(qh >> 24)       ];
+        tmp[0] = table_b2b_0[(qh_ >>  0) & 0xFF];
+        tmp[1] = table_b2b_0[(qh_ >>  8) & 0xFF];
+        tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
+        tmp[3] = table_b2b_0[(qh_ >> 24)       ];
 
         const v128_t qhl = wasm_v128_load(tmp + 0);
         const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3573,6 +3893,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined __wasm_simd128__
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    for (; ib < nb; ++ib) {
+        const block_q8_0 * restrict x0 = &x[ib];
+        const block_q8_0 * restrict y0 = &y[ib];
+
+        const v128_t x0_0 = wasm_v128_load(x0->qs);
+        const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
+        const v128_t y0_0 = wasm_v128_load(y0->qs);
+        const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
+
+        // Extend 8-bit to 16-bit
+        const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
+        const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
+        const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
+        const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
+
+        const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
+        const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
+        const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
+        const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
+
+        // Compute dot products
+        const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
+        const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
+        const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
+        const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
+
+        // Sum all dot products
+        const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
+
+        // Convert to float and accumulate
+        const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
+    }
+
+    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
 #elif defined(__AVX2__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -3686,6 +4045,27 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = hsum_float_8(acc);
+#elif defined(__VXE__) || defined(__VXE2__)
+    __vector float acc = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        const int8x16_t v_xl = vec_xl(0      , x[ib].qs);
+        const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs);
+        const int8x16_t v_yl = vec_xl(0      , y[ib].qs);
+        const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
+
+        const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
+        const float32x4_t v_xy = vec_float(v_xy_);
+        const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        acc = vec_madd(v_xy, v_d, acc);
+    }
+
+    sumf = acc[0] + acc[1] + acc[2] + acc[3];
 #endif
     for (; ib < nb; ++ib) {
         int sumi = 0;
@@ -4447,6 +4827,106 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = hsum_float_8(acc);
 
+#elif defined __wasm_simd128__
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * q2 = x[i].qs;
+        const int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        // Vectorized summs calculation
+        v128_t summs_vec = wasm_i32x4_splat(0);
+        {
+            v128_t sc_vec = wasm_v128_load(sc);
+            v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
+
+            v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
+            v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
+
+            v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
+            v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
+
+            summs_vec = wasm_i32x4_add(
+                wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
+                               wasm_i32x4_dot_i16x8(sc_high, bsums2)),
+                summs_vec
+            );
+
+            summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
+            summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
+        }
+        int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
+
+        // Vectorized isum calculation
+        int32_t isum = 0;
+        const uint8_t * sc_ptr = sc;
+        const int k_iters = QK_K/128;
+
+        for (int k = 0; k < k_iters; ++k) {
+            v128_t isum_vec = wasm_i32x4_splat(0);
+            int shift = 0;
+
+            for (int j = 0; j < 4; ++j) {
+                const int d0 = (sc_ptr[0] & 0xF);
+                const int d1 = (sc_ptr[1] & 0xF);
+                sc_ptr += 2;
+
+                // Process first 16 elements
+                v128_t q2_0 = wasm_v128_load(q2);
+                v128_t q8_0 = wasm_v128_load(q8);
+                v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
+                v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
+
+                // Process next 16 elements
+                v128_t q2_1 = wasm_v128_load(q2 + 16);
+                v128_t q8_1 = wasm_v128_load(q8 + 16);
+                v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
+                v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
+
+                // Calculate dot products
+                v128_t p0 = wasm_i32x4_dot_i16x8(
+                    wasm_i16x8_extend_low_i8x16(q8_0),
+                    wasm_i16x8_extend_low_i8x16(q2_bits_0)
+                );
+                v128_t p1 = wasm_i32x4_dot_i16x8(
+                    wasm_i16x8_extend_high_i8x16(q8_0),
+                    wasm_i16x8_extend_high_i8x16(q2_bits_0)
+                );
+                v128_t p2 = wasm_i32x4_dot_i16x8(
+                    wasm_i16x8_extend_low_i8x16(q8_1),
+                    wasm_i16x8_extend_low_i8x16(q2_bits_1)
+                );
+                v128_t p3 = wasm_i32x4_dot_i16x8(
+                    wasm_i16x8_extend_high_i8x16(q8_1),
+                    wasm_i16x8_extend_high_i8x16(q2_bits_1)
+                );
+
+                // Accumulate scaled results
+                v128_t scaled = wasm_i32x4_add(
+                    wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
+                    wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
+                );
+
+                isum_vec = wasm_i32x4_add(isum_vec, scaled);
+                q8 += 32;
+                shift += 2;
+            }
+            q2 += 32;
+
+            // Horizontal sum of isum_vec
+            isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
+            isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
+            isum += wasm_i32x4_extract_lane(isum_vec, 0);
+        }
+
+        const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf += dall * isum - dmin * summs;
+    }
+
+    *s = sumf;
+
 #elif defined __riscv_v_intrinsic
 
     float sumf = 0;
@@ -4666,9 +5146,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 #elif defined __loongarch_asx
 
-    const __m256i m3 = __lasx_xvreplgr2vr_b(3);
-    const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
-
     __m256 acc = (__m256)__lasx_xvldi(0);
 
     for (int i = 0; i < nb; ++i) {
@@ -4679,18 +5156,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const uint8_t * restrict q2 = x[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
 
-        const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
-        const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
-        const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
-        const __m256i mins = lasx_ext8_16(mins8);
+        const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
+        const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
+        const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
         const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
 
         acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
 
-        const __m256i all_scales = lasx_ext8_16(scales8);
-        const __m128i l_scales = lasx_extracti128(all_scales, 0);
-        const __m128i h_scales = lasx_extracti128(all_scales, 1);
-        const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
+        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
 
         __m256i sumi = __lasx_xvldi(0);
 
@@ -4703,20 +5177,20 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
 
-            const __m256i q2_0 = __lasx_xvand_v(q2bits, m3);
-            const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3);
-            const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3);
-            const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3);
+            const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
+            const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
+            const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
+            const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
 
-            __m256i p0 = lasx_maddubs_h(q2_0, q8_0);
-            __m256i p1 = lasx_maddubs_h(q2_1, q8_1);
-            __m256i p2 = lasx_maddubs_h(q2_2, q8_2);
-            __m256i p3 = lasx_maddubs_h(q2_3, q8_3);
+            __m256i p0 = lasx_madd_h_b(q2_0, q8_0);
+            __m256i p1 = lasx_madd_h_b(q2_1, q8_1);
+            __m256i p2 = lasx_madd_h_b(q2_2, q8_2);
+            __m256i p3 = lasx_madd_h_b(q2_3, q8_3);
 
-            p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0);
-            p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1);
-            p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2);
-            p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3);
+            p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
+            p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
+            p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
+            p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
 
             p0 = __lasx_xvadd_w(p0, p1);
             p2 = __lasx_xvadd_w(p2, p3);
@@ -4789,7 +5263,182 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     const int nb = n / QK_K;
 
-#ifdef __ARM_NEON
+#if defined(__ARM_FEATURE_SVE)
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    const int8_t m32 = 32;
+    const int vector_length = svcntb()*8;
+    const svuint8_t m3b_sv = svdup_n_u8(0x3);
+    const svint32_t vzero_sv = svdup_n_s32(0);
+
+    const svuint8_t m0_sv = svdup_n_u8(1);
+    const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
+    const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
+    const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3_sv = x[i].qs;
+        const uint8_t * restrict qh_sv = x[i].hmask;
+        const int8_t  * restrict q8_sv = y[i].qs;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+
+        for (int j = 0; j < 16; ++j) scale[j] -= m32;
+
+        switch (vector_length) {
+            case 128:
+                {
+                    svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
+                    svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
+                    svuint8_t q3h_sv;
+
+                    svint32_t sumi1_1 = svdup_n_s32(0);
+                    svint8_t q3bytes_sv;
+
+                    for (int j = 0; j < QK_K/128; ++j) {
+
+                        const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
+                        const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
+                        svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+                        svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
+
+                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
+
+
+                        scale += 4;
+                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                        q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
+
+                        q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
+
+
+                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                        q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
+
+                        q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
+
+                        if (j == 0) {
+                            qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
+                            qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
+                        }
+
+                        scale += 4;
+                    }
+
+                    sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
+                } break;
+            case 256:
+            case 512:
+                {
+                    svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
+                    svuint8_t q3h_sv;
+
+                    svint32_t sumi1_1 = svdup_n_s32(0);
+                    svint8_t q3bytes_sv;
+
+                    for (int j = 0; j < QK_K/128; ++j) {
+
+                        const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
+                        svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+                        svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
+                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+
+                        svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
+                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
+                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
+
+                        scale += 4;
+                        q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+                        q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                        q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
+                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
+                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
+
+                        q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
+                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
+
+                        if (j == 0) {
+                            qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
+                        }
+
+                        scale += 4;
+                    }
+
+                    sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
+                } break;
+            default:
+                assert(false && "Unsupported vector length");
+                break;
+        }
+    }
+    *s = sum;
+
+#elif __ARM_NEON
 
     uint32_t aux[3];
     uint32_t utmp[4];
@@ -5129,6 +5778,94 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = hsum_float_8(acc);
 
+#elif defined __wasm_simd128__
+    int8_t  aux8[QK_K];
+    float   sums[8] = {0};
+    uint32_t auxs[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Process blocks with SIMD
+        int8_t * a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int shift = 0; shift <= 6; shift += 2) {
+                v128_t v_m = wasm_i8x16_splat(m);
+                for (int l = 0; l < 32; l += 16) {
+                    v128_t v_q3 = wasm_v128_load(q3 + l);
+                    v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
+                    v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
+
+                    v128_t v_hm = wasm_v128_load(hm + l);
+                    v128_t v_mask = wasm_v128_and(v_hm, v_m);
+                    v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
+
+                    v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
+                    wasm_v128_store(a + l, v_low2);
+                }
+                a += 32;
+                m <<= 1;
+            }
+            q3 += 32;
+        }
+
+        // Extract scales
+        memcpy(auxs, x[i].scales, 12);
+        uint32_t tmp = auxs[2];
+        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+        const int8_t * scales = (const int8_t *)auxs;
+
+        // SIMD dot product with register accumulators
+        v128_t v_acc0 = wasm_i32x4_splat(0);
+        v128_t v_acc1 = wasm_i32x4_splat(0);
+        a = aux8;
+        for (int j = 0; j < QK_K/16; ++j) {
+            const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
+
+            // Process 16 elements per iteration
+            for (int k = 0; k < 2; ++k) {
+                const v128_t v_q8 = wasm_i16x8_load8x8(q8);
+                const v128_t v_a = wasm_i16x8_load8x8(a);
+
+                v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
+                v_prod = wasm_i16x8_mul(v_prod, v_scale);
+
+                v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
+                v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
+
+                q8 += 8;
+                a += 8;
+            }
+        }
+
+        // Accumulate results
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const v128_t v_d = wasm_f32x4_splat(d);
+        v128_t v_sum = wasm_f32x4_add(
+            wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
+            wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
+        );
+
+        // Accumulate into sums vector
+        wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
+    }
+
+    // Horizontal sum
+    v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
+    sumf = wasm_f32x4_extract_lane(v_sum, 0) +
+           wasm_f32x4_extract_lane(v_sum, 1) +
+           wasm_f32x4_extract_lane(v_sum, 2) +
+           wasm_f32x4_extract_lane(v_sum, 3);
+
+    *s = sumf;
+
 #elif defined __riscv_v_intrinsic
 
     uint32_t aux[3];
@@ -5384,8 +6121,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 #elif defined __loongarch_asx
 
-    const __m256i m3 = __lasx_xvreplgr2vr_b(3);
-    const __m256i mone = __lasx_xvreplgr2vr_b(1);
     const __m128i m32 = __lsx_vreplgr2vr_b(32);
 
     __m256 acc = (__m256)__lasx_xvldi(0);
@@ -5405,10 +6140,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
                 (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
                 (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
         scales128 = __lsx_vsub_b(scales128, m32);
-        const __m256i all_scales = lasx_ext8_16(scales128);
-        const __m128i l_scales = lasx_extracti128(all_scales, 0);
-        const __m128i h_scales = lasx_extracti128(all_scales, 1);
-        const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+
+        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
+        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
 
         // high bit
         const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
@@ -5416,35 +6150,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         // integer accumulator
         __m256i sumi = __lasx_xvldi(0);
 
-        int bit = 0;
-        int is  = 0;
-        __m256i xvbit;
-
-
         for (int j = 0; j < QK_K/128; ++j) {
             // load low 2 bits
             const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
 
-            xvbit = __lasx_xvreplgr2vr_h(bit);
             // prepare low and high bits
-            const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
-            const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
-            ++bit;
-
-            xvbit = __lasx_xvreplgr2vr_h(bit);
-            const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
-            const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
-            ++bit;
-
-            xvbit = __lasx_xvreplgr2vr_h(bit);
-            const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
-            const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
-            ++bit;
-
-            xvbit = __lasx_xvreplgr2vr_h(bit);
-            const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
-            const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
-            ++bit;
+            const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
+            const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
+            const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
+            const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
+            const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
+            const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
+            const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
+            const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
+            const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
+            const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
+            const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
+            const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
 
             // load Q8 quants
             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
@@ -5452,29 +6174,16 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
 
-            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
-            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
-            // and 2 if the high bit was set)
-            __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
-            __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
-            __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
-            __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
-
-            __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
-            __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
-            __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
-            __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
-
-            p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
-            p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
-            p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
-            p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
+            __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
+            __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
+            __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
+            __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
 
             // multiply with scales
-            p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
-            p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
-            p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
-            p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+            p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
+            p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
+            p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
+            p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
 
             // accumulate
             p16_0 = __lasx_xvadd_w(p16_0, p16_1);
@@ -5482,7 +6191,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
             sumi  = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
         }
         // multiply with block scale and accumulate
-        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
+        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
     }
 
     *s = hsum_float_8(acc);
@@ -5573,7 +6282,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     uint32_t utmp[4];
 
-#ifdef __ARM_NEON
+#ifdef __ARM_FEATURE_SVE
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, K_SCALE_SIZE);
+
+        uint32x2_t mins8 = { 0 };
+        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        sumf -= dmin * vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int vector_length = ggml_cpu_get_sve_cnt()*8;
+        const svuint8_t m4b = svdup_n_u8(0xf);
+        const svint32_t mzero = svdup_n_s32(0);
+        svint32_t sumi1 = svdup_n_s32(0);
+        svint32_t sumi1_1 = svdup_n_s32(0);
+        svint32_t sumi1_2 = svdup_n_s32(0);
+        svint32_t sumi2 = svdup_n_s32(0);
+        svint32_t sumi2_1 = svdup_n_s32(0);
+        svint32_t sumi2_2 = svdup_n_s32(0);
+        switch (vector_length) {
+            case 128:
+                {
+                    for (int j = 0; j < QK_K/64; ++j) {
+                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
+                        svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+                        q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
+                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+
+                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
+                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
+                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+                        q4 += 32;
+                    }
+                    sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
+                    sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
+                    sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
+                } break;
+            case 256:
+            case 512:
+                {
+                    for (int j = 0; j < QK_K/64; ++j) {
+                        const svuint8_t q4bits  = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
+                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
+                        svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
+                        sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+
+                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
+                        q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
+                        sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+                    }
+                    sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
+                } break;
+            default:
+                assert(false && "Unsupported vector length");
+                break;
+        }
+    }
+    *s = sumf;
+#elif defined __ARM_NEON
     const uint8x16_t m4b = vdupq_n_u8(0xf);
     const int32x4_t mzero = vdupq_n_s32(0);
 
@@ -5636,6 +6426,107 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = sumf;
 
+#elif defined __wasm_simd128__
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Process scales and mins
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        // Sum mins * q8sums
+        int32_t sumi = 0;
+        const int16_t * restrict q8sums = y[i].bsums;
+        const uint8_t * m = (const uint8_t *)&utmp[2];
+        for (int j = 0; j < 16; j += 2) {
+            sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
+        }
+        sumf -= dmin * sumi;
+
+        int32_t sumi1 = 0;
+        int32_t sumi2 = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // Load 64 4-bit weights (32 bytes)
+            const v128_t q4x0 = wasm_v128_load(q4);
+            const v128_t q4x1 = wasm_v128_load(q4 + 16);
+            q4 += 32;
+
+            // Split into low/high nibbles
+            const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
+            const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
+            const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
+            const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
+
+            // Load 64 8-bit values (64 bytes)
+            const v128_t q8x0 = wasm_v128_load(q8);
+            const v128_t q8x1 = wasm_v128_load(q8 + 16);
+            const v128_t q8x2 = wasm_v128_load(q8 + 32);
+            const v128_t q8x3 = wasm_v128_load(q8 + 48);
+            q8 += 64;
+
+            // Low nibble products
+            v128_t vacc1 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q4l0),
+                wasm_i16x8_extend_low_i8x16(q8x0)
+            );
+            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q4l0),
+                wasm_i16x8_extend_high_i8x16(q8x0)
+            ));
+            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q4l1),
+                wasm_i16x8_extend_low_i8x16(q8x1)
+            ));
+            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q4l1),
+                wasm_i16x8_extend_high_i8x16(q8x1)
+            ));
+
+            // High nibble products
+            v128_t vacc2 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q4h0),
+                wasm_i16x8_extend_low_i8x16(q8x2)
+            );
+            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q4h0),
+                wasm_i16x8_extend_high_i8x16(q8x2)
+            ));
+            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q4h1),
+                wasm_i16x8_extend_low_i8x16(q8x3)
+            ));
+            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q4h1),
+                wasm_i16x8_extend_high_i8x16(q8x3)
+            ));
+
+            // Accumulate scaled results
+            int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
+                                wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
+            sumi1 += vacc1_sum * scales[2*j];
+
+            int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
+                                wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
+            sumi2 += vacc2_sum * scales[2*j+1];
+        }
+
+        sumf += d * (sumi1 + sumi2);
+    }
+
+    *s = sumf;
+
 #elif defined __AVX2__
 
     const __m256i m4 = _mm256_set1_epi8(0xF);
@@ -5993,11 +6884,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     *s = vec_extract(vsumf0, 0);
 
 #elif defined __loongarch_asx
-    GGML_UNUSED(kmask1);
-    GGML_UNUSED(kmask2);
-    GGML_UNUSED(kmask3);
-
-    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
 
     __m256 acc = (__m256)__lasx_xvldi(0);
     __m128 acc_m = (__m128)__lsx_vldi(0);
@@ -6017,33 +6903,34 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const uint8_t * restrict q4 = x[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
 
-        const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+        const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
+        const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
 
         const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
         const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
-        const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+        const __m128i prod = lsx_madd_h(mins128, q8s);
         acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
 
-        const __m128i sc128  = lasx_extracti128(mins_and_scales, 0);
-        const __m256i scales = lasx_insertf128(sc128, sc128);
+        const __m256i scales = lasx_insertf128(scales128, scales128);
 
         __m256i sumi = __lasx_xvldi(0);
 
         for (int j = 0; j < QK_K/64; ++j) {
 
-            const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
-            const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+            const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
+            const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
 
             const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
-            const __m256i q4l = __lasx_xvand_v(q4bits, m4);
-            const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
+            const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
+            const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
 
             const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
-            __m256i p16l = lasx_maddubs_h(q4l, q8l);
+            __m256i p16l = lasx_madd_h_b(q4l, q8l);
             p16l = lasx_madd_h(scale_l, p16l);
 
             const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
-            __m256i p16h = lasx_maddubs_h(q4h, q8h);
+            __m256i p16h = lasx_madd_h_b(q4h, q8h);
             p16h = lasx_madd_h(scale_h, p16h);
             const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
 
@@ -6060,9 +6947,78 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
 
 
-    ft_union fi;
-    fi.i = __lsx_vpickve2gr_w(acc_m, 0);
-    *s = hsum_float_8(acc) + fi.f ;
+    *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
+#elif defined(__VXE__) || defined(__VXE2__)
+    const uint8x16_t v_lm = vec_splat_u8(0x0F);
+    const int32x4_t v_z = vec_splat_s32(0);
+
+    uint8x16_t v_x[2];
+    int8x16_t  v_xl[2];
+    int8x16_t  v_y[2];
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
+        const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
+        const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
+
+        memcpy(utmp, x[i].scales, 12);
+
+        uint32x4_t v_mins8 = { 0 };
+        v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0);
+        v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1);
+
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
+
+        const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
+        const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
+        const int32x4_t v_mins = v_minso + v_minse;
+        sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+        const uint8_t * restrict x0 = x[i].qs;
+        const int8_t  * restrict y0 = y[i].qs;
+
+        int32_t sumi1 = 0;
+        int32_t sumi2 = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            v_x[0] = vec_xl(0 , x0);
+            v_x[1] = vec_xl(16, x0);
+            x0 += 32;
+
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            y0 += 32;
+
+            v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm);
+            v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
+
+            const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
+            sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0];
+
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            y0 += 32;
+
+            v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4);
+            v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
+
+            const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
+            sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1];
+        }
+
+        sumf += d * (sumi1 + sumi2);
+    }
+
+    *s = sumf;
 #else
 
     const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6388,6 +7344,118 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = hsum_float_8(acc) + summs;
 
+#elif defined __wasm_simd128__
+    //const uint8_t * scales = (const uint8_t*)&utmp[0];
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Process scales and mins
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        // Sum mins * q8sums
+        int32_t sumi_mins = 0;
+        const int16_t * restrict q8sums = y[i].bsums;
+        const uint8_t * m = (const uint8_t *)&utmp[2];
+        for (int j = 0; j < 16; j += 2) {
+            sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
+        }
+        sumf -= dmin * sumi_mins; // Correct subtraction
+
+        v128_t qh0 = wasm_v128_load(qh);
+        v128_t qh1 = wasm_v128_load(qh + 16);
+        const uint8_t * sc = (const uint8_t *)utmp;
+
+        int32_t sumi = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            const int shift = j * 2;
+            v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
+            v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
+
+            v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
+            v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
+            v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
+            v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
+
+            v128_t q5_0 = wasm_v128_load(q5);
+            v128_t q5_1 = wasm_v128_load(q5 + 16);
+            q5 += 32;
+
+            v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
+            v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
+            v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
+            v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
+
+            v128_t q8_0 = wasm_v128_load(q8);
+            v128_t q8_1 = wasm_v128_load(q8 + 16);
+            v128_t q8_2 = wasm_v128_load(q8 + 32);
+            v128_t q8_3 = wasm_v128_load(q8 + 48);
+            q8 += 64;
+
+            // Process low quants
+            v128_t pl0 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q5l_0),
+                wasm_i16x8_extend_low_i8x16(q8_0)
+            );
+            pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q5l_0),
+                wasm_i16x8_extend_high_i8x16(q8_0)
+            ));
+            v128_t pl1 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q5l_1),
+                wasm_i16x8_extend_low_i8x16(q8_1)
+            );
+            pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q5l_1),
+                wasm_i16x8_extend_high_i8x16(q8_1)
+            ));
+            v128_t sum_low = wasm_i32x4_add(pl0, pl1);
+
+            // Process high quants
+            v128_t ph0 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q5h_0),
+                wasm_i16x8_extend_low_i8x16(q8_2)
+            );
+            ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q5h_0),
+                wasm_i16x8_extend_high_i8x16(q8_2)
+            ));
+            v128_t ph1 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q5h_1),
+                wasm_i16x8_extend_low_i8x16(q8_3)
+            );
+            ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q5h_1),
+                wasm_i16x8_extend_high_i8x16(q8_3)
+            ));
+            v128_t sum_high = wasm_i32x4_add(ph0, ph1);
+
+            // Accumulate with scale factors
+            int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
+                        wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
+            int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
+                        wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
+
+            sumi += sl * sc[2*j] + sh * sc[2*j+1];
+        }
+
+        sumf += d * sumi;
+    }
+
+    *s = sumf;
+
 #elif defined __riscv_v_intrinsic
 
     const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6610,19 +7678,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     *s = vec_extract(vsumf0, 0);
 
 #elif defined __loongarch_asx
-    GGML_UNUSED(kmask1);
-    GGML_UNUSED(kmask2);
-    GGML_UNUSED(kmask3);
-
-    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
-    const __m128i mzero = __lsx_vldi(0);
-    const __m256i mone  = __lasx_xvreplgr2vr_b(1);
 
     __m256 acc = (__m256)__lasx_xvldi(0);
+    __m128 acc_m = (__m128)__lsx_vldi(0);
 
-    float summs = 0.f;
-
-   for (int i = 0; i < nb; ++i) {
+    for (int i = 0; i < nb; ++i) {
 
         const uint8_t * restrict q5 = x[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
@@ -6637,49 +7697,40 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         utmp[2] = uaux;
         utmp[0] &= kmask1;
 
-        const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+        const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
+        const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
 
         const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
         const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
-        const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
-        const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
-        summs += dmin * __lsx_vpickve2gr_w(hsum, 0);    //TODO check
+        const __m128i prod = lsx_madd_h(mins128, q8s);
+        acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
 
-        const __m128i sc128  = lasx_extracti128(mins_and_scales, 0);
-        const __m256i scales = lasx_insertf128(sc128, sc128);
+        const __m256i scales = lasx_insertf128(scales128, scales128);
 
         const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
-        __m256i hmask = mone;
 
         __m256i sumi = __lasx_xvldi(0);
 
-        int bit = 0;
-        __m256i xvbit;
-
         for (int j = 0; j < QK_K/64; ++j) {
 
-            const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
-            const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+            const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
+            const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
 
             const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
 
-            xvbit = __lasx_xvreplgr2vr_h(bit++);
-            const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
-            const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
-            const __m256i q5_0  = __lasx_xvadd_b(q5l_0, q5h_0);
-            hmask = __lasx_xvslli_h(hmask, 1);
-
-            xvbit = __lasx_xvreplgr2vr_h(bit++);
-            const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
-            const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
-            const __m256i q5_1  = __lasx_xvadd_b(q5l_1, q5h_1);
-            hmask = __lasx_xvslli_h(hmask, 1);
+            const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
+            const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
+            const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
+            const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
+            const __m256i q5_0  = __lasx_xvor_v(q5l_0, q5h_0);
+            const __m256i q5_1  = __lasx_xvor_v(q5l_1, q5h_1);
 
             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
 
-            __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
-            __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
+            __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
+            __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
 
             p16_0 = lasx_madd_h(scale_0, p16_0);
             p16_1 = lasx_madd_h(scale_1, p16_1);
@@ -6693,8 +7744,98 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     }
 
-    *s = hsum_float_8(acc) + summs;
+    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
+    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
 
+    *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
+#elif defined(__VXE__) || defined(__VXE2__)
+    const uint8x16_t v_lm = vec_splat_u8(0x0F);
+    const uint8x16_t v_1m = vec_splat_u8(0x01);
+    const uint8x16_t v_2m = vec_splat_u8(0x02);
+
+    const int32x4_t v_z = vec_splat_s32(0);
+
+    const uchar8x16_t v_minsm = {
+        0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
+        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
+    };
+
+    int8x16_t  q5b[4];
+    uint8x16_t q5h[4];
+
+    uint8x16_t v_xl[2];
+    uint8x16_t v_xh[2];
+    int8x16_t  v_y[4];
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
+        const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
+        const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp);
+        const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm);
+        const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
+
+        const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
+        const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
+        const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
+        const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+        const uint8_t * restrict x0l = x[i].qs;
+        const uint8_t * restrict x0h = x[i].qh;
+        const int8_t  * restrict y0 = y[i].qs;
+
+        v_xh[0] = vec_xl(0 , x0h);
+        v_xh[1] = vec_xl(16, x0h);
+
+        int32_t sumi = 0;
+        for (int j = 0; j < QK_K/64; ++j) {
+            v_xl[0] = vec_xl(0 , x0l);
+            v_xl[1] = vec_xl(16, x0l);
+            x0l += 32;
+
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            v_y[2] = vec_xl(32, y0);
+            v_y[3] = vec_xl(48, y0);
+            y0 += 64;
+
+            q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4);
+            q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4);
+            q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3);
+            q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3);
+            v_xh[0] = vec_sr(v_xh[0], 2);
+            v_xh[1] = vec_sr(v_xh[1], 2);
+
+            q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]);
+            q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]);
+            q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]);
+            q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]);
+
+            int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
+            int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
+
+            sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++;
+            sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++;
+        }
+
+        sumf += d * sumi - dmin * mins;
+    }
+
+    *s = sumf;
 #else
 
     const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -7051,6 +8192,85 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = hsum_float_8(acc);
 
+#elif defined __wasm_simd128__
+    int8_t aux8[QK_K] __attribute__((aligned(16)));
+    int32_t aux32[8] __attribute__((aligned(16))) = {0};
+    float sums[8] __attribute__((aligned(16))) = {0};
+
+    for (int i = 0; i < nb; ++i) {
+        // Unpack 6-bit quantized data into aux8 (unchanged)
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        int8_t * a = aux8;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            }
+            a += 128;
+            q4 += 64;
+            qh += 32;
+        }
+
+        const int8_t * restrict a_ptr = aux8;
+        const int8_t * restrict q8 = y[i].qs;
+        v128_t acc0 = wasm_i32x4_splat(0);
+        v128_t acc1 = wasm_i32x4_splat(0);
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            const int scale = x[i].scales[j];
+            const v128_t vscale = wasm_i32x4_splat(scale);
+
+            // Load 16 elements from a and q8
+            const v128_t a_vec = wasm_v128_load(a_ptr);
+            const v128_t q8_vec = wasm_v128_load(q8);
+
+            // Process low 8 elements
+            v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
+            v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
+            v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
+            v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
+            v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
+
+            // Process high 8 elements
+            v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
+            v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
+            v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
+            v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
+            v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
+
+            // Scale and accumulate
+            prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
+            prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
+            prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
+            prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
+
+            acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
+            acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
+
+            a_ptr += 16;
+            q8 += 16;
+        }
+
+        // Store accumulated results
+        wasm_v128_store(&aux32[0], acc0);
+        wasm_v128_store(&aux32[4], acc1);
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) {
+            sums[l] += d * aux32[l];
+        }
+    }
+
+    // Sum final results
+    float sumf = 0;
+    for (int l = 0; l < 8; ++l) {
+        sumf += sums[l];
+    }
+    *s = sumf;
+
 #elif defined __riscv_v_intrinsic
 
     float sumf = 0;
@@ -7275,8 +8495,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 #elif defined __loongarch_asx
 
-    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
-    const __m256i m2 = __lasx_xvreplgr2vr_b(3);
     const __m256i m32s = __lasx_xvreplgr2vr_b(32);
 
     __m256 acc = (__m256)__lasx_xvldi(0);
@@ -7289,58 +8507,42 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const uint8_t * restrict qh = x[i].qh;
         const int8_t  * restrict q8 = y[i].qs;
 
-        const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
+        const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
+        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
+        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
 
         __m256i sumi = __lasx_xvldi(0);
 
-        int is = 0;
-
         for (int j = 0; j < QK_K/128; ++j) {
 
-            const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
-            const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
-            const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
-            const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
-            is += 4;
-
             const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
             const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
             const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
 
-            const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
-            const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
-            const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
-            const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
+            const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
+            const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
+            const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
+            const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
 
-            const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
-            const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
-            const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
-            const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
+            const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
+            const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
+            const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
+            const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
 
             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
 
-            __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
-            __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
-            __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
-            __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
+            __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
+            __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
+            __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
+            __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
 
-            __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
-            __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
-            __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
-            __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
-
-            p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
-            p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
-            p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
-            p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
-
-            p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
-            p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
-            p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
-            p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
+            p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
+            p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
+            p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
+            p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
 
             sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
             sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
@@ -7350,7 +8552,130 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     }
 
     *s = hsum_float_8(acc);
+#elif defined(__VXE__) || defined(__VXE2__)
+    float sum = 0;
 
+    // Lower 4-bit and upper 2-bit masks
+    const uint8x16_t v_lm = vec_splat_u8(0x0F);
+    const uint8x16_t v_um = vec_splat_u8(0x03);
+
+    const int32x4_t v_z = vec_splat_s32(0);
+
+    int8x16_t  q6b[4];
+    uint8x16_t q6h[4];
+
+    uint8x16_t v_xl[4];
+    uint8x16_t v_xh[2];
+    int8x16_t  v_y[4];
+
+    for (int i = 0; i < nb; ++i) {
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict x0l = x[i].ql;
+        const uint8_t * restrict x0h = x[i].qh;
+        const int8_t  * restrict y0 = y[i].qs;
+
+        const int8_t  * restrict scale = x[i].scales;
+
+        const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
+        const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
+
+        const int8x16_t v_scale  = vec_xl(0, scale);
+        const int16x8_t v_scalel = vec_unpackh(v_scale);
+        const int16x8_t v_scaleh = vec_unpackl(v_scale);
+
+        const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
+        const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
+        const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
+        const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
+        const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
+
+        const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
+
+        int32_t isum = 0;
+        for (int j = 0; j < QK_K/128; ++j) {
+            // Load model upper 2 bits
+            v_xh[0] = vec_xl(0 , x0h);
+            v_xh[1] = vec_xl(16, x0h);
+            x0h += 32;
+
+            // Load model lower 4 bits
+            v_xl[0] = vec_xl(0 , x0l);
+            v_xl[1] = vec_xl(16, x0l);
+            v_xl[2] = vec_xl(32, x0l);
+            v_xl[3] = vec_xl(48, x0l);
+            x0l += 64;
+
+            // Load activation quants
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            v_y[2] = vec_xl(32, y0);
+            v_y[3] = vec_xl(48, y0);
+            y0 += 64;
+
+            q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4);
+            q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4);
+            uint8x16_t shifted = vec_sr(v_xh[0], 2);
+            q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
+            shifted = vec_sr(v_xh[1], 2);
+            q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
+
+            q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0]));
+            q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1]));
+            q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2]));
+            q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3]));
+
+            int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
+            int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
+            int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
+            int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
+
+            isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
+                    (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
+                    (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
+                    (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
+
+            scale += 4;
+
+
+            // Load activation quants
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            v_y[2] = vec_xl(32, y0);
+            v_y[3] = vec_xl(48, y0);
+            y0 += 64;
+
+            shifted = vec_sr(v_xh[0], 4);
+            q6h[0] = vec_sl(vec_and(v_um, shifted), 4);
+            shifted = vec_sr(v_xh[1], 4);
+            q6h[1] = vec_sl(vec_and(v_um, shifted), 4);
+            shifted = vec_sr(v_xh[0], 6);
+            q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
+            shifted = vec_sr(v_xh[1], 6);
+            q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
+
+            q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0]));
+            q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1]));
+            q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2]));
+            q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3]));
+
+            summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
+            summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
+            summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
+            summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
+
+            isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
+                    (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
+                    (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
+                    (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
+
+            scale += 4;
+        }
+
+        sum += d_all * y[i].d * (isum - 32 * mins);
+    }
+
+    *s = sum;
 #else
 
     int8_t  aux8[QK_K];
@@ -7711,7 +9036,57 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     }
 
     *s = 0.125f * hsum_float_8(accumf);
-
+//#elif defined(__VXE__) || defined(__VXE2__)
+//    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+//
+//    uint32_t aux32[4];
+//    const uint8_t * aux8 = (const uint8_t *)aux32;
+//
+//    float sumf = 0;
+//
+//    for (int i = 0; i < nb; ++i) {
+//        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+//        const uint16_t * restrict q2 = x[i].qs;
+//        const int8_t   * restrict q8 = y[i].qs;
+//
+//        float sumf1 = 0, sumf2 = 0;
+//
+//        for (int ib32 = 0; ib32 < QK_K/32; ib += 2) {
+//            int8x16_t q8b0 = vec_xl( 0, q8);
+//            int8x16_t qb81 = vec_xl(16, q8);
+//            int8x16_t q8b2 = vec_xl(32, q8);
+//            int8x16_t q8b3 = vec_xl(48, q8);
+//            q8 += 64;
+//
+//            memcpy(aux32, q2, 4 * sizeof(uint32_t));
+//            q2 += 8;
+//
+//            int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) };
+//            int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) };
+//            int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) };
+//            int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) };
+//
+//            int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >>  7) & 127)) };
+//            int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) };
+//            int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >>  7) & 127)) };
+//            int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) };
+//
+//            q2u0 = vec_mul(q2u0, q2s0);
+//            q2u1 = vec_mul(q2u1, q2s1);
+//            q2u2 = vec_mul(q2u2, q2s2);
+//            q2u3 = vec_mul(q2u3, q2s3);
+//
+//            const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1);
+//            const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3);
+//
+//            sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28));
+//            sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28));
+//        }
+//
+//        sumf += d * (sumf1 + sumf2);
+//    }
+//
+//    *s = 0.25f * sumf;
 #else
 
     uint32_t aux32[2];
@@ -9665,13 +11040,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
 }
 #elif defined(__loongarch_asx)
 static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
-    const __m256i ax = __lasx_xvsigncov_b(x, x);
-    const __m256i sy = __lasx_xvsigncov_b(x, y);
-    __m256i tmp1, tmp2, tmp3;
-    tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
-    tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
-    tmp3 = __lasx_xvadd_h(tmp1, tmp2);
-    return __lasx_xvsat_h(tmp3, 15);
+    const __m256i a = __lasx_xvmulwev_h_b(x, y);
+    const __m256i b = __lasx_xvmulwod_h_b(x, y);
+    return __lasx_xvadd_h(a, b);
 }
 #endif
 
@@ -10476,6 +11847,27 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
 
     sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
 
+#elif defined(__VXE__) || defined(__VXE2__)
+    const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
+    const uint8x16_t v_m = vec_splat_u8(0x0F);
+
+    for (; ib < nb; ++ib) {
+        const block_iq4_nl * restrict x0 = &x[ib];
+        const block_q8_0   * restrict y0 = &y[ib];
+
+        const uint8x16_t v_x = vec_xl(0, x0->qs);
+        int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
+        int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
+
+        v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
+        v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);
+
+        const int8x16_t v_yl = vec_xl(0      , y0->qs);
+        const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
+        const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
+
+        sumf += GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
+    }
 #endif
     for (; ib < nb; ++ib) {
         const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
@@ -10721,67 +12113,31 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
 #elif defined(__loongarch_asx)
 
     const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
-    const __m128i m4b  = __lsx_vreplgr2vr_b(0x0f);
 
     __m256 accum = (__m256)__lasx_xvldi(0);
-    __m256i tmp1;
-    __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
 
-    mask_8f = __lsx_vreplgr2vr_b(0x8f);
     for (int ibl = 0; ibl < nb; ++ibl) {
         const uint8_t * qs = x[ibl].qs;
         const int8_t  * q8 = y[ibl].qs;
         uint16_t sh = x[ibl].scales_h;
         __m256i sumi1 = __lasx_xvldi(0);
         __m256i sumi2 = __lasx_xvldi(0);
-        __m128i zero = __lsx_vldi(0);
         for (int ib = 0; ib < QK_K/32; ib += 2) {
-            const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0);  qs += 16;
-            const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0);  qs += 16;
+            const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
+            const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
             const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
             const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
-            tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f);
-            tmp0 = __lsx_vori_b(tmp2, 0x10);
-            mask = __lsx_vsle_b(zero, tmp2);
-            tmp3 = __lsx_vand_v(tmp0, mask);
-            tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
-
-            tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
-            tmp0 = __lsx_vori_b(tmp2, 0x10);
-            mask = __lsx_vsle_b(zero, tmp2);
-            tmp4 = __lsx_vand_v(tmp0, mask);
-            tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
-
-            const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
-
-            tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
-            tmp0 = __lsx_vori_b(tmp2, 0x10);
-            mask = __lsx_vsle_b(zero, tmp2);
-            tmp3 = __lsx_vand_v(tmp0, mask);
-            tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
-
-            tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
-            tmp0 = __lsx_vori_b(tmp2, 0x10);
-            mask = __lsx_vsle_b(zero, tmp2);
-            tmp4 = __lsx_vand_v(tmp0, mask);
-            tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
-
-            const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
-
+            const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
+                                                  __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
+            const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
+                                                  __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
             const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
             const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
             const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
             const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
             sh >>= 4;
-            __m256i tmp5, tmp6;
-            tmp1 = __lasx_xvreplgr2vr_h(ls1);
-            tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
-            tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
-            const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
-            tmp1 = __lasx_xvreplgr2vr_h(ls2);
-            tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
-            tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
-            const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
+            const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
+            const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
             sumi1 = __lasx_xvadd_w(p_1, sumi1);
             sumi2 = __lasx_xvadd_w(p_2, sumi2);
         }
@@ -10790,6 +12146,56 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     }
 
     *s = hsum_float_8(accum);
+#elif defined(__VXE__) || defined(__VXE2__)
+    const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
+    const uint8x16_t v_m = vec_splat_u8(0x0F);
+
+    float sumf = 0;
+
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * restrict q4 = x[ibl].qs;
+        const int8_t  * restrict q8 = y[ibl].qs;
+
+        uint16_t h = x[ibl].scales_h;
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib = 0; ib < QK_K/64; ++ib) {
+            const uint8x16_t v_x0 = vec_xl(0       , q4);
+            const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4);
+            q4 += 32;
+
+            int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
+            int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
+            int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
+            int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
+
+            v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
+            v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
+            v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
+            v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
+
+            const int8x16_t v_y0 = vec_xl( 0, q8);
+            const int8x16_t v_y1 = vec_xl(16, q8);
+            const int8x16_t v_y2 = vec_xl(32, q8);
+            const int8x16_t v_y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1);
+            int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3);
+
+            int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32;
+            int ls2 = ((x[ibl].scales_l[ib] >>  4) | ((h << 2) & 0x30)) - 32;
+
+            h >>= 4;
+
+            sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1;
+            sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2;
+        }
+
+        sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
+    }
+
+    *s = sumf;
 
 #else
     float sumf = 0;
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
index b307d5542..2f606d824 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
@@ -7,10 +7,8 @@
 #include "ggml-cpu-impl.h"
 #include "ggml-cpu.h"
 #include "ggml-impl.h"
-#include "ggml-quants.h"
 #include "ggml-cpu-quants.h"
 #include "ggml-threading.h"
-#include "amx/amx.h"
 #include "ggml.h"
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
@@ -114,7 +112,8 @@ struct ggml_arm_arch_features_type {
     int has_i8mm;
     int has_sve;
     int sve_cnt;
-} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
+    int has_sme;
+} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
 #endif
 
 
@@ -238,6 +237,8 @@ typedef pthread_t ggml_thread_t;
 #else
 #if defined(__POWER9_VECTOR__)
 #define CACHE_LINE_SIZE 128
+#elif defined(__VXE__) || defined(__VXE2__)
+#define CACHE_LINE_SIZE 256
 #else
 #define CACHE_LINE_SIZE 64
 #endif
@@ -1078,29 +1079,23 @@ do {                                                              \
 #define GGML_F16_STEP 32
 #define GGML_F16_EPR  8
 
-// F16 arithmetic is not supported by AVX, so we use F32 instead
+// F16 arithmetic is not supported by LASX, so we use F32 instead
 
 #define GGML_F32Cx8          __m256
 #define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)
 #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
 
 static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
-    float tmp[8];
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return (__m256)__lasx_xvld(tmp, 0);
+    __m256i a;
+    memcpy(&a, x, sizeof(ggml_fp16_t) * 8);
+    a = __lasx_xvpermi_d(a, 0 | (1 << 4));
+    return __lasx_xvfcvtl_s_h(a);
 }
+
 static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
-    float arr[8];
-
-    __lasx_xvst(y, arr, 0);
-
-    for (int i = 0; i < 8; i++) {
-        x[i] = GGML_FP32_TO_FP16(arr[i]);
-    }
+    __m256i a = __lasx_xvfcvt_h_s(y, y);
+    a = __lasx_xvpermi_d(a, 0 | (2 << 2));
+    memcpy(x, &a, sizeof(ggml_fp16_t) * 8);
 }
 #define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x)
 #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
@@ -1218,6 +1213,87 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
 #define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
 #define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
 
+#elif defined(__VXE__) || defined(__VXE2__)
+
+#define GGML_SIMD
+
+// F32 s390x
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              __vector float
+#define GGML_F32x4_ZERO         vec_splats(0.0f)
+#define GGML_F32x4_SET1         vec_splats
+#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD          vec_add
+#define GGML_F32x4_MUL          vec_mul
+#define GGML_F32x4_REDUCE(res, x)                   \
+{                                                   \
+    int offset = GGML_F32_ARR >> 1;                 \
+    for (int i = 0; i < offset; ++i) {              \
+        x[i] = vec_add(x[i], x[offset + i]);        \
+    }                                               \
+    offset >>= 1;                                   \
+    for (int i = 0; i < offset; ++i) {              \
+        x[i] = vec_add(x[i], x[offset + i]);        \
+    }                                               \
+    offset >>= 1;                                   \
+    for (int i = 0; i < offset; ++i) {              \
+        x[i] = vec_add(x[i], x[offset + i]);        \
+    }                                               \
+    res = vec_extract(x[0], 0) +                    \
+          vec_extract(x[0], 1) +                    \
+          vec_extract(x[0], 2) +                    \
+          vec_extract(x[0], 3);                     \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 s390x
+#define GGML_F16_STEP GGML_F32_STEP
+#define GGML_F16_EPR  GGML_F32_EPR
+
+static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) {
+    float tmp[4];
+
+    for (int i = 0; i < 4; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return vec_xl(0, tmp);
+}
+
+static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) {
+    float arr[4];
+
+    vec_xst(y, 0, arr);
+
+    for (int i = 0; i < 4; i++) {
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+    }
+}
+
+#define GGML_F16_VEC                GGML_F32x4
+#define GGML_F16_VEC_ZERO           GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32x4_SET1
+#define GGML_F16_VEC_LOAD(p, i)     __lzs_f16cx4_load(p)
+#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32x4_FMA
+#define GGML_F16_VEC_ADD            GGML_F32x4_ADD
+#define GGML_F16_VEC_MUL            GGML_F32x4_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32x4_REDUCE
+
 #endif
 
 // GGML_F32_ARR / GGML_F16_ARR
@@ -1297,12 +1373,12 @@ struct ggml_threadpool {
     atomic_int n_graph;       // incremented when there is work to be done (i.e each graph)
     atomic_int GGML_CACHE_ALIGN n_barrier;
     atomic_int GGML_CACHE_ALIGN n_barrier_passed;
-    atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
+    atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
 
     // these are atomic as an annotation for thread-sanitizer
     atomic_bool stop;         // Used for stopping the threadpool altogether
     atomic_bool pause;        // Used for pausing the threadpool or individual threads
-    atomic_bool abort;        // Used for aborting processing of a graph
+    atomic_int abort;         // Used for aborting processing of a graph
 
     struct ggml_compute_state * workers;   // per thread state
     int          n_threads_max; // number of threads in the pool
@@ -1824,7 +1900,7 @@ inline static float ggml_silu_f32(float x) {
 
 #if __FINITE_MATH_ONLY__
 #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
-#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
+#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
 #endif
 
 #if defined(__ARM_NEON) && defined(__aarch64__)
@@ -2389,15 +2465,20 @@ bool ggml_is_numa(void) {
 #define HWCAP2_I8MM (1 << 13)
 #endif
 
+#if !defined(HWCAP2_SME)
+#define HWCAP2_SME (1 << 23)
+#endif
+
 static void ggml_init_arm_arch_features(void) {
 #if defined(__linux__) && defined(__aarch64__)
     uint32_t hwcap = getauxval(AT_HWCAP);
     uint32_t hwcap2 = getauxval(AT_HWCAP2);
 
-    ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
+    ggml_arm_arch_features.has_neon    = !!(hwcap & HWCAP_ASIMD);
     ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
-    ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
-    ggml_arm_arch_features.has_sve  = !!(hwcap & HWCAP_SVE);
+    ggml_arm_arch_features.has_i8mm    = !!(hwcap2 & HWCAP2_I8MM);
+    ggml_arm_arch_features.has_sve     = !!(hwcap & HWCAP_SVE);
+    ggml_arm_arch_features.has_sme     = !!(hwcap2 & HWCAP2_SME);
 
 #if defined(__ARM_FEATURE_SVE)
     ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
@@ -2420,6 +2501,11 @@ static void ggml_init_arm_arch_features(void) {
     }
     ggml_arm_arch_features.has_i8mm = oldp;
 
+    if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_sme = oldp;
+
     ggml_arm_arch_features.has_sve = 0;
     ggml_arm_arch_features.sve_cnt = 0;
 #else
@@ -2443,6 +2529,12 @@ static void ggml_init_arm_arch_features(void) {
     ggml_arm_arch_features.has_sve = 0;
     ggml_arm_arch_features.sve_cnt = 0;
 #endif
+
+#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
+    ggml_arm_arch_features.has_sme = 1;
+#else
+    ggml_arm_arch_features.has_sme = 0;
+#endif
 #endif
 }
 #endif
@@ -3967,6 +4059,57 @@ static void ggml_compute_forward_dup_bytes(
     }
 }
 
+static void ggml_compute_forward_dup_q(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+
+    size_t qk = ggml_blck_size(type);
+    const int64_t nr = ggml_nelements(src1) / qk;
+
+    // destination must be contiguous in the first dimension
+    GGML_ASSERT(nb10 == ggml_type_size(dst->type));
+    // must either have first dimension large enough to hold a row, or fully contiguous
+    GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+
+        uint32_t i = ir * qk;
+
+        const int64_t i03 = i/(ne00 * ne01 * ne02);
+        const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+        const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+        const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+        const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+        const int64_t i13 = i/(ne10 * ne11 * ne12);
+        const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+        const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+        const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+        const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+        dequantize_row_q(
+                (const void *) ((char *) src0->data + x_offset),
+                     (float *) ((char *)  dst->data + dst_offset), qk);
+    }
+}
+
 static void ggml_compute_forward_dup(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -3993,6 +4136,10 @@ static void ggml_compute_forward_dup(
             } break;
         default:
             {
+                if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_dup_q(params, dst);
+                    break;
+                }
                 GGML_ABORT("fatal error");
             }
     }
@@ -6691,20 +6838,20 @@ static void ggml_compute_forward_silu_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * grad = dst->src[1];
+    const struct ggml_tensor * grad = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
 
     assert(ggml_is_contiguous_1(grad));
-    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(src1));
     assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-    assert(ggml_are_same_shape(src0, grad));
+    assert(ggml_are_same_shape(src1, dst));
+    assert(ggml_are_same_shape(src1, grad));
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
+    const int nc = src1->ne[0];
+    const int nr = ggml_nrows(src1);
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -6716,7 +6863,7 @@ static void ggml_compute_forward_silu_back_f32(
     for (int i1 = ir0; i1 < ir1; i1++) {
         ggml_vec_silu_backward_f32(nc,
                 (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])),
+                (float *) ((char *) src1->data + i1*(src1->nb[1])),
                 (float *) ((char *) grad->data + i1*(grad->nb[1])));
 
 #ifndef NDEBUG
@@ -6895,7 +7042,7 @@ static void ggml_compute_forward_norm_f32(
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
-    GGML_ASSERT(eps > 0.0f);
+    GGML_ASSERT(eps >= 0.0f);
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -6966,7 +7113,7 @@ static void ggml_compute_forward_rms_norm_f32(
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
-    GGML_ASSERT(eps > 0.0f);
+    GGML_ASSERT(eps >= 0.0f);
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7018,12 +7165,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
+    const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
 
     GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -7042,8 +7190,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
                 const int64_t i12 = i02;
                 const int64_t i13 = i03;
 
-                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
+                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                const float * x  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
 
                 ggml_float sum_xx  = 0.0;
                 ggml_float sum_xdz = 0.0;
@@ -7066,9 +7214,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
                 {
                     // z = rms_norm(x)
                     //
-                    // rms_norm(src0) =
+                    // rms_norm(src1) =
                     //     scale(
-                    //         src0,
+                    //         src1,
                     //         div(
                     //             1,
                     //             sqrt(
@@ -7076,13 +7224,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
                     //                     scale(
                     //                         sum(
                     //                             sqr(
-                    //                                 src0)),
+                    //                                 src1)),
                     //                         (1.0/N)),
                     //                     eps))));
 
                     // postorder:
                     // ## op    args         grad
-                    // 00 param src0         grad[#00]
+                    // 00 param src1         grad[#00]
                     // 01 const 1
                     // 02 sqr   (#00)        grad[#02]
                     // 03 sum   (#02)        grad[#03]
@@ -7159,6 +7307,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
                 // dx := scale(dx, rrms)
                 float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
+                // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
                 ggml_vec_cpy_f32  (ne00, dx, x);
                 // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
                 ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@@ -7439,6 +7588,7 @@ UseGgmlGemm1:;
     if (src1->type != vec_dot_type) {
         char * wdata = params->wdata;
 
+        const size_t nbw0 = ggml_type_size(vec_dot_type);
         const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
         const size_t nbw2 = nbw1*ne11;
         const size_t nbw3 = nbw2*ne12;
@@ -7446,6 +7596,7 @@ UseGgmlGemm1:;
         assert(params->wsize >= ne13*nbw3);
         GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
+    #if 0
         for (int64_t i13 = 0; i13 < ne13; ++i13) {
             for (int64_t i12 = 0; i12 < ne12; ++i12) {
                 for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
@@ -7455,6 +7606,20 @@ UseGgmlGemm1:;
                 }
             }
         }
+    #else
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                    size_t bs = ggml_blck_size(vec_dot_type);
+                    int64_t ne10_block_start = (ith * ne10/bs) / nth;
+                    int64_t ne10_block_end   = ((ith + 1) * ne10/bs) / nth;
+                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
+                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
+                               (ne10_block_end - ne10_block_start) * bs);
+                }
+            }
+        }
+    #endif
     }
 
     if (ith == 0) {
@@ -7509,7 +7674,7 @@ UseGgmlGemm2:;
     int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
 
     // If the chunking is poor for the number of threads on this setup, scrap the whole plan.  Re-chunk it by thread.
-    //   Also, chunking by thread was measured to have perform better on NUMA systems.  See https://github.com/ggerganov/llama.cpp/pull/6915
+    //   Also, chunking by thread was measured to have perform better on NUMA systems.  See https://github.com/ggml-org/llama.cpp/pull/6915
     //   In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
     if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
         // distribute the thread work across the inner or outer loop based on which one is larger
@@ -7542,7 +7707,6 @@ UseGgmlGemm2:;
         if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
             num_rows_per_vec_dot = 1;
         }
-
         ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
 
         if (nth >= nchunk0 * nchunk1) {
@@ -7555,6 +7719,84 @@ UseGgmlGemm2:;
 
 // ggml_compute_forward_mul_mat_id
 
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
+
+struct mmid_row_mapping {
+    int32_t i1;
+    int32_t i2;
+};
+
+static void ggml_compute_forward_mul_mat_id_one_chunk(
+    struct ggml_tensor * dst,
+    const struct ggml_tensor * src0,
+    const struct ggml_tensor * src1,
+    const struct ggml_tensor * ids,
+    const int64_t cur_a,
+    const int64_t ir0_start,
+    const int64_t ir0_end,
+    const int64_t ir1_start,
+    const int64_t ir1_end,
+    const char * src0_cur,
+    const struct mmid_row_mapping * matrix_rows,
+    const size_t row_size,
+    const bool src1_cont,
+    const void * wdata) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+
+    ggml_vec_dot_t    const vec_dot      = type_traits_cpu[type].vec_dot;
+    enum ggml_type    const vec_dot_type = type_traits_cpu[type].vec_dot_type;
+
+    const int64_t blck_0 = 16;
+    const int64_t blck_1 = 16;
+
+    float tmp[16];
+
+    for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+        for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
+                const int64_t _i12 = ir1; // logical row index for this expert
+
+                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
+                const int id       = row_mapping.i1; // selected expert index
+
+                const int64_t  i11 = id % ne11;
+                const int64_t  i12 = row_mapping.i2; // row index in src1
+
+                const int64_t  i1 = id;  // selected expert index
+                const int64_t  i2 = i12; // row
+
+                // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+                //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+                //       the original src1 data pointer, so we should index using the indices directly
+                // TODO: this is a bit of a hack, we should probably have a better way to handle this
+                const char * src1_col = (const char *) wdata +
+                    (src1_cont || src1->type != vec_dot_type
+                    ? (i11      + i12*ne11)*row_size
+                    : (i11*nb11 + i12*nb12));
+
+                float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
+
+                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
+                    vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
+                }
+
+                memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
+            }
+        }
+    }
+}
+
+static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
+
+    void * ptr = *p;
+    ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
+    *p = (void *) ((char *) ptr + size);
+    return ptr;
+}
+
 static void ggml_compute_forward_mul_mat_id(
         const struct ggml_compute_params * params,
               struct ggml_tensor * dst) {
@@ -7572,7 +7814,6 @@ static void ggml_compute_forward_mul_mat_id(
 
     const bool src1_cont = ggml_is_contiguous(src1);
 
-    ggml_vec_dot_t    const vec_dot         = type_traits_cpu[type].vec_dot;
     enum ggml_type    const vec_dot_type    = type_traits_cpu[type].vec_dot_type;
     ggml_from_float_t const from_float      = type_traits_cpu[vec_dot_type].from_float;
 
@@ -7590,21 +7831,27 @@ static void ggml_compute_forward_mul_mat_id(
     const int n_ids = ids->ne[0]; // n_expert_used
     const int n_as  = ne02;       // n_expert
 
-    char * wdata_src1_end = (src1->type == vec_dot_type) ?
-            (char *) params->wdata :
-            (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
+    void * wdata_cur = params->wdata;
 
-    struct mmid_row_mapping {
-        int32_t i1;
-        int32_t i2;
-    };
+    if (src1->type != vec_dot_type) {
+        incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
+    }
 
-    int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
-    struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
+    int64_t * matrix_row_counts = // [n_as]
+        incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
+
+    struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
+        incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
+
+    char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
+        incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
+
+    GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
 
     if (src1->type != vec_dot_type) {
         char * wdata = params->wdata;
 
+        const size_t nbw0 = ggml_type_size(vec_dot_type);
         const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
         const size_t nbw2 = nbw1*ne11;
         const size_t nbw3 = nbw2*ne12;
@@ -7612,19 +7859,32 @@ static void ggml_compute_forward_mul_mat_id(
         assert(params->wsize >= ne13*nbw3);
         GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
+#if 0
         for (int64_t i13 = 0; i13 < ne13; ++i13) {
-            for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+            for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
+                for (int64_t i11 = 0; i11 < ne11; ++i11) {
                     from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
                                (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
                                ne10);
                 }
             }
         }
+#else
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                    size_t bs = ggml_blck_size(vec_dot_type);
+                    int64_t ne10_block_start = (ith * ne10/bs) / nth;
+                    int64_t ne10_block_end   = ((ith + 1) * ne10/bs) / nth;
+                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
+                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
+                               (ne10_block_end - ne10_block_start) * bs);
+                }
+            }
+        }
+#endif
     }
 
-#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
-
     if (ith == 0) {
         // initialize matrix_row_counts
         memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -7642,9 +7902,14 @@ static void ggml_compute_forward_mul_mat_id(
         }
     }
 
+    // reset current_chunk
+    for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
+        atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
+        *current_chunk_ctr = nth;
+    }
+
     ggml_barrier(params->threadpool);
 
-    // compute each matrix multiplication in sequence
     for (int cur_a = 0; cur_a < n_as; ++cur_a) {
         const int64_t cne1 = matrix_row_counts[cur_a];
 
@@ -7652,84 +7917,64 @@ static void ggml_compute_forward_mul_mat_id(
             continue;
         }
 
-        const char * src0_cur = (const char *) src0->data + cur_a*nb02;
-
-        const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+        const char * src0_cur = (const char *) src0->data + cur_a * nb02;
+        const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
         const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
-        const int64_t nr0 = ne01; // src0 rows
-        const int64_t nr1 = cne1; // src1 rows
+        const int64_t nr0 = ne01;
+        const int64_t nr1 = cne1;
 
-        // distribute the thread work across the inner or outer loop based on which one is larger
+        int chunk_size = 16;
+        if (nr0 == 1 || nr1 == 1) {
+            chunk_size = 64;
+        }
 
-        const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
-        const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+#if defined(__aarch64__)
+        // disable for ARM
+        const bool disable_chunking = true;
+#else
+        // disable for NUMA
+        const bool disable_chunking = ggml_is_numa();
+#endif // defined(__aarch64__)
 
-        const int64_t ith0 = ith % nth0;
-        const int64_t ith1 = ith / nth0;
+        int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
+        int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
 
-        const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
-        const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
+        if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
+            nchunk0 = nr0 > nr1 ? nth : 1;
+            nchunk1 = nr0 > nr1 ? 1 : nth;
+        }
 
-        const int64_t ir010 = dr0*ith0;
-        const int64_t ir011 = MIN(ir010 + dr0, nr0);
+        const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+        const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
 
-        const int64_t ir110 = dr1*ith1;
-        const int64_t ir111 = MIN(ir110 + dr1, nr1);
+        int current_chunk = ith;
 
-        // threads with no work simply yield (not sure if it helps)
-        //if (ir010 >= ir011 || ir110 >= ir111) {
-        //    sched_yield();
-        //    continue;
-        //}
+        atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
 
-        // block-tiling attempt
-        const int64_t blck_0 = 16;
-        const int64_t blck_1 = 16;
+        while (current_chunk < nchunk0 * nchunk1) {
+            const int64_t ith0 = current_chunk % nchunk0;
+            const int64_t ith1 = current_chunk / nchunk0;
 
-        // attempt to reduce false-sharing (does not seem to make a difference)
-        float tmp[16];
+            const int64_t ir0_start = dr0 * ith0;
+            const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
 
-        for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
-            for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
-                for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
-                    const int64_t _i12 = ir1; // logical row index for this expert
+            const int64_t ir1_start = dr1 * ith1;
+            const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
 
-                    struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
-                    const int id       = row_mapping.i1; // selected expert index
+            ggml_compute_forward_mul_mat_id_one_chunk(
+                dst, src0, src1, ids, cur_a,
+                ir0_start, ir0_end, ir1_start, ir1_end,
+                src0_cur, matrix_rows, row_size, src1_cont, wdata
+            );
 
-                    const int64_t  i11 = id % ne11;
-                    const int64_t  i12 = row_mapping.i2; // row index in src1
-
-                    const int64_t  i1 = id;  // selected expert index
-                    const int64_t  i2 = i12; // row
-
-                    // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
-                    //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
-                    //       the original src1 data pointer, so we should index using the indices directly
-                    // TODO: this is a bit of a hack, we should probably have a better way to handle this
-                    const char * src1_col = (const char *) wdata +
-                        (src1_cont || src1->type != vec_dot_type
-                        ? (i11      + i12*ne11)*row_size
-                        : (i11*nb11 + i12*nb12));
-
-                    float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
-
-                    //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                    //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
-                    //}
-
-                    for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
-                    }
-
-                    memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
-                }
+            if (nth >= nchunk0 * nchunk1) {
+                break;
             }
+
+            current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
         }
     }
-
-#undef MMID_MATRIX_ROW
 }
 
 // ggml_compute_forward_out_prod
@@ -7750,12 +7995,13 @@ static void ggml_compute_forward_out_prod_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_ASSERT(ne0  == ne00);
-    GGML_ASSERT(ne1  == ne10);
-    GGML_ASSERT(ne2  == ne02);
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne3  == ne13);
-    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
+    GGML_ASSERT(ne2 % ne02 == 0);
+    GGML_ASSERT(ne3 % ne03 == 0);
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == sizeof(float));
@@ -7797,6 +8043,10 @@ static void ggml_compute_forward_out_prod_f32(
     const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
     const int64_t blck_1 = 16;
 
+    // dps == dst per src0, used for group query attention
+    const int64_t dps2 = ne2 / ne02;
+    const int64_t dps3 = ne3 / ne03;
+
     for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
         const int64_t bir1 = MIN(bir + blck_1, ir1);
         for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@@ -7807,8 +8057,8 @@ static void ggml_compute_forward_out_prod_f32(
                 const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
                 const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
 
-                const int64_t i02 = i2;
-                const int64_t i03 = i3;
+                const int64_t i02 = i2 / dps2;
+                const int64_t i03 = i3 / dps3;
 
                 //const int64_t i10 = i1;
                 const int64_t i12 = i2;
@@ -7821,7 +8071,7 @@ static void ggml_compute_forward_out_prod_f32(
 
                     float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
                     float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
 
                     ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
                 }
@@ -7830,7 +8080,7 @@ static void ggml_compute_forward_out_prod_f32(
 
                     float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
                     float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
 
                     ggml_vec_mad_f32(ne0, d, s0, *s1);
                 }
@@ -8906,9 +9156,9 @@ static void ggml_compute_forward_soft_max(
 }
 
 
-// ggml_compute_forward_soft_max_back
+// ggml_compute_forward_soft_max_ext_back
 
-static void ggml_compute_forward_soft_max_back_f32(
+static void ggml_compute_forward_soft_max_ext_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
@@ -8921,6 +9171,14 @@ static void ggml_compute_forward_soft_max_back_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_are_same_shape(src1, dst));
 
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(max_bias == 0.0f);
+
     // TODO: handle transposed/permuted matrices
 
     const int ith = params->ith;
@@ -8969,10 +9227,11 @@ static void ggml_compute_forward_soft_max_back_f32(
 
         // linear runtime, no additional memory
         float dot_y_dy = 0;
-        ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
-        ggml_vec_cpy_f32 (nc, dx, dy);
-        ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
-        ggml_vec_mul_f32 (nc, dx, dx, y);
+        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
+        ggml_vec_cpy_f32  (nc, dx, dy);
+        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
+        ggml_vec_mul_f32  (nc, dx, dx, y);
+        ggml_vec_scale_f32(nc, dx, scale);
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
@@ -8983,7 +9242,7 @@ static void ggml_compute_forward_soft_max_back_f32(
     }
 }
 
-static void ggml_compute_forward_soft_max_back(
+static void ggml_compute_forward_soft_max_ext_back(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
@@ -8992,7 +9251,7 @@ static void ggml_compute_forward_soft_max_back(
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_soft_max_back_f32(params, dst);
+                ggml_compute_forward_soft_max_ext_back_f32(params, dst);
             } break;
         default:
             {
@@ -9009,10 +9268,6 @@ static void ggml_compute_forward_clamp_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->ith != 0) {
-        return;
-    }
-
     float min;
     float max;
     memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
@@ -9985,9 +10240,10 @@ static void ggml_compute_forward_im2col_back_f32(
         const struct ggml_compute_params * params,
               struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+    const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
 
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -10009,11 +10265,11 @@ static void ggml_compute_forward_im2col_back_f32(
     const int64_t IH = is_2D ? ne1 : 1;
     const int64_t IW = ne0;
 
-    const int64_t KH = is_2D ? ne01 : 1;
-    const int64_t KW = ne00;
+    const int64_t KH = is_2D ? ne11 : 1;
+    const int64_t KW = ne10;
 
-    const int64_t OH = is_2D ? ne12 : 1;
-    const int64_t OW = ne11;
+    const int64_t OH = is_2D ? ne02 : 1;
+    const int64_t OW = ne01;
 
     int ofs0 = is_2D ? nb3 : nb2;
     int ofs1 = is_2D ? nb2 : nb1;
@@ -10059,9 +10315,9 @@ static void ggml_compute_forward_im2col_back_f32(
                                     continue;
                                 }
 
-                                const float * const src_data = (const float *) src1->data
+                                const float * const grad_in = (const float *) src0->data
                                     + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                                grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
+                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
                             }
                         }
                         float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@@ -11856,9 +12112,9 @@ static void ggml_compute_forward_add_rel_pos(
 static void ggml_compute_forward_rwkv_wkv6_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    const int64_t T = dst->src[1]->ne[3];
+    const int64_t T = dst->src[1]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t HEADS = dst->src[1]->ne[2];
+    const int64_t HEADS = dst->src[1]->ne[1];
     const int64_t n_seqs = dst->src[5]->ne[1];
     const int64_t head_size = C / HEADS;
 
@@ -12053,6 +12309,197 @@ static void ggml_compute_forward_rwkv_wkv6(
     }
 }
 
+// ggml_compute_forward_gla
+
+static void ggml_compute_forward_gla_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[1];
+    const int64_t n_seqs = dst->src[4]->ne[1];
+    const int64_t head_size = C / HEADS;
+    const float scale = ggml_get_op_params_f32(dst, 0);
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * k = (float *) dst->src[0]->data;
+    float * v = (float *) dst->src[1]->data;
+    float * q = (float *) dst->src[2]->data;
+    float * g = (float *) dst->src[3]->data;
+
+    size_t t_stride = HEADS * head_size; // Same to C
+
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
+
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
+
+
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define GLA_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define GLA_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define GLA_VECTOR_SIZE 4
+    #endif
+
+    #ifdef GLA_VECTOR_SIZE
+        const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
+                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * GLA_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = prev_state * g + kv
+                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
+
+                        // Update dst: dst += temp * q
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
+                    }
+
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val + prev_state_val * g_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+
+    #else
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = prev_state_val * g_val + kv_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+    #endif
+}
+
+
+static void ggml_compute_forward_gla(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gla_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -12346,22 +12793,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * opt0 = dst->src[2];
+    const struct ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output
+    const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
+    const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
 
     GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(opt0));
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(src0f));
+    GGML_ASSERT(ggml_is_contiguous(src1f));
+    GGML_ASSERT(ggml_is_contiguous(grad));
+    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
 
     const int64_t ith = params->ith;
     const int64_t nth = params->nth;
 
     // TODO: handle transposed/permuted matrices
-    const int64_t nc = src0->ne[0];
-    const int64_t nr = ggml_nrows(src0);
+    const int64_t nc = src0f->ne[0];
+    const int64_t nr = ggml_nrows(src0f);
 
     // rows per thread
     const int64_t dr = (nr + nth - 1)/nth;
@@ -12370,12 +12817,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
     const int64_t ir0 = dr*ith;
     const int64_t ir1 = MIN(ir0 + dr, nr);
 
-    const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
+    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
 
     for (int64_t i1 = ir0; i1 < ir1; i1++) {
-        float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
-        float * s0  = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);
+        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
+        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
 
 #ifndef NDEBUG
         for (int64_t i = 0; i < nc; ++i) {
@@ -12388,11 +12835,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         // soft_max
         float max = -INFINITY;
         ggml_vec_max_f32(nc, &max, s0);
-        ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
+        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
         assert(sum > 0.0);
         ggml_vec_scale_f32(nc, ds0, 1.0/sum);
 
-        // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
+        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
         ggml_vec_sub_f32(nc, ds0, ds0, s1);
         ggml_vec_scale_f32(nc, ds0, d_by_nr);
 
@@ -12689,7 +13136,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SOFT_MAX_BACK:
             {
-                ggml_compute_forward_soft_max_back(params, tensor);
+                ggml_compute_forward_soft_max_ext_back(params, tensor);
             } break;
         case GGML_OP_ROPE:
             {
@@ -12806,6 +13253,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rwkv_wkv6(params, tensor);
             } break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            {
+                ggml_compute_forward_gla(params, tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -13105,6 +13556,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
@@ -13513,14 +13965,19 @@ struct ggml_cplan ggml_graph_plan(
                         cur = 0;
                         const struct ggml_tensor * src0 = node->src[0];
                         const struct ggml_tensor * src1 = node->src[1];
+                        const struct ggml_tensor * ids = node->src[2];
                         const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
-                        if (src1->type != vec_dot_type) {
-                            cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
-                        }
                         const int n_as = src0->ne[2];
-                        cur += GGML_PAD(cur, sizeof(int64_t));       // align
-                        cur += n_as * sizeof(int64_t);               // matrix_row_counts
-                        cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
+                        // src1
+                        if (src1->type != vec_dot_type) {
+                            cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);
+                        }
+                        // matrix_row_counts
+                        cur += n_as * sizeof(int64_t) + sizeof(int64_t);
+                        // matrix_rows
+                        cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
+                        // atomic_current_chunk
+                        cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
                     } break;
                 case GGML_OP_OUT_PROD:
                     {
@@ -13530,6 +13987,7 @@ struct ggml_cplan ggml_graph_plan(
                     } break;
                 case GGML_OP_SOFT_MAX:
                 case GGML_OP_ROPE:
+                case GGML_OP_ROPE_BACK:
                     {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                     } break;
@@ -13640,20 +14098,24 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         /*.threadpool=*/ tp,
     };
 
-    for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
+    for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
         struct ggml_tensor * node = cgraph->nodes[node_n];
 
         ggml_compute_forward(¶ms, node);
 
         if (state->ith == 0 && cplan->abort_callback &&
                 cplan->abort_callback(cplan->abort_callback_data)) {
-            tp->abort = true;
+            atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
             tp->ec    = GGML_STATUS_ABORTED;
         }
 
-        ggml_barrier(state->threadpool);
+        if (node_n + 1 < cgraph->n_nodes) {
+            ggml_barrier(state->threadpool);
+        }
     }
 
+    ggml_barrier(state->threadpool);
+
     return 0;
 }
 
@@ -13820,7 +14282,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
         threadpool->current_chunk    = 0;
         threadpool->stop             = false;
         threadpool->pause            = tpp->paused;
-        threadpool->abort            = false;
+        threadpool->abort            = -1;
         threadpool->workers          = NULL;
         threadpool->n_threads_max    = tpp->n_threads;
         threadpool->n_threads_cur    = tpp->n_threads;
@@ -13899,7 +14361,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
         threadpool->cgraph           = cgraph;
         threadpool->cplan            = cplan;
         threadpool->current_chunk    = 0;
-        threadpool->abort            = false;
+        threadpool->abort            = -1;
         threadpool->ec               = GGML_STATUS_SUCCESS;
     }
 
@@ -14098,6 +14560,14 @@ int ggml_cpu_has_vsx(void) {
 #endif
 }
 
+int ggml_cpu_has_vxe(void) {
+#if defined(__VXE__) || defined(__VXE2__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_neon(void) {
 #if defined(__ARM_ARCH) && defined(__ARM_NEON)
     return ggml_arm_arch_features.has_neon;
@@ -14138,6 +14608,14 @@ int ggml_cpu_get_sve_cnt(void) {
 #endif
 }
 
+int ggml_cpu_has_sme(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
+    return ggml_arm_arch_features.has_sme;
+#else
+    return 0;
+#endif
+}
+
 void ggml_cpu_init(void) {
     // needed to initialize f16 tables
     {
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
index f11399cc6..a84203f29 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
@@ -14,6 +14,10 @@
 #include "ggml-cpu-hbm.h"
 #endif
 
+#ifdef GGML_USE_CPU_KLEIDIAI
+#include "kleidiai/kleidiai.h"
+#endif
+
 #if defined(__APPLE__)
 #include 
 #include 
@@ -39,6 +43,12 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type
         }
 #endif
 
+#ifdef GGML_USE_CPU_KLEIDIAI
+        if (ggml_backend_cpu_kleidiai_buffer_type()) {
+            bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
+        }
+#endif
+
 #ifdef GGML_USE_CPU_AARCH64
         if (ggml_backend_cpu_aarch64_buffer_type()) {
             bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
@@ -284,14 +294,14 @@ struct ggml_backend_cpu_device_context {
                         &hKey) == ERROR_SUCCESS) {
             DWORD cpu_brand_size = 0;
             if (RegQueryValueExA(hKey,
-                                TEXT("ProcessorNameString"),
+                                "ProcessorNameString",
                                 NULL,
                                 NULL,
                                 NULL,
                                 &cpu_brand_size) == ERROR_SUCCESS) {
                 description.resize(cpu_brand_size);
                 if (RegQueryValueExA(hKey,
-                                    TEXT("ProcessorNameString"),
+                                    "ProcessorNameString",
                                     NULL,
                                     NULL,
                                     (LPBYTE)&description[0], // NOLINT
@@ -403,12 +413,21 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
                 op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
             return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
-        case GGML_OP_ROPE_BACK:
-            return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
+        case GGML_OP_SOFT_MAX_BACK: {
+            if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
+                return false;
+            }
+            float max_bias = 0.0f;
+
+            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
+
+            return max_bias == 0.0f;
+        }
         case GGML_OP_IM2COL_BACK:
             return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
         case GGML_OP_OUT_PROD:
-            return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
+            return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
+                src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         default:
             return true;
     }
@@ -525,19 +544,22 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
         if (ggml_cpu_has_dotprod()) {
             features.push_back({ "DOTPROD", "1" });
         }
-        if (ggml_cpu_has_matmul_int8()) {
-            features.push_back({ "MATMUL_INT8", "1" });
-        }
         if (ggml_cpu_get_sve_cnt() > 0) {
             static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
             features.push_back({ "SVE_CNT", sve_cnt.c_str() });
         }
+        if (ggml_cpu_has_sme()) {
+            features.push_back({ "SME", "1" });
+        }
         if (ggml_cpu_has_riscv_v()) {
             features.push_back({ "RISCV_V", "1" });
         }
         if (ggml_cpu_has_vsx()) {
             features.push_back({ "VSX", "1" });
         }
+        if (ggml_cpu_has_vxe()) {
+            features.push_back({ "VXE", "1" });
+        }
         if (ggml_cpu_has_wasm_simd()) {
             features.push_back({ "WASM_SIMD", "1" });
         }
@@ -553,6 +575,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
     #ifdef GGML_USE_OPENMP
         features.push_back({ "OPENMP", "1" });
     #endif
+    #ifdef GGML_USE_CPU_KLEIDIAI
+        features.push_back({ "KLEIDIAI", "1" });
+    #endif
     #ifdef GGML_USE_CPU_AARCH64
         features.push_back({ "AARCH64_REPACK", "1" });
     #endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index 3f260ce5a..e0482c593 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -54,6 +54,7 @@
 #include "ggml-quants.h"
 
 #include 
+#include 
 
 #ifdef _MSC_VER
 #define NOINLINE __declspec(noinline)
@@ -1052,6 +1053,704 @@ class tinyBLAS_Q0_AVX {
       } \
    } \
 
+template 
+class tinyBLAS_Q0_PPC {
+  public:
+    tinyBLAS_Q0_PPC(int64_t k,
+                const TA *A, int64_t lda,
+                const TB *B, int64_t ldb,
+                TC *C, int64_t ldc,
+                int ith, int nth)
+        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+    }
+
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
+    }
+
+  private:
+
+    template
+    inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
+       for (int I = 0; I < RM; I++) {
+          for (int J = 0; J < RN; J++) {
+             *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
+          }
+       }
+    }
+
+    template
+    inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array& comparray, vector float* vs, vector float* fin_res) {
+       vector signed int vec_C[4];
+       vector float CA[4] = {0};
+       vector float res[4] = {0};
+       __builtin_mma_disassemble_acc(vec_C, ACC);
+       for (int i = 0; i < 4; i++) {
+          CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
+          res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
+          fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
+       }
+    }
+
+    template
+    void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+        int64_t i, j;
+        TA *aoffset = NULL;
+        VA *vecOffset = NULL;
+        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
+        VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
+        VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
+        VB t1, t2, t3, t4, t5, t6, t7, t8;
+        vector unsigned char xor_vector;
+        uint8_t flip_vec = 0x80;
+        xor_vector = vec_splats(flip_vec);
+        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
+        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
+
+        aoffset = const_cast(a);
+        vecOffset = vec;
+        j = (rows >> 3);
+        if (j > 0) {
+            do {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            aoffset4 = aoffset3 + lda;
+            aoffset5 = aoffset4 + lda;
+            aoffset6 = aoffset5 + lda;
+            aoffset7 = aoffset6 + lda;
+            aoffset8 = aoffset7 + lda;
+            aoffset += 8 * lda;
+
+            i = (cols >> 3);
+            if (i > 0) {
+               do {
+                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
+                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
+                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
+                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
+                    C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
+                    C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
+                    C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
+                    C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
+
+                    __builtin_vsx_disassemble_pair(c1, &C1);
+                    __builtin_vsx_disassemble_pair(c2, &C2);
+                    __builtin_vsx_disassemble_pair(c3, &C3);
+                    __builtin_vsx_disassemble_pair(c4, &C4);
+                    __builtin_vsx_disassemble_pair(c5, &C5);
+                    __builtin_vsx_disassemble_pair(c6, &C6);
+                    __builtin_vsx_disassemble_pair(c7, &C7);
+                    __builtin_vsx_disassemble_pair(c8, &C8);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    t1 = vec_perm(c5[0], c6[0], swiz1);
+                    t2 = vec_perm(c5[0], c6[0], swiz2);
+                    t3 = vec_perm(c7[0], c8[0], swiz1);
+                    t4 = vec_perm(c7[0], c8[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+128);
+                    vec_xst(t6, 0, vecOffset+144);
+                    vec_xst(t7, 0, vecOffset+160);
+                    vec_xst(t8, 0, vecOffset+176);
+
+                    t1 = vec_perm(c5[1], c6[1], swiz1);
+                    t2 = vec_perm(c5[1], c6[1], swiz2);
+                    t3 = vec_perm(c7[1], c8[1], swiz1);
+                    t4 = vec_perm(c7[1], c8[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+192);
+                    vec_xst(t6, 0, vecOffset+208);
+                    vec_xst(t7, 0, vecOffset+224);
+                    vec_xst(t8, 0, vecOffset+240);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    aoffset4 += lda;
+                    aoffset5 += lda;
+                    aoffset6 += lda;
+                    aoffset7 += lda;
+                    aoffset8 += lda;
+                    vecOffset += 256;
+                    i--;
+               } while(i > 0);
+            }
+            j--;
+        } while(j > 0);
+    }
+
+    if (rows & 4) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            aoffset4 = aoffset3 + lda;
+            aoffset += 4 * lda;
+
+        i = (cols >> 3);
+            if (i > 0) {
+               do {
+                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
+                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
+                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
+                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
+
+                    __builtin_vsx_disassemble_pair(c1, &C1);
+                    __builtin_vsx_disassemble_pair(c2, &C2);
+                    __builtin_vsx_disassemble_pair(c3, &C3);
+                    __builtin_vsx_disassemble_pair(c4, &C4);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    aoffset4 += lda;
+                    vecOffset += 128;
+                    i--;
+               } while(i > 0);
+            }
+        }
+        if (rows & 3) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            i = (cols >> 3);
+        if (i > 0) {
+                do {
+                    switch(rows) {
+                        case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
+                                __builtin_vsx_disassemble_pair(c3, &C3);
+                        case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
+                                __builtin_vsx_disassemble_pair(c2, &C2);
+                        case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
+                                __builtin_vsx_disassemble_pair(c1, &C1);
+                                break;
+                    }
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    vecOffset += 128;
+                    i--;
+               } while(i > 0);
+            }
+        }
+    }
+
+    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        int m_rem = MIN(m - m0, 8);
+        int n_rem = MIN(n - n0, 8);
+        // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
+        // issues. After resolving them, below code will be enabled.
+        /*if (m_rem >= 16 && n_rem >= 8) {
+            mc = 16;
+            nc = 8;
+            gemm<16,8>(m0, m, n0, n);
+        } else if(m_rem >= 8 && n_rem >= 16) {
+            mc = 8;
+            nc = 16;
+            gemm<8,16>(m0, m, n0, n);
+        }*/
+        if (m_rem >= 8 && n_rem >= 8) {
+            mc = 8;
+            nc = 8;
+            gemm<8,8>(m0, m, n0, n);
+        } else if (m_rem >= 4 && n_rem >= 8) {
+            mc = 4;
+            nc = 8;
+            gemm<4,8>(m0, m, n0, n);
+        } else if (m_rem >= 8 && n_rem >= 4) {
+            mc = 8;
+            nc = 4;
+            gemm<8,4>(m0, m, n0, n);
+        } else if (m_rem >= 4 && n_rem >= 4) {
+            mc = 4;
+            nc = 4;
+            gemm_small<4, 4>(m0, m, n0, n);
+        } else if ((m_rem < 4) && (n_rem > 4)) {
+            nc = 4;
+            switch(m_rem) {
+                case 1:
+                    mc = 1;
+                    gemm_small<1, 4>(m0, m, n0, n);
+                    break;
+                case 2:
+                    mc = 2;
+                    gemm_small<2, 4>(m0, m, n0, n);
+                    break;
+                case 3:
+                    mc = 3;
+                    gemm_small<3, 4>(m0, m, n0, n);
+                    break;
+                default:
+                    return;
+            }
+        } else if ((m_rem > 4) && (n_rem < 4)) {
+            mc = 4;
+            switch(n_rem) {
+                case 1:
+                    nc = 1;
+                    gemm_small<4, 1>(m0, m, n0, n);
+                    break;
+                case 2:
+                    nc = 2;
+                    gemm_small<4, 2>(m0, m, n0, n);
+                    break;
+                case 3:
+                    nc = 3;
+                    gemm_small<4, 3>(m0, m, n0, n);
+                    break;
+                default:
+                    return;
+            }
+        } else {
+            switch((m_rem << 4) | n_rem) {
+                case 0x43:
+                    mc = 4;
+                    nc = 3;
+                    gemm_small<4, 3>(m0, m, n0, n);
+                    break;
+                case 0x42:
+                    mc = 4;
+                    nc = 2;
+                    gemm_small<4, 2>(m0, m, n0, n);
+                    break;
+                case 0x41:
+                    mc = 4;
+                    nc = 1;
+                    gemm_small<4, 1>(m0, m, n0, n);
+                    break;
+                case 0x34:
+                    mc = 3;
+                    nc = 4;
+                    gemm_small<3, 4>(m0, m, n0, n);
+                    break;
+                case 0x33:
+                    mc = 3;
+                    nc = 3;
+                    gemm_small<3, 3>(m0, m, n0, n);
+                    break;
+                case 0x32:
+                    mc = 3;
+                    nc = 2;
+                    gemm_small<3, 2>(m0, m, n0, n);
+                    break;
+                case 0x31:
+                    mc = 3;
+                    nc = 1;
+                    gemm_small<3, 1>(m0, m, n0, n);
+                    break;
+                case 0x24:
+                    mc = 2;
+                    nc = 4;
+                    gemm_small<2, 4>(m0, m, n0, n);
+                    break;
+                case 0x23:
+                    mc = 2;
+                    nc = 3;
+                    gemm_small<2, 3>(m0, m, n0, n);
+                    break;
+                case 0x22:
+                    mc = 2;
+                    nc = 2;
+                    gemm_small<2, 2>(m0, m, n0, n);
+                    break;
+                case 0x21:
+                    mc = 2;
+                    nc = 1;
+                    gemm_small<2, 1>(m0, m, n0, n);
+                    break;
+                case 0x14:
+                    mc = 1;
+                    nc = 4;
+                    gemm_small<1, 4>(m0, m, n0, n);
+                    break;
+                case 0x13:
+                    mc = 1;
+                    nc = 3;
+                    gemm_small<1, 3>(m0, m, n0, n);
+                    break;
+                case 0x12:
+                    mc = 1;
+                    nc = 2;
+                    gemm_small<1, 2>(m0, m, n0, n);
+                    break;
+                case 0x11:
+                    mc = 1;
+                    nc = 1;
+                    gemm_small<1, 1>(m0, m, n0, n);
+                    break;
+                default:
+                    return;
+            }
+        }
+        mp = m0 + (m - m0) / mc * mc;
+        np = n0 + (n - n0) / nc * nc;
+        mnpack(mp, m, n0, np);
+        mnpack(m0, m, np, n);
+    }
+
+    void KERNEL_4x8(int64_t ii, int64_t jj) {
+        vec_t vec_A[8], vec_B[16] = {0};
+        acc_t acc_0, acc_1;
+        std::array comparray;
+        vector float fin_res[8] = {0};
+        vector float vs[8] = {0};
+        for (int l = 0; l < k; l++) {
+            __builtin_mma_xxsetaccz(&acc_0);
+            __builtin_mma_xxsetaccz(&acc_1);
+            packNormal((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            for(int x = 0; x < 8; x++) {
+                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
+            }
+            for (int I = 0; I<4; I++) {
+                for (int J = 0; J<4; J++) {
+                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+                    *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+                }
+            }
+            auto aoffset = A+(ii*lda)+l;
+            for (int i = 0; i < 4; i++) {
+                comparray[i] = 0;
+                int ca = 0;
+                const int8_t *at = aoffset->qs;
+                for (int j = 0; j < 32; j++)
+                    ca += (int)*at++;
+                comparray[i] = ca;
+                aoffset += lda;
+            }
+            compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
+            compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
+        }
+        save_res<4, 4>(ii, jj, 0, fin_res);
+        save_res<4, 4>(ii, jj+4, 4, fin_res);
+    }
+
+    void KERNEL_8x4(int64_t ii, int64_t jj) {
+        vec_t vec_A[16], vec_B[8] = {0};
+        acc_t acc_0, acc_1;
+        std::array comparray;
+        vector float fin_res[8] = {0};
+        vector float vs[8] = {0};
+        for (int l = 0; l < k; l++) {
+            __builtin_mma_xxsetaccz(&acc_0);
+            __builtin_mma_xxsetaccz(&acc_1);
+            packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
+            for(int x = 0; x < 8; x++) {
+                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
+            }
+            for (int I = 0; I<8; I++) {
+                for (int J = 0; J<4; J++) {
+                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+                }
+            }
+            auto aoffset = A+(ii*lda)+l;
+            for (int i = 0; i < 8; i++) {
+                comparray[i] = 0;
+                int ca = 0;
+                const int8_t *at = aoffset->qs;
+                for (int j = 0; j < 32; j++)
+                    ca += (int)*at++;
+                comparray[i] = ca;
+                aoffset += lda;
+            }
+            compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
+            compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
+        }
+        save_res<4, 4>(ii, jj, 0, fin_res);
+        save_res<4, 4>(ii+4, jj, 4, fin_res);
+    }
+
+    void KERNEL_8x8(int64_t ii, int64_t jj) {
+        vec_t vec_A[16], vec_B[16] = {0};
+        acc_t acc_0, acc_1, acc_2, acc_3;
+        std::array comparray;
+        vector float fin_res[16] = {0};
+        vector float vs[16] = {0};
+        for (int l = 0; l < k; l++) {
+            __builtin_mma_xxsetaccz(&acc_0);
+            __builtin_mma_xxsetaccz(&acc_1);
+            __builtin_mma_xxsetaccz(&acc_2);
+            __builtin_mma_xxsetaccz(&acc_3);
+            packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            for(int x = 0; x < 8; x++) {
+                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
+                __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
+            }
+            for (int I = 0; I<8; I++) {
+                for (int J = 0; J<4; J++) {
+                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+                    *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+                }
+            }
+            auto aoffset = A+(ii*lda)+l;
+            for (int i = 0; i < 8; i++) {
+                comparray[i] = 0;
+                int ca = 0;
+                const int8_t *at = aoffset->qs;
+                for (int j = 0; j < 32; j++)
+                    ca += (int)*at++;
+                comparray[i] = ca;
+                aoffset += lda;
+            }
+            compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
+            compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
+            compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
+            compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
+        }
+        save_res<4, 4>(ii, jj, 0, fin_res);
+        save_res<4, 4>(ii+4, jj, 4, fin_res);
+        save_res<4, 4>(ii, jj+4, 8, fin_res);
+        save_res<4, 4>(ii+4, jj+4, 12, fin_res);
+    }
+
+    template
+    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        vec_t vec_A[8], vec_B[8] = {0};
+        vector signed int vec_C[4];
+        acc_t acc_0;
+
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            std::array comparray;
+            vector float res[4] = {0};
+            vector float fin_res[4] = {0};
+            vector float vs[4] = {0};
+            vector float CA[4] = {0};
+            __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
+            __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
+            for (int l = 0; l < k; l++) {
+                __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_mma_xxsetaccz(&acc_0);
+                packNormal((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
+                for(int x = 0; x < 8; x+=4) {
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
+                }
+                for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d));
+                    }
+                }
+                __builtin_mma_disassemble_acc(vec_C, &acc_0);
+                auto aoffset = A+(ii*lda)+l;
+                for (int i = 0; i < RM; i++) {
+                    comparray[i] = 0;
+                    int ca = 0;
+                    const int8_t *at = aoffset->qs;
+                    for (int j = 0; j < 32; j++)
+                        ca += (int)*at++;
+                    comparray[i] = ca;
+                    aoffset += lda;
+                }
+
+                for (int i = 0; i < RM; i++) {
+                    CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
+                    res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
+                    fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
+                }
+            }
+            save_res(ii, jj, 0, fin_res);
+        }
+    }
+
+    template
+    inline void kernel(int64_t ii, int64_t jj) {
+       if constexpr(RM == 4 && RN == 8) {
+          KERNEL_4x8(ii,jj);
+       } else if constexpr(RM == 8 && RN == 4) {
+          KERNEL_8x4(ii,jj);
+       } else if constexpr(RM == 8 && RN == 8) {
+          KERNEL_8x8(ii,jj);
+       } else {
+          static_assert(false, "RN/RM values not supported");
+       }
+    }
+
+    template 
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            kernel(ii, jj);
+        }
+    }
+
+    const TA *const A;
+    const TB *const B;
+    TC *C;
+    TA *At;
+    TB *Bt;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+    const int ith;
+    const int nth;
+};
+
 template 
 class tinyBLAS_PPC {
   public:
@@ -1071,13 +1770,17 @@ class tinyBLAS_PPC {
 
     void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
 
-    void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
+    template
+    void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
         int64_t i, j;
-        float *aoffset = NULL, *boffset = NULL;
-        float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
-
-        aoffset = const_cast(a);
+        TA *aoffset = NULL, *boffset = NULL;
+        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
+        VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
+        VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
+        VA t1, t2, t3, t4, t5, t6, t7, t8;
+        aoffset = const_cast(a);
         boffset = vec;
         j = (rows >> 3);
         if (j > 0) {
@@ -1093,9 +1796,6 @@ class tinyBLAS_PPC {
                 aoffset += 8 * lda;
                 i = (cols >> 3);
                 if (i > 0) {
-                    __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
-                    vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
-                    vector float t1, t2, t3, t4, t5, t6, t7, t8;
                     do {
                         C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
                         C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1175,21 +1875,19 @@ class tinyBLAS_PPC {
                     } while(i > 0);
                 }
                 if (cols & 4) {
-                    vector float c1, c2, c3, c4, c5, c6, c7, c8;
-                    vector float t1, t2, t3, t4, t5, t6, t7, t8;
-                    c1 = vec_xl(0, aoffset1);
-                    c2 = vec_xl(0, aoffset2);
-                    c3 = vec_xl(0, aoffset3);
-                    c4 = vec_xl(0, aoffset4);
-                    c5 = vec_xl(0, aoffset5);
-                    c6 = vec_xl(0, aoffset6);
-                    c7 = vec_xl(0, aoffset7);
-                    c8 = vec_xl(0, aoffset8);
+                    c1[0] = vec_xl(0, aoffset1);
+                    c2[0] = vec_xl(0, aoffset2);
+                    c3[0] = vec_xl(0, aoffset3);
+                    c4[0] = vec_xl(0, aoffset4);
+                    c5[0] = vec_xl(0, aoffset5);
+                    c6[0] = vec_xl(0, aoffset6);
+                    c7[0] = vec_xl(0, aoffset7);
+                    c8[0] = vec_xl(0, aoffset8);
 
-                    t1 = vec_mergeh(c1, c2);
-                    t2 = vec_mergeh(c3, c4);
-                    t3 = vec_mergeh(c5, c6);
-                    t4 = vec_mergeh(c7, c8);
+                    t1 = vec_mergeh(c1[0], c2[0]);
+                    t2 = vec_mergeh(c3[0], c4[0]);
+                    t3 = vec_mergeh(c5[0], c6[0]);
+                    t4 = vec_mergeh(c7[0], c8[0]);
                     t5 = vec_xxpermdi(t1, t2, 0);
                     t6 = vec_xxpermdi(t3, t4, 0);
                     t7 = vec_xxpermdi(t1, t2, 3);
@@ -1199,10 +1897,10 @@ class tinyBLAS_PPC {
                     vec_xst(t7, 0, boffset+8);
                     vec_xst(t8, 0, boffset+12);
 
-                    t1 = vec_mergel(c1, c2);
-                    t2 = vec_mergel(c3, c4);
-                    t3 = vec_mergel(c5, c6);
-                    t4 = vec_mergel(c7, c8);
+                    t1 = vec_mergel(c1[0], c2[0]);
+                    t2 = vec_mergel(c3[0], c4[0]);
+                    t3 = vec_mergel(c5[0], c6[0]);
+                    t4 = vec_mergel(c7[0], c8[0]);
                     t5 = vec_xxpermdi(t1, t2, 0);
                     t6 = vec_xxpermdi(t3, t4, 0);
                     t7 = vec_xxpermdi(t1, t2, 3);
@@ -1224,9 +1922,6 @@ class tinyBLAS_PPC {
             aoffset += 4 * lda;
             i = (cols >> 3);
             if (i > 0) {
-                __vector_pair C1, C2, C3, C4;
-                vector float c1[2], c2[2], c3[2], c4[2];
-                vector float t1, t2, t3, t4, t5, t6, t7, t8;
                 do {
                     C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
                     C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1273,22 +1968,20 @@ class tinyBLAS_PPC {
             }
 
             if (cols & 4) {
-                vector float c1, c2, c3, c4;
-                vector float t1, t2, t3, t4;
-                c1 = vec_xl(0, aoffset1);
-                c2 = vec_xl(0, aoffset2);
-                c3 = vec_xl(0, aoffset3);
-                c4 = vec_xl(0, aoffset4);
+                c1[0] = vec_xl(0, aoffset1);
+                c2[0] = vec_xl(0, aoffset2);
+                c3[0] = vec_xl(0, aoffset3);
+                c4[0] = vec_xl(0, aoffset4);
 
-                t1 = vec_mergeh(c1, c2);
-                t2 = vec_mergeh(c3, c4);
+                t1 = vec_mergeh(c1[0], c2[0]);
+                t2 = vec_mergeh(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset);
                 vec_xst(t4, 0, boffset+4);
 
-                t1 = vec_mergel(c1, c2);
-                t2 = vec_mergel(c3, c4);
+                t1 = vec_mergel(c1[0], c2[0]);
+                t2 = vec_mergel(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset+8);
@@ -1300,21 +1993,19 @@ class tinyBLAS_PPC {
             aoffset2 = aoffset1 + lda;
             aoffset3 = aoffset2 + lda;
             if (cols & 4) {
-                vector float c1, c2, c3, c4 = {0};
-                vector float t1, t2, t3, t4;
-                c1 = vec_xl(0, aoffset1);
-                c2 = vec_xl(0, aoffset2);
-                c3 = vec_xl(0, aoffset3);
+                c1[0] = vec_xl(0, aoffset1);
+                c2[0] = vec_xl(0, aoffset2);
+                c3[0] = vec_xl(0, aoffset3);
 
-                t1 = vec_mergeh(c1, c2);
-                t2 = vec_mergeh(c3, c4);
+                t1 = vec_mergeh(c1[0], c2[0]);
+                t2 = vec_mergeh(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset);
                 vec_xst(t4, 0, boffset+4);
 
-                t1 = vec_mergel(c1, c2);
-                t2 = vec_mergel(c3, c4);
+                t1 = vec_mergel(c1[0], c2[0]);
+                t2 = vec_mergel(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset+8);
@@ -1322,14 +2013,13 @@ class tinyBLAS_PPC {
             }
         }
     }
-
     void KERNEL_4x4(int64_t ii, int64_t jj) {
         vec_t vec_A[4], vec_B[4], vec_C[4];
         acc_t acc_0;
         __builtin_mma_xxsetaccz(&acc_0);
         for (int l = 0; l < k; l+=4) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -1344,8 +2034,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_0);
         __builtin_mma_xxsetaccz(&acc_1);
         for (int64_t l = 0; l < k; l+=4) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -1365,8 +2055,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_0);
         __builtin_mma_xxsetaccz(&acc_1);
         for (int64_t l = 0; l < k; l+=4) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -1388,8 +2078,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_2);
         __builtin_mma_xxsetaccz(&acc_3);
         for (int l = 0; l < k; l+=8) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
             for(int x = 0; x < 16; x+=2) {
                 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
                 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
@@ -1572,15 +2262,15 @@ class tinyBLAS_PPC {
             vec_t vec_A[4], vec_B[4];
             for (int l=0; l= 4 && RM == 1) {
-                    float* a = const_cast(A+(ii)*lda+l);
-                    READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+                    TA* a = const_cast(A+(ii)*lda+l);
+                    packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
                     vec_A[0] = (vec_t)vec_xl(0,a);
-                    vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
-                    vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
-                    vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
+                    vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
+                    vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
+                    vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
                 } else {
-                    READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
-                    READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+                    packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
+                    packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
                 }
                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -1590,7 +2280,7 @@ class tinyBLAS_PPC {
             __builtin_mma_disassemble_acc(vec_C, &acc_0);
             for (int I = 0; I < RM; I++) {
                 for (int J = 0; J < RN; J++) {
-                    *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
+                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
                 }
             }
        }
@@ -1813,6 +2503,20 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
+
+#elif defined(__MMA__)
+        if (n < 8 && n != 4)
+           return false;
+        if (m < 8 && m != 4)
+           return false;
+        tinyBLAS_Q0_PPC tb{
+            k, (const block_q8_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            params->ith, params->nth};
+        tb.matmul(m, n);
+        return true;
+
 #else
         return false;
 #endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
index 14761650f..96bd5a0be 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
@@ -7,7 +7,7 @@ if (CUDAToolkit_FOUND)
 
     if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
         # native == GPUs available at build time
-        # 52     == Maxwell, lowest CUDA 12 standard
+        # 50     == Maxwell, lowest CUDA 12 standard
         # 60     == P100, FP16 CUDA intrinsics
         # 61     == Pascal, __dp4a instruction (per-byte integer dot product)
         # 70     == V100, FP16 tensor cores
@@ -15,9 +15,9 @@ if (CUDAToolkit_FOUND)
         if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
             set(CMAKE_CUDA_ARCHITECTURES "native")
         elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
-            set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
+            set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
         else()
-            set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
+            set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80")
         endif()
     endif()
     message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
@@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
 
     file(GLOB   GGML_SOURCES_CUDA "*.cu")
-    file(GLOB   SRCS "template-instances/fattn-wmma*.cu")
+    file(GLOB   SRCS "template-instances/fattn-mma*.cu")
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
     file(GLOB   SRCS "template-instances/mmq*.cu")
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_VMM)
     endif()
 
+    if (NOT GGML_CUDA_FA)
+        add_compile_definitions(GGML_CUDA_NO_FA)
+    endif()
+
     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
         add_compile_definitions(GGML_CUDA_F16)
     endif()
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
index c7b6be4e2..ce4b9cfb5 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
@@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
 
 template 
 static __global__ void k_repeat_back(
-    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2) {
+    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+    const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
 
-    const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
-    const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
-    const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
+    const int64_t tid0  = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
+    const int64_t tid1  = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
+    const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
+    const int64_t tid2  = tid23 % ne2;
+    const int64_t tid3  = tid23 / ne2;
 
     if (tid0 >= ne0) {
         return;
     }
 
     T sum = 0;
-    for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
-        for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
-            for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
-                sum += src[i2*ne01*ne00 + i1*ne00 + i0];
+    for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
+        for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+            for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+                for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+                    sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
+                }
             }
         }
     }
-    dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+    dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
 }
 
 template
@@ -274,12 +279,14 @@ struct bin_bcast_cuda {
 
 template 
 static void repeat_back_cuda(
-    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
+    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+    const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
 
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
-    k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
+    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
+    k_repeat_back<<>>
+        (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
 }
 
 template
@@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const ggml_tensor * src0 = dst->src[0];
 
     GGML_ASSERT(src0->type == dst->type);
-    GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_can_repeat(dst, src0));
 
     cudaStream_t stream = ctx.stream();
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    GGML_ASSERT(src0->ne[3] == 1);
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    GGML_ASSERT(dst->ne[3] == 1);
+    GGML_ASSERT(ne2*ne3 <= (1 << 15));
+
+    const size_t ts = ggml_type_size(src0->type);
+    const size_t s00 = nb00 / ts;
+    const size_t s01 = nb01 / ts;
+    const size_t s02 = nb02 / ts;
+    const size_t s03 = nb03 / ts;
 
     switch (dst->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
             float       * dst_d  = (float       *) dst->data;
-            repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+            repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
         } break;
         default: {
             GGML_ASSERT(false);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
index 2c0a56226..adf0d3ecb 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
@@ -41,29 +41,78 @@
 #define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
 #define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
 
-#define GGML_CUDA_CC_PASCAL     600
-#define GGML_CUDA_CC_DP4A       610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#define GGML_CUDA_CC_VOLTA      700
-#define GGML_CUDA_CC_TURING     750
-#define GGML_CUDA_CC_AMPERE     800
-#define GGML_CUDA_CC_OFFSET_AMD 1000000
+#define GGML_CUDA_CC_PASCAL       600
+#define GGML_CUDA_CC_DP4A         610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define GGML_CUDA_CC_VOLTA        700
+#define GGML_CUDA_CC_TURING       750
+#define GGML_CUDA_CC_AMPERE       800
+#define GGML_CUDA_CC_ADA_LOVELACE 890
+#define GGML_CUDA_CC_OFFSET_AMD   0x1000000
 
 // GCN/CNDA, wave size is 64
-#define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 803)  // Tonga, Fiji, Polaris, minimum for fast fp16
-#define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 900)  // Vega56/64, minimum for fp16 dual issue
-#define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 906)  // MI50/Radeon VII, minimum for dp4a
-#define GGML_CUDA_CC_CDNA       (GGML_CUDA_CC_OFFSET_AMD + 908)  // MI100, minimum for MFMA, acc registers
-#define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 910)  // MI210, minimum acc register renameing
-#define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 942)  // MI300
+#define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 0x803)  // Tonga, Fiji, Polaris, minimum for fast fp16
+#define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 0x900)  // Vega56/64, minimum for fp16 dual issue
+#define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 0x906)  // MI50/Radeon VII, minimum for dp4a
+#define GGML_CUDA_CC_CDNA       (GGML_CUDA_CC_OFFSET_AMD + 0x908)  // MI100, minimum for MFMA, acc registers
+#define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x910)  // MI210, minimum acc register renameing
+#define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x942)  // MI300
 
 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
-#define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
-#define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
-#define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
+#define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
+#define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
+#define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
+
+#define GGML_CUDA_CC_IS_RDNA(cc)  (cc >= GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
+#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
+#define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
+#define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
 
 #define GGML_CUDA_CC_QY1        210
 #define GGML_CUDA_CC_QY2        220
 
+#ifdef __CUDA_ARCH_LIST__
+constexpr bool ggml_cuda_has_arch_impl(int) {
+    return false;
+}
+
+template
+constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
+    return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
+}
+
+constexpr bool ggml_cuda_has_arch(const int arch) {
+    return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
+}
+
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
+    if (cur == 0) {
+        GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
+    }
+    return cur;
+}
+
+template
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
+    if (first <= arch && first > cur) {
+        return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
+    } else {
+        return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
+    }
+}
+
+constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
+    return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
+}
+#else
+static int ggml_cuda_highest_compiled_arch(const int arch) {
+    return arch;
+}
+#endif // __CUDA_ARCH_LIST__
+
+// ---------------------------------------------------------------------------------------------------------
+
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
 #if defined(_MSC_VER)
@@ -117,11 +166,11 @@ static const char * cu_get_error_str(CUresult err) {
 #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
 #endif
 
-#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
+#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
 #else
 #define GGML_CUDA_ASSUME(x)
-#endif // CUDART_VERSION >= 11100
+#endif // CUDART_VERSION >= 11010
 
 #ifdef GGML_CUDA_F16
 typedef half dfloat; // dequantize float
@@ -131,6 +180,10 @@ typedef float dfloat; // dequantize float
 typedef float2 dfloat2;
 #endif // GGML_CUDA_F16
 
+#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
+#define GGML_USE_VMM
+#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
+
 #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 #define FP16_AVAILABLE
 #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
@@ -144,23 +197,55 @@ typedef float2 dfloat2;
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
-#define INT8_MMA_AVAILABLE
+#define NEW_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
-#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
-#define FLASH_ATTN_AVAILABLE
-#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#define CP_ASYNC_AVAILABLE
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
-static constexpr bool fast_fp16_available(const int cc) {
+#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#define FLASH_ATTN_AVAILABLE
+#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+
+static bool fp16_available(const int cc) {
+    return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
+}
+
+static bool fast_fp16_available(const int cc) {
+    return fp16_available(cc) && cc != 610;
+}
+
+// To be used for feature selection of external libraries, e.g. cuBLAS.
+static bool fast_fp16_hardware_available(const int cc) {
     return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
 }
 
-static constexpr bool fp16_mma_available(const int cc) {
+// Any FP16 tensor core instructions are available for ggml code.
+static bool fp16_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
+}
+
+// To be used for feature selection of external libraries, e.g. cuBLAS.
+static bool fp16_mma_hardware_available(const int cc) {
     return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
 }
 
-static constexpr bool int8_mma_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
+// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
+static bool new_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
+}
+
+static bool cp_async_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+}
+
+static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+    return __AMDGCN_WAVEFRONT_SIZE;
+#else
+    return 32;
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 }
 
 [[noreturn]]
@@ -186,53 +271,46 @@ static __device__ void no_device_code(
 #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
 #endif // __CUDA_ARCH__
 
+template
 static __device__ __forceinline__ int warp_reduce_sum(int x) {
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
     return __reduce_add_sync(0xffffffff, x);
 #else
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, width);
     }
     return x;
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 }
 
+template
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, width);
     }
     return x;
 }
 
+template
 static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
-        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
     }
     return a;
 }
 
+template
 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #ifdef FP16_AVAILABLE
-
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
-        reinterpret_cast(a.x) +=  __low2half(a_other);
-        reinterpret_cast(a.y) += __high2half(a_other);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
     }
     return a;
-#else
-#pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
-    }
-    return a;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
 #else
     NO_DEVICE_CODE;
@@ -240,10 +318,11 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #endif // FP16_AVAILABLE
 }
 
+template
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
     }
     return x;
 }
@@ -265,35 +344,34 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
 }
 
 static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-
-#if CUDART_VERSION >= CUDART_HMAX
+#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000
+    return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
+#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX
     return __hmax2(a, b);
-#else
+#elif !defined(GGML_USE_HIP)
     half2 ret;
     reinterpret_cast(ret.x) = __float2half(fmaxf( __low2float(a),  __low2float(b)));
     reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
     return ret;
-#endif // CUDART_VERSION >= CUDART_HMAX
-
 #else
     GGML_UNUSED(a);
     GGML_UNUSED(b);
     NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif
 }
 
+template
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
 #pragma unroll
-   for (int offset = 16; offset > 0; offset >>= 1) {
-       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
+   for (int offset = width/2; offset > 0; offset >>= 1) {
+       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
    }
    return x;
 #else
    GGML_UNUSED(x);
    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
 }
 
 #if CUDART_VERSION < CUDART_HMASK
@@ -333,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 
 #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
-#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
     return __dp4a(a, b, c);
-#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
     const int8_t * a8 = (const int8_t *) &a;
     const int8_t * b8 = (const int8_t *) &b;
     return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
-#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
 
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 }
@@ -516,6 +594,7 @@ struct ggml_cuda_device_info {
         bool    vmm;                // virtual memory support
         size_t  vmm_granularity;    // granularity of virtual memory
         size_t  total_vram;
+        int     warp_size;          // Number of threads in a dispatch
     };
 
     cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
@@ -588,7 +667,7 @@ struct ggml_tensor_extra_gpu {
 };
 
 
-#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
+#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
 #define USE_CUDA_GRAPH
 #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu b/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu
index 5eb9f08da..aafbaf803 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu
@@ -124,7 +124,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
           uint64_t   nb1,
           uint64_t   nb2,
           uint64_t   nb3){
-    static_assert(dim >= 0 && dim <= 3, "dim must be between 0 and 3");
+    static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]");
 
     const int64_t i3 = blockIdx.z;
     const int64_t i2 = blockIdx.y;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
index 5b0dfacef..795b720d6 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
@@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
         case GGML_TYPE_Q5_1:
             return dequantize_block_cuda;
         case GGML_TYPE_Q8_0:
-            if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
+            if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
                 return dequantize_block_q8_0_f16_cuda;
             }
             return dequantize_block_cuda;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cp-async.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/cp-async.cuh
new file mode 100644
index 000000000..ecb659997
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cp-async.cuh
@@ -0,0 +1,46 @@
+// Simplified API for asynchronous data loading.
+
+#include "common.cuh"
+
+// Copies data from global to shared memory, cg == cache global.
+// Both the src and dst pointers must be aligned to 16 bit.
+// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
+// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
+// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
+template 
+static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
+    static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
+#ifdef CP_ASYNC_AVAILABLE
+#if CUDART_VERSION >= 11040
+    if (preload == 256) {
+        asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    } else if (preload == 128) {
+        asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    } else if (preload == 64) {
+        asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    } else
+#endif // CUDART_VERSION >= 11040
+    {
+        asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    }
+#else
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src);
+    NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
+
+// Makes each thread wait until its asynchronous data copies are done.
+// This does NOT provide any additional synchronization.
+// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
+static __device__ __forceinline__ void cp_async_wait_all() {
+#ifdef CP_ASYNC_AVAILABLE
+    asm volatile("cp.async.wait_all;");
+#else
+    NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
index 54c0f66d2..cca2bee0b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
@@ -1,4 +1,5 @@
 #include "cpy.cuh"
+#include "dequantize.cuh"
 
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 
@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
 }
 
 static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
-    const block_q8_0 * xi = (const block_q8_0 *) cxi;
-    float * dsti = (float *) cdsti;
+    float * cdstf = (float *)(cdsti);
 
-    const float d = (float)xi->d;
-
-    for (int j = 0; j < QK8_0; j++) {
-       dsti[j] = xi->qs[j] * d;
+#pragma unroll
+    for (int j = 0; j < QK8_0; j += 2) {
+        dfloat2 dq;
+        dequantize_q8_0(cxi, 0, j, dq);
+        *(cdstf + j) = dq.x;
+        *(cdstf + j + 1) = dq.y;
     }
 }
 
@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
     memcpy(dsti->qh, &qh, sizeof(qh));
 }
 
+template
+static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
+    float * cdstf = (float *)(cdsti);
+
+#pragma unroll
+    for (int j = 0; j < qk/2; j++) {
+        dfloat2 dq;
+        dequant(cxi, 0, j, dq);
+        *(cdstf + j) = dq.x;
+        *(cdstf + j + qk/2) = dq.y;
+    }
+}
 
 static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
     if (x <= val[0]) return 0;
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_q4_0_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02,
+    const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12,
+    const int nb10, const int nb11, const int nb12, const int nb13,
+    cudaStream_t stream) {
+    const int num_blocks = ne;
+    cpy_q_f32, QK4_0><<>>(
+        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_q4_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_q4_1_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02,
+    const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12,
+    const int nb10, const int nb11, const int nb12, const int nb13,
+    cudaStream_t stream) {
+    const int num_blocks = ne;
+    cpy_q_f32, QK4_1><<>>(
+        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_q5_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_q5_0_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02,
+    const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12,
+    const int nb10, const int nb11, const int nb12, const int nb13,
+    cudaStream_t stream) {
+    const int num_blocks = ne;
+    cpy_q_f32, QK5_0><<>>(
+        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_q5_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_q5_1_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02,
+    const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12,
+    const int nb10, const int nb11, const int nb12, const int nb13,
+    cudaStream_t stream) {
+    const int num_blocks = ne;
+    cpy_q_f32, QK5_1><<>>(
+        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_iq4_nl_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
         ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
         ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
         ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
         ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
         ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
         ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
         return (void*) cpy_q_f32;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
         return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32, QK4_0>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
         return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32, QK4_1>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
         return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32, QK5_0>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
         return (void*) cpy_f32_q;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
         return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32, QK5_1>;
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
         return (void*) cpy_f32_f16;
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cross-entropy-loss.cu
index ed09406a8..0ce4afbb2 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/cross-entropy-loss.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cross-entropy-loss.cu
@@ -5,95 +5,89 @@
 #include 
 #include 
 
-static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
-    const int warp_id = threadIdx.x / WARP_SIZE;
-    const int lane_id = threadIdx.x % WARP_SIZE;
-    const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
+template 
+static __global__ void cross_entropy_loss_f32(
+        const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
+    extern __shared__ float tmp[];
 
-    const int ne_tmp = WARP_SIZE*nclasses;
-
-    extern __shared__ float tmp_all[];
-    float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
-    float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
-
-    // Each warp first loads ne_tmp logits/labels into shared memory:
-    for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
-        const int ig = i0*nclasses + i; // ig == i global
-
-        tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
-        tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
-    }
-
-    // Each thread in the warp then calculates the cross entropy loss for a single row.
-    // TODO: pad in order to avoid shared memory bank conflicts.
+    logits += int64_t(blockIdx.x)*nclasses;
+    labels += int64_t(blockIdx.x)*nclasses;
 
     // Find maximum for softmax:
-    float max = -INFINITY;
-    for (int i = 0; i < nclasses; ++i) {
-        max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
+    float max_logit = -INFINITY;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float val = logits[i];
+        max_logit = fmaxf(max_logit, val);
+
+        if (use_shared) {
+            tmp[i] = val;
+        }
     }
+    max_logit = warp_reduce_max(max_logit);
 
     // Calculate log(softmax(logits)) which is just logits - max:
     float sum = 0.0f;
-    for (int i = 0; i < nclasses; ++i) {
-        float val = tmp_logits[lane_id*nclasses + i] - max;
-        sum += expf(val);
-        tmp_logits[lane_id*nclasses + i] = val;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float logit_i = use_shared ? tmp[i] : logits[i];
+        sum += expf(logit_i - max_logit);
     }
+    sum = warp_reduce_sum(sum);
     sum = logf(sum);
 
     // log(exp(logits - max) / sum) = (logits - max) - log(sum)
     float loss = 0.0f;
-    for (int i = 0; i < nclasses; ++i) {
-        loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float logit_i = use_shared ? tmp[i] : logits[i];
+        loss += (logit_i - max_logit - sum) * labels[i];
     }
     loss = -warp_reduce_sum(loss) / (float)k;
 
-    __syncthreads();
-
-    if (lane_id == 0) {
-        tmp_all[warp_id] = loss;
-    }
-
-    __syncthreads();
-
-    if (warp_id != 0) {
-        return;
-    }
-
-    loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
-    loss = warp_reduce_sum(loss);
-
-    if (lane_id != 0) {
+    if (threadIdx.x != 0) {
         return;
     }
 
     dst[blockIdx.x] = loss;
 }
 
-static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
+template 
+static __global__ void cross_entropy_loss_back_f32(
+        const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
+        float * __restrict__ dst, const int nclasses) {
     extern __shared__ float tmp[];
 
+    logits += int64_t(blockIdx.x)*nclasses;
+    labels += int64_t(blockIdx.x)*nclasses;
+    dst    += int64_t(blockIdx.x)*nclasses;
+
     float maxval = -INFINITY;
     for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
-        const float val = logits[blockIdx.x*nclasses + i];
+        const float val = logits[i];
         maxval = fmaxf(maxval, val);
-        tmp[i] = val;
+
+        if (use_shared) {
+            tmp[i] = val;
+        }
     }
     maxval = warp_reduce_max(maxval);
 
     float sum = 0.0f;
     for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
-        const float val = expf(tmp[i] - maxval);
+        const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
         sum += val;
-        tmp[i] = val;
+
+        if (use_shared) {
+            tmp[i] = val;
+        } else {
+            dst[i] = val;
+        }
     }
     sum = warp_reduce_sum(sum);
     const float sm_scale = 1.0f/sum;
 
-    const float d_by_nrows = *loss/gridDim.x;
+    const float d_by_nrows = *grad/gridDim.x;
     for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
-        dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
+        const float val = use_shared ? tmp[i] : dst[i];
+        dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
     }
 }
 
@@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t stream = ctx.stream();
 
-    const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
-    const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
-    const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(nrows, 1, 1);
+    const size_t nbytes_shared = ne00*sizeof(float);
+
+    const int    id    = ggml_cuda_get_device();
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
 
     ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x);
 
-    cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+    if (nbytes_shared <= smpbo) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+        if (!shared_memory_limit_raised[id]) {
+            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
+            shared_memory_limit_raised[id] = true;
+        }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+    } else {
+        cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+    }
+    CUDA_CHECK(cudaGetLastError());
 
     // Combine results from individual blocks:
     sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
 }
 
 void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-    const ggml_tensor * opt0 = dst->src[2];
+    const ggml_tensor * grad  = dst->src[0];
+    const ggml_tensor * src0f = dst->src[1];
+    const ggml_tensor * src1f = dst->src[2];
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(opt0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1f->type == GGML_TYPE_F32);
+    GGML_ASSERT( grad->type == GGML_TYPE_F32);
+    GGML_ASSERT(  dst->type == GGML_TYPE_F32);
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_is_scalar(grad));
+    GGML_ASSERT(ggml_is_contiguous(src0f));
+    GGML_ASSERT(ggml_is_contiguous(src1f));
     GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
+    GGML_ASSERT(ggml_are_same_shape(src0f, dst));
 
-    const int64_t ne00  = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    const int64_t ne00  = src0f->ne[0];
+    const int64_t nrows = ggml_nrows(src0f);
 
-    const float * src0_d = (const float *) src0->data;
-    const float * src1_d = (const float *) src1->data;
-    const float * opt0_d = (const float *) opt0->data;
-    float       * dst_d  = (float       *) dst->data;
+    const float * grad_d  = (const float *) grad->data;
+    const float * src0f_d = (const float *) src0f->data;
+    const float * src1f_d = (const float *) src1f->data;
+    float       * dst_d   = (float       *) dst->data;
 
     cudaStream_t stream = ctx.stream();
 
     const dim3 blocks_dim(WARP_SIZE, 1, 1);
     const dim3 blocks_num(nrows, 1, 1);
-    const int shmem = ne00*sizeof(float);
+    const size_t nbytes_shared = ne00*sizeof(float);
 
-    cross_entropy_loss_back_f32<<>>(src0_d, src1_d, opt0_d, dst_d, ne00);
+    const int    id    = ggml_cuda_get_device();
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+    if (nbytes_shared <= smpbo) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+        if (!shared_memory_limit_raised[id]) {
+            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
+            shared_memory_limit_raised[id] = true;
+        }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+    } else {
+        cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+    }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
index ee9752da6..7b9566fb4 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
@@ -516,6 +516,96 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
+template // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_stream_k_fixup(
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
+    constexpr int ncols = ncols1*ncols2;
+
+    const int bidx0 = blockIdx.x;
+    const int j     = blockIdx.y;
+    const int c     = blockIdx.z;
+    const int jc    = j*ncols2 + c;
+    const int tid   = threadIdx.x;
+
+    const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
+
+    const int iter_k = ne11 / FATTN_KQ_STRIDE;
+    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+
+    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+
+    const bool did_not_have_any_data   = kbc0 == kbc0_stop;
+    const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
+    const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+        return;
+    }
+
+    const int channel = kbc0 / (iter_k*iter_j);
+    const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
+
+    if (jt*ncols1 + j >= ne01) {
+        return;
+    }
+
+    dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
+
+    // Load the partial result that needs a fixup:
+    float dst_val = 0.0f;
+    float max_val = 0.0f;
+    float rowsum  = 0.0f;
+    {
+        dst_val = *dst;
+
+        const float2 tmp = dst_fixup[bidx0*ncols + jc];
+        max_val = tmp.x;
+        rowsum  = tmp.y;
+    }
+
+    // Iterate over previous blocks and compute the combined results.
+    // All CUDA blocks that get here must have a previous block that needs a fixup.
+    int bidx = bidx0 - 1;
+    int kbc_stop = kbc0;
+    while(true) {
+        const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+        if (kbc == kbc_stop) { // Did not have any data.
+            bidx--;
+            kbc_stop = kbc;
+            continue;
+        }
+
+        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
+
+        const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
+
+        // Scale the current and new value accumulators depending on the max. values.
+        const float max_val_new = fmaxf(max_val, tmp.x);
+
+        const float diff_val = max_val - max_val_new;
+        const float diff_add = tmp.x   - max_val_new;
+
+        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
+        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
+
+        dst_val = scale_val*dst_val + scale_add*dst_add;
+        rowsum  = scale_val*rowsum  + scale_add*tmp.y;
+
+        max_val = max_val_new;
+
+        // If this block started in a previous tile we are done and don't need to combine additional partial results.
+        if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+            break;
+        }
+        bidx--;
+        kbc_stop = kbc;
+    }
+
+    // Write back final result:
+    *dst = dst_val / rowsum;
+}
+
 template // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
@@ -581,11 +671,14 @@ static void on_no_fattn_vec_case(const int D) {
     }
 }
 
-template 
+// parallel_blocks == 0 is stream-k decomposition
+template 
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
-    const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
+    const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
 ) {
+    constexpr int ncols = ncols1 * ncols2;
+
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
@@ -603,20 +696,25 @@ void launch_fattn(
 
     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 
+    GGML_ASSERT(Q->ne[3] == 1);
+
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
+    const int id  = ggml_cuda_get_device();
+    const int cc  = ggml_cuda_info().devices[id].cc;
+    const int nsm = ggml_cuda_info().devices[id].nsm;
 
     ggml_cuda_pool_alloc   K_f16(pool);
     ggml_cuda_pool_alloc   V_f16(pool);
     ggml_cuda_pool_alloc  dst_tmp(pool);
     ggml_cuda_pool_alloc dst_tmp_meta(pool);
 
-    char * K_data = (char *) K->data;
+    const char * K_data = (const char *) K->data;
     size_t nb11 = K->nb[1];
     size_t nb12 = K->nb[2];
     size_t nb13 = K->nb[3];
 
-    char * V_data = (char *) V->data;
+    const char * V_data = (const char *) V->data;
     size_t nb21 = V->nb[1];
     size_t nb22 = V->nb[2];
     size_t nb23 = V->nb[3];
@@ -649,39 +747,61 @@ void launch_fattn(
         nb23 = nb23*bs*sizeof(half)/ts;
     }
 
-    if (parallel_blocks > 1) {
-        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
-        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
-    }
+    const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
+    const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
     const dim3 block_dim(WARP_SIZE, nwarps, 1);
-    const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
-    const int  shmem = 0;
+    dim3 blocks_num;
+    if (parallel_blocks == 0) {
+        // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
+        const int max_blocks = 2*nsm;
+        const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
+        const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
+
+        const int nblocks_stream_k = max_blocks;
+
+        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
+
+        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
+        blocks_num.y = 1;
+        blocks_num.z = 1;
+
+        dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
+    } else {
+        blocks_num.x = parallel_blocks*ntiles_x;
+        blocks_num.y = Q->ne[2];
+        blocks_num.z = Q->ne[3];
+
+        if (parallel_blocks > 1) {
+            dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+            dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+        }
+    }
 
     float scale         = 1.0f;
     float max_bias      = 0.0f;
     float logit_softcap = 0.0f;
 
-    memcpy(&scale,         (float *) KQV->op_params + 0, sizeof(float));
-    memcpy(&max_bias,      (float *) KQV->op_params + 1, sizeof(float));
-    memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
+    memcpy(&scale,         (const float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (const float *) KQV->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
     if (logit_softcap != 0.0f) {
         scale /= logit_softcap;
     }
 
     const uint32_t n_head      = Q->ne[2];
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+    const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
 
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-    fattn_kernel<<>>(
+    fattn_kernel<<>>(
         (const char *) Q->data,
         K_data,
         V_data,
         mask ? ((const char *) mask->data) : nullptr,
-        (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+        (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -693,16 +813,22 @@ void launch_fattn(
     );
     CUDA_CHECK(cudaGetLastError());
 
-    if ((parallel_blocks) == 1) {
-        return;
+    if constexpr (parallel_blocks == 0) {
+        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+            const dim3 block_dim_combine(D, 1, 1);
+            const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
+
+            flash_attn_stream_k_fixup
+                <<>>
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
+        }
+    } else if constexpr (parallel_blocks > 1) {
+        const dim3 block_dim_combine(D, 1, 1);
+        const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+
+        flash_attn_combine_results
+            <<>>
+            (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
     }
-
-    const dim3 block_dim_combine(D, 1, 1);
-    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
-    const int  shmem_combine = 0;
-
-    flash_attn_combine_results
-        <<>>
-        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
     CUDA_CHECK(cudaGetLastError());
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
new file mode 100644
index 000000000..718ee5402
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -0,0 +1,1021 @@
+#include "common.cuh"
+#include "cp-async.cuh"
+#include "mma.cuh"
+#include "fattn-common.cuh"
+
+using namespace ggml_cuda_mma;
+
+typedef tile<16,  8, half2> tile_A;
+typedef tile< 8,  8, half2> tile_B;
+typedef tile<16,  8, half2> tile_B_16;
+typedef tile<16,  8, float> tile_C_KQ;
+typedef tile<16, 16, float> tile_C_KQ_16;
+typedef tile<16,  4, half2> tile_C_VKQ;
+typedef tile<16,  8, half2> tile_C_VKQ_16;
+
+template
+static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
+        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
+    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+
+    // If cp.async is available, load up to the highest power of 2 in D asynchronously:
+#ifdef CP_ASYNC_AVAILABLE
+    static_assert(D >= 64 && D < 512, "bad D");
+    constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
+
+    const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
+
+    constexpr int preload = 64;
+    constexpr int h2_per_chunk = 16/sizeof(half2);
+    constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
+    constexpr int stride_i = WARP_SIZE / chunks_per_row;
+#pragma unroll
+    for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
+        const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
+        const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
+
+        cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
+    }
+#else
+    constexpr int k0_sync_start = 0;
+#endif // CP_ASYNC_AVAILABLE
+    static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
+
+    // If D is not a power of 2, the rest is loaded synchronously.
+    // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
+    static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds");
+#pragma unroll
+    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+        const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
+        const int k0_stop  =                                         D/2 - (D/2) % (1*stride_k);
+        const int stride_i = WARP_SIZE / stride_k;
+
+        if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
+            continue;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
+            const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+#pragma unroll
+            for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
+            }
+        }
+    }
+}
+
+template
+static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
+        const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
+    static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter");
+#ifdef CP_ASYNC_AVAILABLE
+    constexpr int preload = KQ_per_iter * sizeof(half);
+    constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter;
+    constexpr int stride_j = nwarps * cols_per_warp;
+
+    const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
+        const int j = j0 + threadIdx.y*cols_per_warp +
+            (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8));
+
+        if (j0 + stride_j > ncols1 && j >= ncols1) {
+            break;
+        }
+
+        const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8));
+
+        cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
+    }
+#else
+    constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter;
+    constexpr int stride_j = nwarps * cols_per_warp;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
+        const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2));
+
+        if (j0 + stride_j > ncols1 && j >= ncols1) {
+            break;
+        }
+
+        const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2);
+
+        tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i];
+    }
+#endif // CP_ASYNC_AVAILABLE
+}
+
+template
+static __device__ __forceinline__ void flash_attn_ext_f16_iter(
+        const float2 * const __restrict__ Q_f2,
+        const half2  * const __restrict__ K_h2,
+        const half2  * const __restrict__ V_h2,
+        const half2  * const __restrict__ mask_h2,
+        float2       * const __restrict__ dstk,
+        float2       * const __restrict__ dstk_fixup,
+        const float scale,
+        const float slope,
+        const float logit_softcap,
+        const int ne01,
+        const int ne02,
+        const int stride_KV,
+        const int stride_mask,
+        const int jt,
+        half2        * const __restrict__ tile_K,
+        half2        * const __restrict__ tile_V,
+        half2        * const __restrict__ tile_mask,
+        const tile_B * const __restrict__ Q_B,
+        tile_C_VKQ   * const __restrict__ VKQ_C,
+        float        * const __restrict__ KQ_max,
+        float        * const __restrict__ KQ_rowsum,
+        const int kb0) {
+#ifdef NEW_MMA_AVAILABLE
+    constexpr int cols_per_warp   = ntiles * tile_B::I;
+    constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
+    constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int D2_padded       = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+
+    const int k_VKQ_0 = kb0 * KQ_per_iter;
+    tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles];
+
+    // Use wide variants of tiles if ntiles >= 2.
+    tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
+    tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+    tile_C_KQ_16  * KQ_C_16  = (tile_C_KQ_16  *) KQ_C;
+
+#ifdef CP_ASYNC_AVAILABLE
+    cp_async_wait_all();
+    __syncthreads();
+    flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+#else
+    if (ncols2 > 1 || mask_h2) {
+        flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
+    }
+    flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
+    __syncthreads();
+#endif // CP_ASYNC_AVAILABLE
+
+    // Calculate tile of KQ:
+#pragma unroll
+    for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) {
+        const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+#pragma unroll
+        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
+            tile_A K_A;
+            load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
+            if (ntiles == 1) {
+                mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
+            } else {
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+                    // Wide version of KQ_C is column-major => swap A and B.
+                    mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
+                }
+            }
+        }
+    }
+
+#ifndef CP_ASYNC_AVAILABLE
+    __syncthreads(); // Only needed if tile_K == tile_V.
+#endif // CP_ASYNC_AVAILABLE
+
+    if (use_logit_softcap) {
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) {
+#pragma unroll
+            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
+            }
+        }
+    }
+
+    float KQ_max_new[cols_per_thread];
+#pragma unroll
+    for (int col = 0; col < cols_per_thread; ++col) {
+        KQ_max_new[col] = KQ_max[col];
+    }
+    float KQ_rowsum_add[cols_per_thread] = {0.0f};
+
+    if (ntiles == 1) {
+        if (ncols2 > 1 || mask_h2) {
+#pragma unroll
+            for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) {
+                const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
+#pragma unroll
+                for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                    const int i = i0 + tile_C_KQ::get_i(l);
+                    const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
+
+                    KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
+                        __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]);
+                }
+            }
+        }
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
+#pragma unroll
+            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
+            }
+        }
+
+        // Values per KQ column are spread across 8 threads, does not need full warp reduce:
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+            for (int offset = 16; offset >= 4; offset >>= 1) {
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+            }
+        }
+
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
+#pragma unroll
+            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
+
+                KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
+            }
+        }
+    } else { // ntiles > 1
+        if (ncols2 > 1 || mask_h2) {
+#pragma unroll
+            for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) {
+                const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+                    for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
+                        const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
+                        const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
+
+                        const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]);
+                        const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
+                        KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
+                        KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
+                    }
+                }
+            }
+        }
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
+#pragma unroll
+            for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+                for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
+                    const int KQ_index = 2*t + (l/2) % 2;
+                    KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
+                }
+            }
+        }
+
+        // Values per KQ column are spread across 4 threads, does not need full warp reduce:
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+            for (int offset = 2; offset >= 1; offset >>= 1) {
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+            }
+        }
+
+        static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
+#pragma unroll
+            for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+                for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
+                    const int KQ_index = 2*t + (l/2) % 2;
+
+                    KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
+
+                    KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
+                }
+            }
+        }
+    }
+
+    {
+        float KQ_max_scale[cols_per_thread];
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+            KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
+            KQ_max[col] = KQ_max_new[col];
+
+            // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+            KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
+        }
+
+        if (ntiles == 1) {
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+#pragma unroll
+            for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
+#pragma unroll
+                for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+                    VKQ_C[i].x[l] *= KQ_max_scale_h2;
+                }
+            }
+        } else {
+#pragma unroll
+            for (int col = 0; col < cols_per_thread; ++col) {
+                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+                for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
+#pragma unroll
+                    for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
+                        VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+                    }
+                }
+            }
+        }
+    }
+
+    // Convert KQ C tiles into B tiles for VKQ calculation:
+    tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles];
+    tile_B_16 * B_16 = (tile_B_16 *) B;
+    static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size");
+    if (ntiles == 1) {
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) {
+            B[k] = get_transposed(get_half2(KQ_C[k]));
+        }
+    } else {
+        for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) {
+#pragma unroll
+            for (int t = 0; t < ntiles/2; ++t) {
+                B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
+            }
+        }
+    }
+
+#ifdef CP_ASYNC_AVAILABLE
+    // Preload K tile for next iteration:
+    cp_async_wait_all();
+    __syncthreads();
+    if (!last_iter) {
+        if (ncols2 > 1 || mask_h2) {
+            flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask);
+        }
+        flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV);
+    }
+#else
+    flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+    __syncthreads();
+#endif // CP_ASYNC_AVAILABLE
+
+    // Calculate VKQ tile:
+#pragma unroll
+    for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
+        static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size");
+#pragma unroll
+        for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) {
+            const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
+
+            tile_A A;
+            load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
+            if (ntiles == 1) {
+                mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+            } else {
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+                    // Wide version of VKQ_C is column-major => swap A and B.
+                    mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
+                }
+            }
+        }
+    }
+
+#ifndef CP_ASYNC_AVAILABLE
+    __syncthreads(); // Only needed if tile_K == tile_V.
+#endif // CP_ASYNC_AVAILABLE
+
+#else
+    NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+}
+
+template
+static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
+        const float2 * const __restrict__ Q_f2,
+        const half2  * const __restrict__ K_h2,
+        const half2  * const __restrict__ V_h2,
+        const half2  * const __restrict__ mask_h2,
+        float2       * const __restrict__ dstk,
+        float2       * const __restrict__ dstk_fixup,
+        const float scale,
+        const float slope,
+        const float logit_softcap,
+        const int ne01,
+        const int ne02,
+        const int stride_Q1,
+        const int stride_Q2,
+        const int stride_KV,
+        const int stride_mask,
+        const int jt,
+        const int kb0_start,
+        const int kb0_stop) {
+#ifdef NEW_MMA_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    constexpr int ncols           = ncols1 * ncols2;
+    constexpr int cols_per_warp   = ntiles * tile_B::I;
+    constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
+    constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+
+    static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
+
+    static_assert(D           % nwarps == 0, "bad D");
+    static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter");
+
+    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+
+    // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements:
+    extern __shared__ half2 tile_K[];
+#ifdef CP_ASYNC_AVAILABLE
+    half2 * tile_V    = tile_K + KQ_per_iter*D2_padded;
+#else
+    half2 * tile_V    = tile_K;
+#endif // CP_ASYNC_AVAILABLE
+    half2 * tile_mask = tile_V + KQ_per_iter*D2_padded;
+
+    tile_B       Q_B[D/(2*tile_B::J) * ntiles];
+    tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles];
+
+    tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
+    tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+
+    float KQ_rowsum[cols_per_thread] = {0.0f};
+    float KQ_max[cols_per_thread];
+#pragma unroll
+    for (int col = 0; col < cols_per_thread; ++col) {
+        KQ_max[col] = -FLT_MAX/2.0f;
+    }
+
+    // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
+    // The loading is done with decreasing granularity for D for better memory bandwidth.
+    const half2 scale_h2 = make_half2(scale, scale);
+#pragma unroll
+    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+        const int k0_start  = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+        const int k0_stop   =                             D/2 - (D/2) % (1*stride_k);
+        const int stride_jc = WARP_SIZE / stride_k;
+
+        if (k0_start == k0_stop) {
+            continue;
+        }
+
+#pragma unroll
+        for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
+            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+            if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
+                break;
+            }
+
+            const int j = jc / ncols2;
+            const int c = jc % ncols2;
+
+            if (jt*ncols1 + j < ne01) {
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
+                    tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
+                }
+            } else {
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f);
+                }
+            }
+        }
+    }
+
+    __syncthreads();
+
+    {
+        const int j0 = (threadIdx.y / np) * cols_per_warp;
+
+#pragma unroll
+        for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+            if (ntiles == 1) {
+                load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
+            } else {
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+                    load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
+                        tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded);
+                }
+            }
+        }
+    }
+
+    __syncthreads();
+
+    // Preload mask and K data for first iteration when using cp_async:
+#ifdef CP_ASYNC_AVAILABLE
+    if (ncols2 > 1 || mask_h2) {
+        flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask);
+    }
+    flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV);
+#endif // CP_ASYNC_AVAILABLE
+
+    // Iterate over ne11 == previous tokens:
+    for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
+        constexpr bool last_iter = false;
+        flash_attn_ext_f16_iter
+            (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+    }
+    { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+        constexpr bool last_iter = true;
+        flash_attn_ext_f16_iter
+            (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+    }
+
+    // With cp_async there is no __syncthreads at the end of the iter,
+    //     there can be a race condition on shared memory access for combining/writing back results.
+#ifdef CP_ASYNC_AVAILABLE
+    if (nwarps*cols_per_warp > KQ_per_iter) {
+        __syncthreads();
+    }
+#endif // CP_ASYNC_AVAILABLE
+
+    // Finally, sum up partial KQ rowsums.
+    // The partial sums are spread across 8/4 threads each, does not need full reduce.
+    {
+        constexpr int offset_first = ntiles == 1 ? 16 : 2;
+        constexpr int offset_last  = ntiles == 1 ?  4 : 1;
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+            for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
+                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
+            }
+        }
+    }
+
+    // Write VKQ accumulators to shared memory in column-major format.
+    // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
+    // Also for np > 1 the combination is done via these values in shared memory.
+    if (ntiles == 1) {
+        const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
+#pragma unroll
+        for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+            const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
+
+#pragma unroll
+            for (int l = 0; l < tile_B::ne; ++l) {
+                const int k = k0 + tile_B::get_j(l);
+
+                tile_K[jc_cwd*D2_padded + k] = B.x[l];
+            }
+        }
+    } else {
+#pragma unroll
+        for (int t = 0; t < ntiles/2; ++t) {
+            const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
+#pragma unroll
+            for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) {
+#pragma unroll
+                for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
+                    const int j = j0 + tile_C_VKQ_16::get_i(l);
+                    const int k = k0 + tile_C_VKQ_16::get_j(l);
+
+                    tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
+                }
+            }
+        }
+    }
+
+    if constexpr (ntiles == 1) {
+        const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
+        const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
+        const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
+
+        if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
+            // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+            ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+        }
+
+        __syncthreads();
+
+        if (np == 1) {
+            // No combination is needed, the meta data can be directly written from registers to VRAM.
+            if (needs_fixup && threadIdx.x < tile_B::I) {
+                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+            if (is_fixup && threadIdx.x < tile_B::I) {
+                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+        }
+    } else {
+        static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
+        const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
+            + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
+            + tile_C_VKQ_16::get_i(threadIdx.x % 4);
+        const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
+
+        if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
+            // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+            ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+        }
+
+        __syncthreads();
+
+        if (np == 1) {
+            // No combination is needed, the meta data can be directly written from registers to VRAM.
+            if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+            if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+        }
+    }
+
+    static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
+    if (np > 1 && threadIdx.y % np == 0) {
+        // Combine the meta data for parallel warps via shared memory.
+        // Warps with threadIdx.y % np != 0 must NOT return early.
+        // All threads must return simultaneously to avoid race conditions with work on the next tile.
+
+        constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
+
+        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
+        float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4;
+        float2 meta[nmeta];
+#pragma unroll
+        for (int imeta = 0; imeta < nmeta; ++imeta) {
+            meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2];
+        }
+
+        float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
+#pragma unroll
+        for (int imeta = 1; imeta < nmeta; ++imeta) {
+            KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
+        }
+#pragma unroll
+        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+            if (offset >= WARP_SIZE) {
+                continue;
+            }
+            KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
+        }
+
+        float KQ_cms[nmeta]; // KQ combine max scale per warp.
+#pragma unroll
+        for (int imeta = 0; imeta < nmeta; ++imeta) {
+            KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
+        }
+
+        float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
+#pragma unroll
+        for (int imeta = 1; imeta < nmeta; ++imeta) {
+            KQ_crs += KQ_cms[imeta]*meta[imeta].y;
+        }
+#pragma unroll
+        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+            if (offset >= WARP_SIZE) {
+                continue;
+            }
+            KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
+        }
+
+        // Write back combined meta data:
+#pragma unroll
+        for (int imeta = 0; imeta < nmeta; ++imeta) {
+            if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
+                // Combined KQ max scale + rowsum.
+                meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs);
+            }
+        }
+
+        // Combined KQ max + rowsum.
+        static_assert(cols_per_warp <= WARP_SIZE);
+        if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+            float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+        }
+        if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+            float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+        }
+    }
+
+    if (np > 1) {
+        __syncthreads();
+    }
+
+    if (np == 1 || threadIdx.y % np == 0) {
+        // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
+        // The values after that are for the partial results of the individual blocks.
+        float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
+
+#pragma unroll
+        for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+            const int k0_start  = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+            const int k0_stop   =                             D/2 - (D/2) % (1*stride_k);
+            const int stride_jc = WARP_SIZE / stride_k;
+
+            if (k0_start == k0_stop) {
+                continue;
+            }
+
+#pragma unroll
+            for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
+                const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+                if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
+                    break;
+                }
+
+                const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
+
+                const int j_dst = jc_dst / ncols2;
+                const int c_dst = jc_dst % ncols2;
+
+                if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
+                    continue;
+                }
+
+                const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2;
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    float2 dstk_val = make_float2(0.0f, 0.0f);
+#pragma unroll
+                    for (int ip = 0; ip < np; ++ip) {
+                        const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0];
+                        const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]);
+                        dstk_val.x += dstk_val_add.x*KQ_crs;
+                        dstk_val.y += dstk_val_add.y*KQ_crs;
+                    }
+
+                    if (!needs_fixup && !is_fixup) {
+                        const float KQ_rowsum_j = meta_j[1];
+                        dstk_val.x /= KQ_rowsum_j;
+                        dstk_val.y /= KQ_rowsum_j;
+                    }
+
+                    if (is_fixup) {
+                        dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val;
+                    } else {
+                        dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val;
+                    }
+                }
+            }
+        }
+    }
+
+    if (np > 1) {
+        __syncthreads();
+    }
+#else
+    NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+}
+
+template
+__launch_bounds__(nwarps*WARP_SIZE, 2)
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter");
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+
+    const int stride_Q1   = nb01 / sizeof(float2);
+    const int stride_Q2   = nb02 / sizeof(float2);
+    const int stride_KV   = nb11 / sizeof(half2);
+    const int stride_mask = nb31 / sizeof(half2);
+
+    const int iter_k = ne11 / FATTN_KQ_STRIDE;
+    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+
+    constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice.
+
+    // kbc == k block continuous, current index in continuous ijk space.
+    int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+
+    // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
+    // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
+    // In the most general case >2 seams can fall into the same tile.
+
+    // kb0 == k start index when in the output tile.
+    int kb0_start = kbc % iter_k;
+    int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
+    while (kbc < kbc_stop && kb0_stop == iter_k) {
+        const int channel = kbc / (iter_k*iter_j);
+        const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+
+        const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
+        const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+        const half2  * V_h2    = (const half2  *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
+        const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+        float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * D/2);
+
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+
+        const int kb0_start_kernel = kb0_start * kb_niter;
+        const int kb0_stop_kernel  = kb0_stop  * kb_niter;
+
+        constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+        if (kb0_start == 0) {
+            constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
+            flash_attn_ext_f16_process_tile
+                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+        } else {
+            constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
+            flash_attn_ext_f16_process_tile
+                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+        }
+
+        kbc += iter_k;
+        kbc -= kbc % iter_k;
+
+        kb0_start = 0;
+        kb0_stop  = min(iter_k, kbc_stop - kbc);
+    }
+
+    if (kbc >= kbc_stop) {
+        return;
+    }
+
+    const int channel = kbc / (iter_k*iter_j);
+    const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+
+    const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
+    const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+    const half2  * V_h2    = (const half2  *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
+    const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+    float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * D/2);
+
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+
+    const int kb0_start_kernel = kb0_start * kb_niter;
+    const int kb0_stop_kernel  = kb0_stop  * kb_niter;
+
+    constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
+    constexpr bool needs_fixup = false;
+    flash_attn_ext_f16_process_tile
+        (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+         ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+#else
+    NO_DEVICE_CODE;
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+}
+
+template 
+void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    constexpr int ncols         = ncols1 * ncols2;
+    constexpr int KQ_per_iter   = D <= 128 && ncols1 <= 64 ? 64 : 32;
+    constexpr int nwarps        = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4;
+    constexpr int ntiles        = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4);
+    constexpr int cols_per_warp = ntiles * tile_B::I;
+
+    static_assert(D     %    tile_B::J  == 0, "bad D");
+    static_assert(ncols % cols_per_warp == 0, "bad ncols");
+
+    const ggml_tensor * KQV = dst;
+    const int id    = ggml_cuda_get_device();
+    const int cc    = ggml_cuda_info().devices[id].cc;
+
+    const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter;
+
+    const size_t nbytes_shared_KV      = KQ_shared_rows       * (D           + 8) * sizeof(half);
+    const size_t nbytes_shared_mask    = ncols1               * (KQ_per_iter + 8) * sizeof(half);
+    const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D           + 8) * sizeof(half);
+
+    const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine);
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    fattn_kernel_t fattn_kernel;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        fattn_kernel = flash_attn_ext_f16;
+    } else {
+        constexpr bool use_logit_softcap = true;
+        fattn_kernel = flash_attn_ext_f16;
+    }
+
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
+}
+
+
+#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2)                          \
+    template void ggml_cuda_flash_attn_ext_mma_f16_case                     \
+    (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,   8);
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  16);
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  32);
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  64);
+
+// Kernels with ncols == 128 are only 4% faster due to register pressure.
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
index 4d314dacb..ef3569fab 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -44,8 +44,13 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifdef FP16_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
+
     // Skip unused kernel variants for faster compilation:
+#ifdef FP16_MMA_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
         return;
@@ -280,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // FP16_AVAILABLE
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
 template 
@@ -288,16 +293,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
         case  64: {
-            constexpr int      D = 64;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 64;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
-            constexpr int      D = 128;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 128;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
index bb3360447..04b69c83b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -44,11 +44,13 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifndef FLASH_ATTN_AVAILABLE
+#ifdef FLASH_ATTN_AVAILABLE
+
+    // Skip unused kernel variants for faster compilation:
+#ifdef FP16_MMA_AVAILABLE
     NO_DEVICE_CODE;
     return;
-#endif // FLASH_ATTN_AVAILABLE
-    // Skip unused kernel variants for faster compilation:
+#endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
         return;
@@ -280,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
             dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
+#else
+    NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
 }
 
 template 
@@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
         case  64: {
-            constexpr int      D = 64;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 64;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
-            constexpr int      D = 128;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 128;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index 34a2992c7..b7686c1ec 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -41,7 +41,8 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifdef FP16_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
+
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -294,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // FP16_AVAILABLE
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
 template 
@@ -303,7 +304,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+    constexpr size_t nbytes_shared = 0;
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
index a28fc8b7f..c1d2dd8d1 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -41,6 +41,8 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
+#ifdef FLASH_ATTN_AVAILABLE
+
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -276,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
     if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
+#else
+    NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
 }
 
 template 
@@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+    constexpr size_t nbytes_shared = 0;
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
new file mode 100644
index 000000000..8828652fb
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -0,0 +1,648 @@
+// Old and deprecated WMMA FlashAttention implementation.
+// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
+// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-wmma-f16.cuh"
+
+#ifdef FP16_MMA_AVAILABLE
+#include 
+#endif // FP16_MMA_AVAILABLE
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+
+    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+    constexpr int frag_m = ncols == 8 ? 32 : 16;
+    constexpr int frag_n = ncols == 8 ?  8 : 16;
+    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+    typedef nvcuda::wmma::fragment frag_a_K;
+    typedef nvcuda::wmma::fragment frag_a_V;
+    typedef nvcuda::wmma::fragment frag_b;
+    typedef nvcuda::wmma::fragment                      frag_c_KQ;
+    typedef nvcuda::wmma::fragment                          frag_c_VKQ;
+
+    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+    constexpr int D_padded = D + 8;
+    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
+    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+
+    const int stride_Q  = nb01 / sizeof(float);
+    const int stride_KV = nb11 / sizeof(half);
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+    const half2 slope2 = make_half2(slopef, slopef);
+
+    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
+
+    frag_b Q_b[D/16][ncols/frag_n];
+
+    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+    constexpr int mem_KQ = ncols*kqs_padded*kqar;
+    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+    float * KQ_f = (float *) KQ;
+    half2 * KQ2 = (half2 *) KQ;
+
+    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
+    float       KQ_max_f[ncols/nwarps];
+    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_f[j] = -FLT_MAX/2.0f;
+    }
+
+    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+    half2       KQ_max_h2[ncols/nwarps];
+    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+    }
+
+    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+    half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                break;
+            }
+            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+        }
+    }
+
+    // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+        }
+    }
+
+    __syncthreads();
+
+    // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+    for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+        }
+    }
+
+    __syncthreads();
+
+    // Iterate over ne11 == previous tokens:
+    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+        // Calculate tile of KQ:
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+            frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+            }
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+                frag_a_K K_a;
+                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                }
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same::value) {
+                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+
+                    if (use_logit_softcap) {
+                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+                    }
+                }
+
+                float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = warp_reduce_max(KQ_max_new);
+
+                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_f[j0/nwarps] = expf(diff);
+                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                    KQ_max_scale_f[j0/nwarps] = 0.0f;
+                }
+                KQ_max_f[j0/nwarps] = KQ_max_new;
+
+                float KQ_rowsum_add = 0.0f;
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                    }
+                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+            } else {
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+
+                    if (use_logit_softcap) {
+                        // There is no dedicated tangens hyperbolicus function for half2.
+                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
+                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
+                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+
+                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+                    }
+                }
+
+                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+                KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+            }
+        }
+
+        __syncthreads();
+
+        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+                nvcuda::wmma::load_matrix_sync(
+                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+                    KQ + j0*(kqar*kqs_padded) + k,
+                    kqar*kqs_padded);
+            }
+        }
+
+        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+            }
+
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+                frag_a_V v_a;
+                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                }
+            }
+        }
+
+        __syncthreads();
+
+        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync(
+                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+                    D_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            half2 VKQ_scale;
+            if (std::is_same::value) {
+                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+            } else {
+                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                    break;
+                }
+
+                half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int l = 0; l < VKQ_ratio; ++l) {
+                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+                }
+                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j_VKQ = j0 + threadIdx.y;
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+        float KQ_rowsum_j;
+        if (std::is_same::value) {
+            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+        } else {
+            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            float dst_val = VKQ[j_VKQ*D_padded + i];
+            if (parallel_blocks == 1) {
+                dst_val /= KQ_rowsum_j;
+            }
+            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+        }
+
+        if (parallel_blocks == 1 || threadIdx.x != 0) {
+            continue;
+        }
+
+        float2 dst_meta_val;
+        if (std::is_same::value) {
+            dst_meta_val.x = KQ_max_f[j0/nwarps];
+        } else {
+            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+        }
+        dst_meta_val.y = KQ_rowsum_j;
+        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+}
+
+constexpr int get_max_power_of_2(int x) {
+    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
+
+template 
+void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    constexpr int nwarps = 4;
+
+    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
+    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (4*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 4;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        return;
+    }
+    if (2*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 2;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        return;
+    }
+    constexpr int parallel_blocks = 1;
+    fattn_kernel_t fattn_kernel;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    } else {
+        constexpr bool use_logit_softcap = true;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    }
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+}
+
+void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+
+    if (prec != GGML_PREC_DEFAULT) {
+        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+            constexpr int cols_per_block = 16;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                case 256:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+                    break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        } else {
+            constexpr int cols_per_block = 32;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                // case 256:
+                //     ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+                //     break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+        constexpr int cols_per_block = 8;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 16;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 80:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 112:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    switch (Q->ne[0]) {
+        case 64:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+            break;
+        case 80:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+            break;
+        case 96:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+            break;
+        case 112:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+            break;
+        case 128:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+            break;
+        case 256:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
index 860d0e6dc..beeea95eb 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
@@ -1,543 +1,3 @@
 #include "common.cuh"
-#include "fattn-common.cuh"
 
-#ifdef FP16_MMA_AVAILABLE
-#include 
-#endif // FP16_MMA_AVAILABLE
-
-// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-static __global__ void flash_attn_ext_f16(
-        const char * __restrict__ Q,
-        const char * __restrict__ K,
-        const char * __restrict__ V,
-        const char * __restrict__ mask,
-        float      * __restrict__ dst,
-        float2     * __restrict__ dst_meta,
-        const float scale,
-        const float max_bias,
-        const float m0,
-        const float m1,
-        const uint32_t n_head_log2,
-        const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int nb31,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
-#ifdef FP16_MMA_AVAILABLE
-    // Skip unused kernel variants for faster compilation:
-    if (use_logit_softcap && !(D == 128 || D == 256)) {
-        NO_DEVICE_CODE;
-        return;
-    }
-
-    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
-
-    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
-    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
-
-    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
-    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
-    constexpr int frag_m = ncols == 8 ? 32 : 16;
-    constexpr int frag_n = ncols == 8 ?  8 : 16;
-    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
-    typedef nvcuda::wmma::fragment frag_a_K;
-    typedef nvcuda::wmma::fragment frag_a_V;
-    typedef nvcuda::wmma::fragment frag_b;
-    typedef nvcuda::wmma::fragment                      frag_c_KQ;
-    typedef nvcuda::wmma::fragment                          frag_c_VKQ;
-
-    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
-    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
-    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
-
-    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
-    constexpr int D_padded = D + 8;
-    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
-    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
-
-    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
-    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
-    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
-
-    const int stride_Q  = nb01 / sizeof(float);
-    const int stride_KV = nb11 / sizeof(half);
-
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
-    const half  slopeh = __float2half(slopef);
-    const half2 slope2 = make_half2(slopef, slopef);
-
-    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
-
-    frag_b Q_b[D/16][ncols/frag_n];
-
-    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
-    constexpr int mem_KQ = ncols*kqs_padded*kqar;
-    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
-    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
-    float * KQ_f = (float *) KQ;
-    half2 * KQ2 = (half2 *) KQ;
-
-    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
-    float       KQ_max_f[ncols/nwarps];
-    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
-
-#pragma unroll
-    for (int j = 0; j < ncols/nwarps; ++j) {
-        KQ_max_f[j] = -FLT_MAX/2.0f;
-    }
-
-    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
-    half2       KQ_max_h2[ncols/nwarps];
-    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
-
-#pragma unroll
-    for (int j = 0; j < ncols/nwarps; ++j) {
-        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
-    }
-
-    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
-    half2 * VKQ2 = (half2 *) VKQ;
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-#pragma unroll
-        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
-                break;
-            }
-            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
-        }
-    }
-
-    // Convert Q to half and apply scale, temporarily store in KQ:
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
-                break;
-            }
-            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
-        }
-    }
-
-    __syncthreads();
-
-    // Load Q into tensor core fragments/registers since it will be used frequently:
-#pragma unroll
-    for (int i0 = 0; i0 < D; i0 += 16) {
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
-        }
-    }
-
-    __syncthreads();
-
-    // Iterate over ne11 == previous tokens:
-    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
-        // Calculate tile of KQ:
-#pragma unroll
-        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
-            frag_c_KQ KQ_c[ncols/frag_n];
-#pragma unroll
-            for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
-            }
-#pragma unroll
-            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
-                frag_a_K K_a;
-                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
-#pragma unroll
-                for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
-                }
-            }
-#pragma unroll
-            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
-            }
-        }
-
-        __syncthreads();
-
-        // Calculate softmax for each KQ column using the current max. value.
-        // The divisor is stored in KQ_rowsum and will be applied at the end.
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-            if (std::is_same::value) {
-                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
-
-                    if (use_logit_softcap) {
-                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
-                    }
-                }
-
-                float KQ_max_new = KQ_max_f[j0/nwarps];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
-                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
-                }
-                KQ_max_new = warp_reduce_max(KQ_max_new);
-
-                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
-                KQ_max_scale_f[j0/nwarps] = expf(diff);
-                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                    KQ_max_scale_f[j0/nwarps] = 0.0f;
-                }
-                KQ_max_f[j0/nwarps] = KQ_max_new;
-
-                float KQ_rowsum_add = 0.0f;
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
-                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
-                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
-                    }
-                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
-                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
-                }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
-
-                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
-                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
-            } else {
-                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
-
-                    if (use_logit_softcap) {
-                        // There is no dedicated tangens hyperbolicus function for half2.
-                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
-                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
-                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
-
-                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
-                    }
-                }
-
-                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
-                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
-                }
-                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
-                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
-                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
-                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
-                KQ_max_h2[j0/nwarps] = KQ_max_new;
-
-                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
-                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
-                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
-                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
-                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
-                }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
-
-                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
-                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
-            }
-        }
-
-        __syncthreads();
-
-        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-#pragma unroll
-            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
-                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-                nvcuda::wmma::load_matrix_sync(
-                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
-                    KQ + j0*(kqar*kqs_padded) + k,
-                    kqar*kqs_padded);
-            }
-        }
-
-        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
-#pragma unroll
-        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
-#pragma unroll
-            for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
-            }
-
-#pragma unroll
-            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
-                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-
-                frag_a_V v_a;
-                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
-#pragma unroll
-                for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
-                }
-            }
-        }
-
-        __syncthreads();
-
-        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
-#pragma unroll
-        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
-#pragma unroll
-            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync(
-                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
-                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
-                    D_padded, nvcuda::wmma::mem_col_major);
-            }
-        }
-
-        __syncthreads();
-
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-            half2 VKQ_scale;
-            if (std::is_same::value) {
-                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
-            } else {
-                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
-            }
-
-#pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-                const int i = i0 + threadIdx.x;
-                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
-                    break;
-                }
-
-                half2 VKQ_add = make_half2(0.0f, 0.0f);
-#pragma unroll
-                for (int l = 0; l < VKQ_ratio; ++l) {
-                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
-                }
-                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
-            }
-        }
-
-        __syncthreads();
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j_VKQ = j0 + threadIdx.y;
-        if (ic0 + j_VKQ >= ne01) {
-            return;
-        }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-
-        float KQ_rowsum_j;
-        if (std::is_same::value) {
-            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
-        } else {
-            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
-        }
-
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
-                break;
-            }
-            float dst_val = VKQ[j_VKQ*D_padded + i];
-            if (parallel_blocks == 1) {
-                dst_val /= KQ_rowsum_j;
-            }
-            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
-        }
-
-        if (parallel_blocks == 1 || threadIdx.x != 0) {
-            continue;
-        }
-
-        float2 dst_meta_val;
-        if (std::is_same::value) {
-            dst_meta_val.x = KQ_max_f[j0/nwarps];
-        } else {
-            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
-        }
-        dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
-    }
-#else
-   NO_DEVICE_CODE;
-#endif // FP16_MMA_AVAILABLE
-}
-
-constexpr int get_max_power_of_2(int x) {
-    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
-}
-
-static_assert(get_max_power_of_2(1) == 1, "Test failed.");
-static_assert(get_max_power_of_2(2) == 2, "Test failed.");
-static_assert(get_max_power_of_2(4) == 4, "Test failed.");
-static_assert(get_max_power_of_2(6) == 2, "Test failed.");
-
-// Number of VKQ rows calculated in parallel:
-constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
-    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
-}
-
-static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
-static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
-static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
-static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
-static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
-static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
-
-template 
-void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
-
-    constexpr int nwarps = 4;
-
-    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
-    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
-    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
-
-    float logit_softcap;
-    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
-
-    if (4*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 4;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
-        return;
-    }
-    if (2*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 2;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
-        return;
-    }
-    constexpr int parallel_blocks = 1;
-    fattn_kernel_t fattn_kernel;
-    if (logit_softcap == 0.0f) {
-        constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-    } else {
-        constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-    }
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
-}
-
-#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t)                         \
-    template void ggml_cuda_flash_attn_ext_wmma_f16_case                              \
-    (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
-// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64,  8, half);
-extern DECL_FATTN_WMMA_F16_CASE( 96,  8, half);
-extern DECL_FATTN_WMMA_F16_CASE(128,  8, half);
-extern DECL_FATTN_WMMA_F16_CASE(256,  8, half);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
+void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
index 0b26b0f8e..b1becccb4 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
@@ -1,5 +1,6 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
+#include "fattn-mma-f16.cuh"
 #include "fattn-tile-f16.cuh"
 #include "fattn-tile-f32.cuh"
 #include "fattn-vec-f16.cuh"
@@ -7,144 +8,89 @@
 #include "fattn-wmma-f16.cuh"
 #include "fattn.cuh"
 
-#include 
+template 
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
 
-static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
-
-    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
-
-    if (prec != GGML_PREC_DEFAULT) {
-        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
-            constexpr int cols_per_block = 16;
-            switch (Q->ne[0]) {
-                case 64:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
-                    break;
-                case 80:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
-                    break;
-                case 96:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
-                    break;
-                case 112:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
-                    break;
-                case 128:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
-                    break;
-                case 256:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
-                    break;
-                default:
-                    GGML_ABORT("fatal error");
-                    break;
-            }
-        } else {
-            constexpr int cols_per_block = 32;
-            switch (Q->ne[0]) {
-                case 64:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
-                    break;
-                case 80:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
-                    break;
-                case 96:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
-                    break;
-                case 112:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
-                    break;
-                case 128:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
-                    break;
-                // case 256:
-                //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
-                //     break;
-                default:
-                    GGML_ABORT("fatal error");
-                    break;
-            }
-        }
+    if (Q->ne[1] <= 8/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
         return;
     }
 
-    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
-        constexpr int cols_per_block = 8;
-        switch (Q->ne[0]) {
-            case 64:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
-                break;
-            case 96:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
-                break;
-            case 128:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
-                break;
-            case 256:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
-                break;
-            default:
-                GGML_ABORT("fatal error");
-                break;
-        }
+    if (Q->ne[1] <= 16/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
         return;
     }
 
-    if (Q->ne[1] <= 32) {
-        constexpr int cols_per_block = 16;
-        switch (Q->ne[0]) {
-            case 64:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
-                break;
-            case 80:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
-                break;
-            case 96:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
-                break;
-            case 112:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
-                break;
-            case 128:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
-                break;
-            case 256:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
-                break;
-            default:
-                GGML_ABORT("fatal error");
-                break;
-        }
+    if (Q->ne[1] <= 32/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
         return;
     }
 
-    constexpr int cols_per_block = 32;
+    ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
+}
+
+template 
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+
     switch (Q->ne[0]) {
         case 64:
-            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
             break;
         case 80:
-            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
             break;
         case 96:
-            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
             break;
         case 112:
-            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
             break;
         case 128:
-            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
             break;
         case 256:
-            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
             break;
         default:
             GGML_ABORT("fatal error");
             break;
     }
 }
+
+static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * mask = dst->src[3];
+
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+    const float use_gqa_opt = mask && max_bias == 0.0f;
+
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+    if (use_gqa_opt && gqa_ratio % 8 == 0) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
+        return;
+    }
+
+    if (use_gqa_opt && gqa_ratio == 4) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
+        return;
+    }
+
+    if (use_gqa_opt && gqa_ratio == 2) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
+        return;
+    }
+
+    ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
+}
+
 #define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
     if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \
         ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \
@@ -296,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
 }
 
 void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * V    = dst->src[2];
+    const ggml_tensor * mask = dst->src[3];
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -323,15 +272,26 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     if (!fp16_mma_available(cc)) {
-        if (Q->ne[1] <= 8) {
-            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+        if (prec == GGML_PREC_DEFAULT) {
+            if (Q->ne[1] <= 8) {
+                ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+            } else {
+                ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+            }
         } else {
-            ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+            if (Q->ne[1] <= 8) {
+                ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+            } else {
+                ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
+            }
         }
         return;
     }
 
-    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+    const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
+        K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             return;
@@ -341,5 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         }
     }
 
-    ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+    // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
+    if (cc == GGML_CUDA_CC_VOLTA) {
+        ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+        return;
+    }
+
+    ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
index 4c3703238..4cef53a98 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
@@ -3,15 +3,15 @@
 
 template
 static __global__ void k_get_rows(
-            const void * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+        const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+        const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
     const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
-    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i10 =  blockDim.y*blockIdx.y + threadIdx.y;
     const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
     const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
 
@@ -22,10 +22,10 @@ static __global__ void k_get_rows(
     const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+    const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-    const int ib = i00/qk; // block index
-    const int iqs = (i00%qk)/qr; // quant index
+    const int ib   =  i00/qk;      // block index
+    const int iqs  = (i00%qk)/qr;  // quant index
     const int iybs = i00 - i00%qk; // dst block start index
     const int y_offset = qr == 1 ? 1 : qk/2;
 
@@ -39,15 +39,15 @@ static __global__ void k_get_rows(
 
 template
 static __global__ void k_get_rows_float(
-            const src0_t * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+        const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+        const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
-    const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
-    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i00 =  blockIdx.x*blockDim.x + threadIdx.x;
+    const int i10 =  blockDim.y*blockIdx.y + threadIdx.y;
     const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
     const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
 
@@ -58,14 +58,38 @@ static __global__ void k_get_rows_float(
     const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+    const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
     dst_row[i00] = src0_row[i00];
 }
 
+template
+static __global__ void k_get_rows_back_float(
+        const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
+    const int col = blockIdx.x*blockDim.x + threadIdx.x;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
+
+    float sum = 0.0f;
+
+    for (int64_t i = 0; i < nrows_grad; ++i) {
+        if (rows[i] != dst_row) {
+            continue;
+        }
+        sum += grad[i*ncols + col];
+    }
+
+    dst[dst_row*ncols + col] = sum;
+}
+
 template
-static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-                            const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+static void get_rows_cuda(
+        const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+        const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
     GGML_ASSERT(ne00 % 2 == 0);
 
     k_get_rows<<>>(
-            src0_dd, src1_dd, dst_dd,
-            ne00, /*ne01, ne02, ne03,*/
-            /*ne10, ne11,*/ ne12, /*ne13,*/
-            /* s0,*/ s1, s2, s3,
-            /* nb00,*/ nb01, nb02, nb03,
-            s10, s11, s12/*, s13*/);
+        src0_dd, src1_dd, dst_dd,
+        ne00, /*ne01, ne02, ne03,*/
+        /*ne10, ne11,*/ ne12, /*ne13,*/
+        /* s0,*/ s1, s2, s3,
+        /* nb00,*/ nb01, nb02, nb03,
+        s10, s11, s12/*, s13*/);
 
     GGML_UNUSED(dst);
 }
 
 template
-static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-                                const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+static void get_rows_cuda_float(
+        const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+        const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
+    GGML_ASSERT(ne13 == 1);
+
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
     const dim3 block_nums(block_num_x, ne10, ne11*ne12);
@@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
     //const size_t s13 = nb13 / ggml_element_size(src1);
 
     k_get_rows_float<<>>(
-            src0_dd, src1_dd, dst_dd,
-            ne00, /*ne01, ne02, ne03,*/
-            /*ne10, ne11,*/ ne12, /*ne13,*/
-            /* s0,*/ s1, s2, s3,
-            /* nb00,*/ nb01, nb02, nb03,
-            s10, s11, s12/*, s13*/);
+        src0_dd, src1_dd, dst_dd,
+        ne00, /*ne01, ne02, ne03,*/
+        /*ne10, ne11,*/ ne12, /*ne13,*/
+        /* s0,*/ s1, s2, s3,
+        /* nb00,*/ nb01, nb02, nb03,
+        s10, s11, s12/*, s13*/);
 
     GGML_UNUSED(dst);
 }
@@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
-    const float * src0_d = (const float *)src0->data;
-    const float * src1_d = (const float *)src1->data;
-    float * dst_d = (float *)dst->data;
+
+    const void    * src0_d = (const void    *) src0->data;
+    const int32_t * src1_d = (const int32_t *) src1->data;
+    float         * dst_d  = (float         *) dst->data;
+
     cudaStream_t stream = ctx.stream();
 
-
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
     GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
     GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
-
-    const int32_t * src1_i32 = (const int32_t *) src1_d;
+    GGML_ASSERT(dst->nb[0]  == ggml_type_size(dst->type));
 
     switch (src0->type) {
         case GGML_TYPE_F16:
-            get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_F32:
-            get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q4_0:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q4_1:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q5_0:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q5_1:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q8_0:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         default:
             // TODO: k-quants
@@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             break;
     }
 }
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+    const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const float   * src0_d = (const float   *) src0->data;
+    const int32_t * src1_d = (const int32_t *) src1->data;
+    float         * dst_d  = (float         *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(ne02*ne03 == 1);
+    GGML_ASSERT(ne12*ne13 == 1);
+    GGML_ASSERT(ne2*ne3 == 1);
+
+    const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, ne1, 1);
+
+    k_get_rows_back_float<<>>(src0_d, src1_d, dst_d, ne00, ne10);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cuh
index bbf130232..a1ca643f1 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cuh
@@ -1,5 +1,8 @@
 #include "common.cuh"
 
 #define CUDA_GET_ROWS_BLOCK_SIZE 256
+#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
 
 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
index 9286f8665..1adf08fad 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -37,10 +37,13 @@
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv6.cuh"
+#include "ggml-cuda/gla.cuh"
+#include "ggml.h"
 
 #include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -61,7 +64,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 [[noreturn]]
 void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
     int id = -1; // in case cudaGetDevice fails
-    cudaGetDevice(&id);
+    (void)cudaGetDevice(&id);
 
     GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
     GGML_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
@@ -118,12 +121,78 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
 #endif
 }
 
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+static int ggml_cuda_parse_id(char devName[]) {
+    // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
+    // these values are not stable so this is susceptible to breakage
+    // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp
+    int archMajor = 0x0;
+    int archMinor = 0x0;
+    int archNum = GGML_CUDA_CC_OFFSET_AMD;
+    int archLen = strlen(devName);
+    char archName[archLen + 1];
+
+    // strip leading 'gfx' while copying into our buffer
+    if (archLen > 3) {
+        strcpy(archName, &devName[3]);
+        archLen -= 3;
+    }
+
+    // trim trailing :xnack- or :sramecc- statuses
+    archLen = strcspn(archName, ":");
+    archName[archLen] = '\0';
+
+    // tease out the version information
+    if (archLen > 8) {
+        // versions labeled generic use '-' as delimiter
+        // strip the trailing "-generic" then iterate through what remains
+        if ((strstr(archName, "-generic"))) {
+            archName[archLen - 8] = '\0';
+            char * pch;
+            if ((pch = strtok(archName, "-"))) {
+                archMajor = (int)strtoul(pch, 0, 16);
+                if ((pch = strtok(NULL, "-"))) {
+                    archMinor = 0x10 * (int)strtoul(pch, 0, 16);
+                }
+            }
+        }
+    } else if (archLen >= 3) {
+        // last two digits should be the minor * 0x10 + stepping
+        archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);
+        archName[archLen - 2] = '\0';
+
+        // only the major version remains
+        archMajor = (int)strtoul(archName, 0, 16);
+    }
+    archNum += archMajor * 0x100;
+    archNum += archMinor;
+    return archNum;
+}
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+
 static ggml_cuda_device_info ggml_cuda_init() {
 #ifdef __HIP_PLATFORM_AMD__
     // Workaround for a rocBLAS bug when using multiple graphics cards:
     // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
-    rocblas_initialize();
-    CUDA_CHECK(cudaDeviceSynchronize());
+    {
+        int major_version = 0;
+        size_t version_length = 0;
+        if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
+            std::vector version(version_length+1, '\0');
+            if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
+                version.resize(::strlen(version.data()));
+                int parsed_value = 0;
+                if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
+                    major_version = parsed_value;
+                }
+            }
+        }
+        if (major_version < 4) {
+            GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
+            rocblas_initialize();
+            CUDA_CHECK(cudaDeviceSynchronize());
+        }
+    }
 #endif
 
     ggml_cuda_device_info info = {};
@@ -151,7 +220,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
     for (int id = 0; id < info.device_count; ++id) {
         int device_vmm = 0;
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#if defined(GGML_USE_VMM)
         CUdevice device;
         CU_CHECK(cuDeviceGet(&device, id));
         CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@@ -163,24 +232,46 @@ static ggml_cuda_device_info ggml_cuda_init() {
             alloc_prop.location.id = id;
             CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
         }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#endif // defined(GGML_USE_VMM)
         info.devices[id].vmm = !!device_vmm;
 
         cudaDeviceProp prop;
         CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
-        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 
         info.default_tensor_split[id] = total_vram;
         total_vram += prop.totalGlobalMem;
 
-        info.devices[id].nsm   = prop.multiProcessorCount;
-        info.devices[id].smpb  = prop.sharedMemPerBlock;
+        info.devices[id].nsm       = prop.multiProcessorCount;
+        info.devices[id].smpb      = prop.sharedMemPerBlock;
+        info.devices[id].warp_size = prop.warpSize;
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
         info.devices[id].smpbo = prop.sharedMemPerBlock;
-        info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
+
+        info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
+        if ((info.devices[id].cc & 0xff00) == 0x0) {
+            GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s  cc %d.%d\n",
+                            id, prop.name, prop.gcnArchName, prop.major, prop.minor);
+
+            // Fallback to prop.major and prop.minor
+            if (prop.major > 0) {
+                info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;
+                info.devices[id].cc += prop.minor * 0x10;
+            }
+        }
+        GGML_LOG_INFO("  Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
+                      id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
+                      device_vmm ? "yes" : "no", prop.warpSize);
+#elif defined(GGML_USE_MUSA)
+        // TODO: refine the .cc to reflect MUSA's actual CC capabilities
+        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
+        info.devices[id].cc = 100*prop.major + 10*prop.minor;
+        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
+                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 #else
         info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
+        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
+                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
     }
 
@@ -299,7 +390,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
 };
 
 // pool with virtual memory
-#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#if defined(GGML_USE_VMM)
 struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
     static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
 
@@ -308,6 +399,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
     size_t pool_used = 0;
     size_t pool_size = 0;
     size_t granularity;
+#if defined(GGML_USE_HIP)
+    std::vector> mappings;
+#endif
 
     explicit ggml_cuda_pool_vmm(int device) :
         device(device),
@@ -316,7 +410,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
 
     ~ggml_cuda_pool_vmm() {
         if (pool_addr != 0) {
+#if defined(GGML_USE_HIP)
+            // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
+            for (std::pair & mapping : mappings) {
+                CU_CHECK(cuMemUnmap(mapping.first, mapping.second));
+            }
+#else
             CU_CHECK(cuMemUnmap(pool_addr, pool_size));
+#endif
             CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
         }
     }
@@ -349,7 +450,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
             }
 
             // map at the end of the pool
-            CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
+            CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
+            CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0));
+#if defined(GGML_USE_HIP)
+            mappings.push_back({start_ptr, reserve_size});
+#endif
 
             // the memory allocation handle is no longer needed after mapping
             CU_CHECK(cuMemRelease(handle));
@@ -359,7 +464,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
             access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
             access.location.id = device;
             access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
-            CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
+            CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
 
             // add to the pool
             pool_size += reserve_size;
@@ -371,7 +476,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
 
         GGML_ASSERT(pool_addr != 0);
 
-        void * ptr = (void *) (pool_addr + pool_used);
+        void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used));
         *actual_size = size;
         pool_used += size;
 
@@ -390,17 +495,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
         pool_used -= size;
 
         // all deallocations must be in reverse order of the allocations
-        GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
+        GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
     }
 };
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#endif // defined(GGML_USE_VMM)
 
 std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) {
-#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#if defined(GGML_USE_VMM)
     if (ggml_cuda_info().devices[device].vmm) {
         return std::unique_ptr(new ggml_cuda_pool_vmm(device));
     }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#endif // defined(GGML_USE_VMM)
     return std::unique_ptr(new ggml_cuda_pool_leg(device));
 }
 
@@ -547,7 +652,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
     cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
     if (err != cudaSuccess) {
         // clear the error
-        cudaGetLastError();
+        (void)cudaGetLastError();
         GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
         return nullptr;
     }
@@ -962,7 +1067,7 @@ static void * ggml_cuda_host_malloc(size_t size) {
     cudaError_t err = cudaMallocHost((void **) &ptr, size);
     if (err != cudaSuccess) {
         // clear the error
-        cudaGetLastError();
+        (void)cudaGetLastError();
         GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
                            size / 1024.0 / 1024.0, cudaGetErrorString(err));
         return nullptr;
@@ -1082,7 +1187,9 @@ static void ggml_cuda_op_mul_mat_cublas(
 
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
-    if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
+    const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
+
+    if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
         ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id));
         if (src0->type != GGML_TYPE_F16) {
@@ -1103,28 +1210,38 @@ static void ggml_cuda_op_mul_mat_cublas(
             to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
         }
         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
-        ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols);
-
-        const half alpha_f16 = 1.0f;
-        const half beta_f16 = 0.0f;
-
-        cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
-        if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
-            cu_compute_type = CUBLAS_COMPUTE_32F;
-        }
 
         CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
-        CUBLAS_CHECK(
-            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
-                    row_diff, src1_ncols, ne10,
-                    &alpha_f16, src0_ptr,       CUDA_R_16F, ne00,
-                                src1_ptr,       CUDA_R_16F, ne10,
-                    &beta_f16,   dst_f16.get(), CUDA_R_16F, ldc,
-                    cu_compute_type,
-                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
-        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-        to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+        if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
+            const float alpha = 1.0f;
+            const float beta = 0.0f;
+            CUBLAS_CHECK(
+                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+                        row_diff, src1_ncols, ne10,
+                        &alpha, src0_ptr,  CUDA_R_16F, ne00,
+                                src1_ptr,  CUDA_R_16F, ne10,
+                        &beta,   dst_dd_i, CUDA_R_32F, ldc,
+                        CUBLAS_COMPUTE_32F,
+                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+        } else {
+            ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols);
+
+            const half alpha_f16 = 1.0f;
+            const half beta_f16 = 0.0f;
+
+            CUBLAS_CHECK(
+                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+                        row_diff, src1_ncols, ne10,
+                        &alpha_f16, src0_ptr,      CUDA_R_16F, ne00,
+                                    src1_ptr,      CUDA_R_16F, ne10,
+                        &beta_f16,  dst_f16.get(), CUDA_R_16F, ldc,
+                        CUBLAS_COMPUTE_16F,
+                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+            to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+        }
     } else {
         ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id));
         ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id));
@@ -1197,7 +1314,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
                         CUDA_CHECK(err);
                     } else {
                         // reset the error
-                        cudaGetLastError();
+                        (void)cudaGetLastError();
                     }
                 } else {
                     cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
@@ -1205,7 +1322,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
                         CUDA_CHECK(err);
                     } else {
                         // reset the error
-                        cudaGetLastError();
+                        (void)cudaGetLastError();
                     }
                 }
             }
@@ -1256,8 +1373,6 @@ static void ggml_cuda_op_mul_mat(
     const int64_t ne13 = src1->ne[3];
     const int64_t nrows1 = ggml_nrows(src1);
 
-    GGML_ASSERT(ne03 == ne13);
-
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
@@ -1271,9 +1386,11 @@ static void ggml_cuda_op_mul_mat(
 
     GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
 
-    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
 
     const int64_t i02_divisor = ne12 / ne02;
+    const int64_t i03_divisor = ne13 / ne03;
 
     const size_t src0_ts = ggml_type_size(src0->type);
     const size_t src0_bs = ggml_blck_size(src0->type);
@@ -1289,6 +1406,7 @@ static void ggml_cuda_op_mul_mat(
     GGML_ASSERT(!(split && ne02 > 1));
     GGML_ASSERT(!(split && ne03 > 1));
     GGML_ASSERT(!(split && ne02 < ne12));
+    GGML_ASSERT(!(split && ne03 < ne13));
 
     ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
 
@@ -1369,12 +1487,7 @@ static void ggml_cuda_op_mul_mat(
             const size_t nbytes_data    = ggml_nbytes(src0);
             const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
             dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
-        // TODO: remove this for MUSA once the Guilty Lockup issue is resolved
-#ifndef GGML_USE_MUSA
             CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
-#else // GGML_USE_MUSA
-            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
-#endif // !GGML_USE_MUSA
         }
 
         // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
@@ -1452,7 +1565,8 @@ static void ggml_cuda_op_mul_mat(
                 }
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+                const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;
+                char  *  src0_dd_i =  dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;
                 float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
                 char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset;
                 float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@@ -1496,8 +1610,9 @@ static void ggml_cuda_op_mul_mat(
                     CUDA_CHECK(cudaGetLastError());
                 }
 
-                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
-                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
+                if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+                        src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
                 }
 
                 // do the computation
@@ -1613,10 +1728,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
     cudaDataType_t      cu_data_type    = CUDA_R_16F;
 
-    if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
-        cu_compute_type = CUBLAS_COMPUTE_32F;
-    }
-
     // dst strides
     size_t nbd2 = dst->nb[2];
     size_t nbd3 = dst->nb[3];
@@ -1645,6 +1756,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         beta  = &beta_f32;
     }
 
+    if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
+        cu_compute_type = CUBLAS_COMPUTE_32F;
+        alpha = &alpha_f32;
+        beta  = &beta_f32;
+    }
+
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
 
@@ -1672,9 +1789,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         }
     }
 #else
-#ifdef GGML_USE_MUSA
-    GGML_ASSERT(false);
-#else // !GGML_USE_MUSA
     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
@@ -1717,7 +1831,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     }
-#endif // GGML_USE_MUSA
 #endif
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
@@ -1752,14 +1865,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
 
             const int cc              = ggml_cuda_info().devices[id].cc;
             use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
-            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
+            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
+            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
         }
     } else {
         const int cc              = ggml_cuda_info().devices[ctx.device].cc;
         use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
-        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
+        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
+        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
     }
 
     // debug helpers
@@ -1770,7 +1883,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
+    if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
         ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
@@ -2003,6 +2116,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GET_ROWS:
             ggml_cuda_op_get_rows(ctx, dst);
             break;
+        case GGML_OP_GET_ROWS_BACK:
+            ggml_cuda_op_get_rows_back(ctx, dst);
+            break;
         case GGML_OP_DUP:
             ggml_cuda_dup(ctx, dst);
             break;
@@ -2094,16 +2210,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_LEAKY_RELU:
             ggml_cuda_op_leaky_relu(ctx, dst);
             break;
+        case GGML_OP_SILU_BACK:
+            ggml_cuda_op_silu_back(ctx, dst);
+            break;
         case GGML_OP_RMS_NORM:
             ggml_cuda_op_rms_norm(ctx, dst);
             break;
+        case GGML_OP_RMS_NORM_BACK:
+            ggml_cuda_op_rms_norm_back(ctx, dst);
+            break;
         case GGML_OP_MUL_MAT:
-            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
-                GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
-                return false;
-            } else {
-                ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
-            }
+            ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
             break;
         case GGML_OP_MUL_MAT_ID:
             ggml_cuda_mul_mat_id(ctx, dst);
@@ -2141,9 +2258,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SOFT_MAX:
             ggml_cuda_op_soft_max(ctx, dst);
             break;
+        case GGML_OP_SOFT_MAX_BACK:
+            ggml_cuda_op_soft_max_back(ctx, dst);
+            break;
         case GGML_OP_ROPE:
             ggml_cuda_op_rope(ctx, dst);
             break;
+        case GGML_OP_ROPE_BACK:
+            ggml_cuda_op_rope_back(ctx, dst);
+            break;
         case GGML_OP_IM2COL:
             ggml_cuda_op_im2col(ctx, dst);
             break;
@@ -2173,6 +2296,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_RWKV_WKV6:
             ggml_cuda_op_rwkv_wkv6(ctx, dst);
             break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            ggml_cuda_op_gated_linear_attn(ctx, dst);
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             break;
@@ -2291,6 +2417,66 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
 }
 
 #ifdef USE_CUDA_GRAPH
+static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+    std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
+
+    // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+    cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        ggml_tensor * node = cgraph->nodes[i];
+
+        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+            continue;
+        }
+
+        if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
+            use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+        }
+
+        if (node->op == GGML_OP_MUL_MAT_ID) {
+            use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+#endif
+        }
+
+        if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+            // disable CUDA graphs for batch size > 1 for now.
+            // Changes in batch size or context size can cause changes to the grid size of some kernels.
+            use_cuda_graph = false;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+        }
+
+        if (node->op == GGML_OP_CPY) {
+            // store the copy op parameter which changes with each token.
+            cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+            // store a pointer to each copy op CUDA kernel to identify it later
+            void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+            if (!ptr) {
+                use_cuda_graph = false;
+#ifndef NDEBUG
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
+#endif
+            } else {
+                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
+                }
+            }
+        }
+
+        if (!use_cuda_graph) {
+            break;
+        }
+    }
+
+    return use_cuda_graph;
+}
+
 static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
     graph_node_properties->node_address = node->data;
     graph_node_properties->node_op = node->op;
@@ -2341,149 +2527,111 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
 
     return true;
 }
-#endif
 
-static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
 
-    ggml_cuda_set_device(cuda_ctx->device);
+    if (cuda_graph_update_required) {
+        // Extract nodes from graph
+        // First call with null argument gets number of nodes in graph
+        CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
+        // Subsequent call with non-null argument gets nodes
+        cuda_ctx->cuda_graph->nodes.clear();
+        cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+        cuda_ctx->cuda_graph->params.clear();
+        cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
+        if (cuda_ctx->cuda_graph->num_nodes > 0) {
+            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
 
-#ifdef USE_CUDA_GRAPH
-    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
-
-    // Objects required for CUDA Graph
-    if (cuda_ctx->cuda_graph == nullptr) {
-        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
-    }
-
-    bool use_cuda_graph = true;
-    bool cuda_graph_update_required = false;
-    // vector of pointers to CUDA cpy kernels, which are required to identify
-    // kernel parameters which need updated in the graph for each token
-    std::vector ggml_cuda_cpy_fn_ptrs;
-
-    if (cuda_ctx->cuda_graph->graph == nullptr) {
-        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
-            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
-#endif
-        }
-    }
-
-    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
-    // or previous graph capture failure.
-    // Also disable for multi-gpu for now. TO DO investigate
-    if (disable_cuda_graphs_due_to_env
-        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
-        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
-        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
-        use_cuda_graph = false;
-    }
-
-    if (use_cuda_graph) {
-        if (cuda_ctx->cuda_graph->instance == nullptr) {
-            cuda_graph_update_required = true;
-        }
-
-        // Check if the graph size has changed
-        if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
-            cuda_graph_update_required = true;
-            cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
-        }
-
-        // Loop over nodes in GGML graph to determine if CUDA graph update is required
-        // and store properties to allow this comparison for the next token
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            bool has_matching_properties = true;
-            if (!cuda_graph_update_required) {
-                has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
-            }
-            if (!has_matching_properties) {
-                cuda_graph_update_required = true;
-            }
-            set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
-        }
-
-        // Loop over nodes in GGML graph to obtain info needed for CUDA graph
-        cuda_ctx->cuda_graph->updated_kernel_arg.clear();
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            ggml_tensor * node = cgraph->nodes[i];
-
-            if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
-                continue;
-            }
-
-            if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
-                use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
-#endif
-            }
-
-            if (node->op == GGML_OP_MUL_MAT_ID) {
-                use_cuda_graph = false; // This node type is not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
-#endif
-            }
-
-            if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
-                // disable CUDA graphs for batch size > 1 for now.
-                // Changes in batch size or context size can cause changes to the grid size of some kernels.
-                use_cuda_graph = false;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
-#endif
-            }
-
-            if (node->op == GGML_OP_CPY) {
-                // store the copy op parameter which changes with each token.
-                cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
-                // store a pointer to each copy op CUDA kernel to identify it later
-                void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
-                if (!ptr) {
-                    use_cuda_graph = false;
-#ifndef NDEBUG
-                    GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
-#endif
-                } else {
-                    if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
-                        ggml_cuda_cpy_fn_ptrs.push_back(ptr);
+            // Loop over nodes, and extract kernel parameters from each node
+            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                cudaGraphNodeType node_type;
+                CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
+                if (node_type == cudaGraphNodeTypeKernel) {
+                    cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
+                    if (stat == cudaErrorInvalidDeviceFunction) {
+                        // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
+                        // We don't need to update blas nodes, so clear error and move on.
+                        (void)cudaGetLastError();
+                    } else {
+                        GGML_ASSERT(stat == cudaSuccess);
                     }
                 }
             }
-
-            if (!use_cuda_graph) {
-                break;
+        }
+    } else {
+        // One of the arguments to the copy kernel is updated for each token, hence we need to
+        // replace that argument with the updated value in the CUDA graph
+        // on update steps, the live parameters will already be captured
+        int k = 0;
+        for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+            if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
+                char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
+                cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+                CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
             }
         }
-
-        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
-        if (use_cuda_graph && cuda_graph_update_required) {
-            cuda_ctx->cuda_graph->number_consecutive_updates++;
-        } else {
-            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
-        }
-
-        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
-            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
-#endif
-        }
     }
+}
 
-    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
-        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
-    }
+static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
 
-#else
-    bool use_cuda_graph = false;
     bool cuda_graph_update_required = false;
-#endif // USE_CUDA_GRAPH
 
-    bool graph_evaluated_or_captured = false;
+    if (cuda_ctx->cuda_graph->instance == nullptr) {
+        cuda_graph_update_required = true;
+    }
+
+    // Check if the graph size has changed
+    if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+        cuda_graph_update_required = true;
+        cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+    }
+
+    // Loop over nodes in GGML graph to determine if CUDA graph update is required
+    // and store properties to allow this comparison for the next token
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        bool has_matching_properties = true;
+        if (!cuda_graph_update_required) {
+            has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+        }
+        if (!has_matching_properties) {
+            cuda_graph_update_required = true;
+        }
+        set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+    }
+
+    return cuda_graph_update_required;
+}
+
+static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
+
+    cudaGraphExecUpdateResultInfo result_info;
+#ifdef __HIP_PLATFORM_AMD__
+    hipGraphNode_t errorNode;
+    hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
+#else
+    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+#endif
+    if (stat == cudaErrorGraphExecUpdateFailure) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
+#endif
+
+        // The pre-existing graph exec cannot be updated due to violated constraints
+        // so instead clear error and re-instantiate
+        (void)cudaGetLastError();
+        CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+        cuda_ctx->cuda_graph->instance = nullptr;
+        CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+    } else {
+        GGML_ASSERT(stat == cudaSuccess);
+    }
+}
+#endif
+
+static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+   [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs,  bool & graph_evaluated_or_captured, bool & use_cuda_graph,
+    bool & cuda_graph_update_required) {
 
     while (!graph_evaluated_or_captured) {
         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -2521,19 +2669,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
                 CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
                 cuda_ctx->cuda_graph->graph = nullptr;
             }
-            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
 
-#if 0
-            if (disable_cuda_graphs_due_to_failed_capture) {
-                use_cuda_graph = false;
-                cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
-#endif
-            } else {
-                graph_evaluated_or_captured = true; // CUDA graph has been captured
-            }
-#endif
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
             graph_evaluated_or_captured = true; // CUDA graph has been captured
         } else {
             graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
@@ -2546,72 +2683,91 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         }
 
         // Perform update to graph (if required for this token), and change copy parameter (required for every token)
-
-        if (cuda_graph_update_required) {
-            // Extract nodes from graph
-            // First call with null argument gets number of nodes in graph
-            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-            // Subsequent call with non-null argument gets nodes
-            cuda_ctx->cuda_graph->nodes.clear();
-            cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
-            cuda_ctx->cuda_graph->params.clear();
-            cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
-            if (cuda_ctx->cuda_graph->num_nodes > 0) {
-                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
-
-                // Loop over nodes, and extract kernel parameters from each node
-                for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                    cudaGraphNodeType node_type;
-                    CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
-                    if (node_type == cudaGraphNodeTypeKernel) {
-                        cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
-                        if (stat == cudaErrorInvalidDeviceFunction) {
-                            // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
-                            // We don't need to update blas nodes, so clear error and move on.
-                            cudaGetLastError();
-                        } else {
-                            GGML_ASSERT(stat == cudaSuccess);
-                        }
-                    }
-                }
-            }
-        }
-
-        // One of the arguments to the copy kernel is updated for each token, hence we need to
-        // replace that argument with the updated value in the CUDA graph
-        if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
-            int k = 0;
-            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
-                    char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
-                    cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
-                    CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
-                }
-            }
-        }
+        maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
 
         // Update graph executable
-        cudaGraphExecUpdateResultInfo result_info;
-        cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
-        if (stat == cudaErrorGraphExecUpdateFailure) {
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
-#endif
-            // The pre-existing graph exec cannot be updated due to violated constraints
-            // so instead clear error and re-instantiate
-            cudaGetLastError();
-            CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
-            cuda_ctx->cuda_graph->instance = nullptr;
-            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
-        } else {
-            GGML_ASSERT(stat == cudaSuccess);
-        }
+        update_cuda_graph_executable(cuda_ctx);
+
         // Launch graph
         CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
 #else
         graph_evaluated_or_captured = true;
-#endif // USE_CUDA_GRAPH
+#endif  // USE_CUDA_GRAPH
     }
+}
+
+static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    ggml_cuda_set_device(cuda_ctx->device);
+
+    // vector of pointers to CUDA cpy kernels, which are required to identify
+    // kernel parameters which need updated in the graph for each token
+    std::vector ggml_cuda_cpy_fn_ptrs;
+
+#ifdef USE_CUDA_GRAPH
+    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+
+    // Objects required for CUDA Graph
+    if (cuda_ctx->cuda_graph == nullptr) {
+        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+    }
+
+    bool use_cuda_graph = true;
+    bool cuda_graph_update_required = false;
+
+    if (cuda_ctx->cuda_graph->graph == nullptr) {
+        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
+            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+#endif
+        }
+    }
+
+    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
+    // or previous graph capture failure.
+    // Also disable for multi-gpu for now. TO DO investigate
+    if (disable_cuda_graphs_due_to_env
+        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
+        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
+        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
+        use_cuda_graph = false;
+    }
+
+    if (use_cuda_graph) {
+        cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
+
+        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
+                             ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
+
+        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+        if (use_cuda_graph && cuda_graph_update_required) {
+            cuda_ctx->cuda_graph->number_consecutive_updates++;
+        } else {
+            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
+        }
+
+        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
+            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+#endif
+        }
+    }
+
+    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+    }
+
+#else
+    bool use_cuda_graph = false;
+    bool cuda_graph_update_required = false;
+#endif // USE_CUDA_GRAPH
+
+    bool graph_evaluated_or_captured = false;
+
+    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
 
     return GGML_STATUS_SUCCESS;
 }
@@ -2687,11 +2843,11 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
         return false;
     }
 
-#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
+#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
     cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
     if (err != cudaSuccess) {
         // clear the error
-        cudaGetLastError();
+        (void)cudaGetLastError();
 
         GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
                            size / 1024.0 / 1024.0, cudaGetErrorString(err));
@@ -2699,8 +2855,10 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
     }
     return true;
 #else
+    GGML_UNUSED(buffer);
+    GGML_UNUSED(size);
     return false;
-#endif
+#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
 }
 
 void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
@@ -2711,7 +2869,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
     cudaError_t err = cudaHostUnregister(buffer);
     if (err != cudaSuccess) {
         // clear the error
-        cudaGetLastError();
+        (void)cudaGetLastError();
     }
 }
 
@@ -2843,9 +3001,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
                     return false;
                 }
-                if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
-                    return false;
-                }
 #ifdef GGML_USE_MUSA
                 if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
                     !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
@@ -2887,7 +3042,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 }
             } break;
         case GGML_OP_OUT_PROD:
-            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_GET_ROWS:
             {
                 switch (op->src[0]->type) {
@@ -2903,6 +3058,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                         return false;
                 }
             } break;
+        case GGML_OP_GET_ROWS_BACK:
+            {
+                return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+            } break;
         case GGML_OP_CPY:
             {
                 ggml_type src0_type = op->src[0]->type;
@@ -2922,15 +3081,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
                     return true;
                 }
@@ -2961,7 +3132,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
             } break;
         case GGML_OP_REPEAT_BACK:
-                return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
+                return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
         case GGML_OP_CONCAT:
             {
                 ggml_type src0_type = op->src[0]->type;
@@ -2976,8 +3147,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 }
                 return false;
             } break;
+        case GGML_OP_SILU_BACK:
+            return ggml_is_contiguous(op->src[0]);
+            break;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
+            return true;
+        case GGML_OP_RMS_NORM_BACK:
             return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
             break;
         case GGML_OP_NONE:
@@ -3002,15 +3178,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
             return true;
+        case GGML_OP_SOFT_MAX_BACK: {
+            float max_bias = 0.0f;
+            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
+            return max_bias == 0.0f;
+        }
         case GGML_OP_ROPE:
-            return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_ROPE_BACK: {
+            const size_t ts = ggml_type_size(op->src[0]->type);
+            const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
+            return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
+        }
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
+            return true;
         case GGML_OP_GROUP_NORM:
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
         case GGML_OP_UNPAD:
@@ -3018,11 +3205,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
             return false;
-#endif
+#endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
                 return false;
             }
@@ -3035,8 +3223,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
-            const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-            return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+            return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
+                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
         }
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@@ -3059,6 +3247,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
             return op->ne[1];
         case GGML_OP_MUL_MAT_ID:
         case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
             return op->ne[2];
         default:
             return ggml_nrows(op);
@@ -3161,7 +3350,7 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
         features.push_back({ "FORCE_CUBLAS", "1" });
     #endif
 
-    #ifdef GGML_CUDA_NO_VMM
+    #ifndef GGML_USE_VMM
         features.push_back({ "NO_VMM", "1" });
     #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/gla.cu b/ml/backend/ggml/ggml/src/ggml-cuda/gla.cu
new file mode 100644
index 000000000..f7d615a82
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/gla.cu
@@ -0,0 +1,93 @@
+#include "common.cuh"
+#include "gla.cuh"
+
+template
+static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
+     const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = HEAD_SIZE;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4 & k = (float4 &)(_k[j]);
+            const float4 & r = (float4 &)(_r[j]);
+            const float4 & td = (float4 &)(_td[j]);
+            float4 & s = (float4 &)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+
+            y += r.x * s.x;
+            y += r.y * s.y;
+            y += r.z * s.z;
+            y += r.w * s.w;
+        }
+        dst[t] = y * scale;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * td_d = (const float *)dst->src[3]->data;
+    const float * s_d  = (const float *)dst->src[4]->data;
+
+    const int64_t B = dst->src[4]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    float scale;
+    memcpy(&scale, (float*)dst->op_params, sizeof(float));
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == 64 || C / H == 128);
+
+
+    if (C / H == 64) {
+        gated_linear_attn_f32<64><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    } else {
+        gated_linear_attn_f32<128><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/gla.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/gla.cuh
new file mode 100644
index 000000000..2c82ad7dd
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/gla.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
index 7d11540af..9206bfeba 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
@@ -1,221 +1,394 @@
+// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
+// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
+// The documentation for the PTX instructions can be found under:
+//   https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
+//
+// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
+// A is a row-major matrix with shape M x K.
+// B is a column-major matrix with shape K x N.
+// C is a column-major matrix with shape M x N.
+// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
+// Note that J is measured in physical 32 bit elements instead of logical elements.
+// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
+// All matrix tiles have ne physical 32 bit elements per warp.
+//
+// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
+
 #include "common.cuh"
 
-struct mma_int_A_I16K4 {
-    static constexpr int I  = 16;
-    static constexpr int K  = 4;
-    static constexpr int ne = 2;
 
-    int x[ne] = {0};
+#if CUDART_VERSION >= 11080
 
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l%2) * (I/2) + threadIdx.x / K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
+static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
+    int ret = 0;
+
+#ifdef NEW_MMA_AVAILABLE
+    asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
+        : "=r"(ret) : "r"(x));
+#else
+    NO_DEVICE_CODE;
+#endif // defined(NEW_MMA_AVAILABLE)
+    return ret;
+}
+
+#else
+
+static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
+    // Imagine transposing row-major matrix to column-major matrix.
+    const int src_i_low  = 2 * (threadIdx.x % 4);
+    const int src_i_high = src_i_low + 1;
+    const int src_j      = threadIdx.x / 4;
+
+    const int src_laneid_low  = src_i_low  * 4 + src_j / 2;
+    const int src_laneid_high = src_i_high * 4 + src_j / 2;
+
+    const int shift_low  = ((src_j + 0) % 2) * 16;
+    const int shift_high = ((src_j + 1) % 2) * 16;
+
+    const int ret_low  = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low,  WARP_SIZE) >> shift_low)  & 0x0000FFFF;
+    const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
+
+    return ret_low | ret_high;
+}
+
+#endif // CUDART_VERSION >= 11080
+
+static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
+    half2 ret;
+    *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
+    return ret;
+}
+
+namespace ggml_cuda_mma {
+
+    template 
+    struct tile {
+        static constexpr int I  = I_;
+        static constexpr int J  = J_;
+        static constexpr int ne = I * J / WARP_SIZE;
+        T x[ne] = {0};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && (J == 4 || J == 8)) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l / 2) * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 16) {
+                return ((l / 2) % 2) * 8 + threadIdx.x / 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 4) {
+                return threadIdx.x % 4;
+            } else if constexpr (I == 8 && J == 8) {
+                return 4 * l + threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return 2 * (threadIdx.x % 4) + l % 2;
+            } else if constexpr (I == 16 && J == 16) {
+                return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+    };
+
+    template 
+    struct tile {
+        static constexpr int I  = I_;
+        static constexpr int J  = J_;
+        static constexpr int ne = I * J / WARP_SIZE;
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return l * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l % 2) * 8 + threadIdx.x / 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return l * 4 + threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l / 2) * 4 + threadIdx.x % 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+    };
+
+    template 
+    static __device__ __forceinline__ tile get_half2(const tile & tile_float) {
+        tile ret;
+#pragma unroll
+        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+        }
         return ret;
     }
 
-    static __device__ __forceinline__ int get_k(const int /* l */) {
-        const int ret = threadIdx.x % K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
+    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+        tile<8, 8, half2> ret;
+        ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
+        ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
+
         return ret;
     }
 
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE)
-        const int * xs = xs0 + (threadIdx.x%I)*stride;
-        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
-            : "+r"(x[0]), "+r"(x[1])
+    template 
+    static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) {
+#pragma unroll
+        for (int l = 0; l < t.ne; ++l) {
+            t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
+        }
+    }
+
+    template 
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int *) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
+        asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "=r"(xi[0]), "=r"(xi[1])
             : "l"(xs));
 #else
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_i(l)*stride + get_k(l)];
-        }
-#endif // defined(INT8_MMA_AVAILABLE)
-    }
-};
-
-struct mma_int_A_I16K8 {
-    static constexpr int I  = 16;
-    static constexpr int K  = 8;
-    static constexpr int ne = 4;
-
-    int x[ne] = {0};
-
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
-        return ret;
+        load_generic(t, xs0, stride);
+#endif // NEW_MMA_AVAILABLE
     }
 
-    static __device__ __forceinline__ int get_k(const int l) {
-        const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
-        return ret;
-    }
-
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE)
-        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
-        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
-            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+    template 
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int *) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
+        asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "=r"(xi[0]), "=r"(xi[1])
             : "l"(xs));
 #else
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_i(l)*stride + get_k(l)];
-        }
-#endif // defined(INT8_MMA_AVAILABLE)
+        load_generic(xs0, stride);
+#endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
-        ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
-    }
-};
-
-struct mma_int_B_J8K4 {
-    static constexpr int J  = 8;
-    static constexpr int K  = 4;
-    static constexpr int ne = 1;
-
-    int x[ne] = {0};
-
-    static __device__ __forceinline__ int get_j(const int /* l */) {
-        const int ret = threadIdx.x / K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
-    }
-
-    static __device__ __forceinline__ int get_k(const int /* l */) {
-        const int ret = threadIdx.x % K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
-        return ret;
-    }
-
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
-        const int * xs = xs0 + (threadIdx.x%J)*stride;
-        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
-            : "+r"(x[0])
+    template 
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int * ) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+        asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
             : "l"(xs));
 #else
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_j(l)*stride + get_k(l)];
-        }
-#endif // defined(INT8_MMA_AVAILABLE)
-    }
-};
-
-struct mma_int_B_J8K8 {
-    static constexpr int J  = 8;
-    static constexpr int K  = 8;
-    static constexpr int ne = 2;
-
-    int x[ne] = {0};
-
-    static __device__ __forceinline__ int get_j(const int /* l */) {
-        const int ret = threadIdx.x / (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
+        load_generic(t, xs0, stride);
+#endif // NEW_MMA_AVAILABLE
     }
 
-    static __device__ __forceinline__ int get_k(const int l) {
-        const int ret = l * (K/2) + threadIdx.x % (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
-        return ret;
-    }
-
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
-        const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
-        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
-            : "+r"(x[0]), "+r"(x[1])
+    template 
+    static __device__ __forceinline__ void load_ldmatrix_trans(
+            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int * ) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+        asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
+            : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
             : "l"(xs));
 #else
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_j(l)*stride + get_k(l)];
-        }
-#endif // defined(INT8_MMA_AVAILABLE)
-    }
-};
-
-struct mma_int_C_I16J8 {
-    static constexpr int I  = 16;
-    static constexpr int J  = 8;
-    static constexpr int ne = 4;
-
-    int x[ne] = {0};
-
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
-        return ret;
+        GGML_UNUSED(t);
+        GGML_UNUSED(xs0);
+        GGML_UNUSED(stride);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
     }
 
-    static __device__ __forceinline__ int get_j(const int l) {
-        const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
-    }
-
-    __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
-#ifdef INT8_MMA_AVAILABLE
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
+#ifdef NEW_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
-            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+            : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
 #else
         // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[0]), "+r"(x[1])
-            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+            : "+r"(D.x[0]), "+r"(D.x[1])
+            : "r"(A.x[0]), "r"(B.x[0]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+            : "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[1]), "r"(B.x[0]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
-        GGML_UNUSED(mma_A);
-        GGML_UNUSED(mma_B);
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
-#ifdef INT8_MMA_AVAILABLE
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
+#ifdef NEW_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
-            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
+            : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
 #else
         // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[0]), "+r"(x[1])
-            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+            : "+r"(D.x[0]), "+r"(D.x[1])
+            : "r"(A.x[0]), "r"(B.x[0]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+            : "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[1]), "r"(B.x[0]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[0]), "+r"(x[1])
-            : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
+            : "+r"(D.x[0]), "+r"(D.x[1])
+            : "r"(A.x[2]), "r"(B.x[1]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
+            : "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[3]), "r"(B.x[1]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
-        GGML_UNUSED(mma_A);
-        GGML_UNUSED(mma_B);
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
-};
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
index 270251df4..10f2ebb1c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
@@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q(
     const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
 
     int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
+    const int cc = ggml_cuda_info().devices[id].cc;
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // nrows_dst == nrows of the matrix that the kernel writes into
@@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q(
     // The stream-k decomposition is only faster for recent NVIDIA GPUs.
     // Also its fixup needs to allocate a temporary buffer in the memory pool.
     // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
-    const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
+    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
+        cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
     const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
 
     switch (src0->type) {
@@ -132,11 +133,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return false;
     }
 
-    if (int8_mma_available(cc)) {
+    if (new_mma_available(cc)) {
         return true;
     }
 
-    if (cc < GGML_CUDA_CC_DP4A) {
+    if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
         return false;
     }
 
@@ -145,8 +146,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
 #endif //GGML_CUDA_FORCE_MMQ
 
     if (cc < GGML_CUDA_CC_OFFSET_AMD) {
-        return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+        return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
-    return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
index 3cd508a1d..0451c65f3 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
@@ -7,6 +7,8 @@
 #include 
 #include 
 
+using namespace ggml_cuda_mma;
+
 #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
 #define MMQ_ITER_K 256
 #define MMQ_NWARPS 8
@@ -86,19 +88,20 @@ struct tile_x_sizes {
     int sc;
 };
 
-static constexpr int get_mmq_x_max_host(const int cc) {
-    return int8_mma_available(cc) ? 128 :
+static int get_mmq_x_max_host(const int cc) {
+    return new_mma_available(cc) ? 128 :
+        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
 #ifdef GGML_CUDA_FORCE_MMQ
-        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128                     : 64;
+            128                     : 64;
 #else
-        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
+            MMQ_DP4A_MAX_BATCH_SIZE : 64;
 #endif // GGML_CUDA_FORCE_MMQ
 }
 
 static constexpr __device__ int get_mmq_x_max_device() {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     return 128;
-#else // INT8_MMA_AVAILABLE
+#else // NEW_MMA_AVAILABLE
 
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
     return 128;
@@ -116,11 +119,12 @@ static constexpr __device__ int get_mmq_x_max_device() {
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
-static constexpr int get_mmq_y_host(const int cc) {
-    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (cc == GGML_CUDA_CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
+static int get_mmq_y_host(const int cc) {
+    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
+        (ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
 }
 
 static constexpr __device__ int get_mmq_y_device() {
@@ -209,10 +213,10 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
 #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
 
 static int mmq_get_granularity_host(const int mmq_x, const int cc) {
-    return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+    return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
 }
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
     return mmq_x >= 48 ? 16 : 8;
 }
@@ -220,21 +224,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
 static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
     return 8;
 }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 // ------------------------------------------------------------
 
 template  static __device__ __forceinline__ void load_tiles_q4_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_0;
     const int kqsx = threadIdx.x % QI4_0;
@@ -250,12 +254,12 @@ template  static __device__ __forceinlin
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b2(bxi->qs, kqsx);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
 #else
         x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -271,11 +275,11 @@ template  static __device__ __forceinlin
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -322,14 +326,14 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
 template  static __device__ __forceinline__ void load_tiles_q4_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_1;
     const int kqsx = threadIdx.x % QI4_1;
@@ -345,12 +349,12 @@ template  static __device__ __forceinlin
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b4(bxi->qs, kqsx);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@@ -366,11 +370,11 @@ template  static __device__ __forceinlin
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
 #else
         x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -417,14 +421,14 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
 template  static __device__ __forceinline__ void load_tiles_q5_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_0;
     const int kqsx = threadIdx.x % QI5_0;
@@ -456,13 +460,13 @@ template  static __device__ __forceinlin
         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
@@ -478,25 +482,25 @@ template  static __device__ __forceinlin
 
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_q5_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_1;
     const int kqsx = threadIdx.x % QI5_1;
@@ -526,13 +530,13 @@ template  static __device__ __forceinlin
         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
@@ -548,25 +552,25 @@ template  static __device__ __forceinlin
 
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
 #else
         x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_q8_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_tile + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI8_0;
     const int kqsx = threadIdx.x % QI8_0;
@@ -581,13 +585,13 @@ template  static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
         x_qs[i*(2*WARP_SIZE + 1)     + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
@@ -603,11 +607,11 @@ template  static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = bxi->d;
 #else
         x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -645,15 +649,15 @@ template 
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 8, int> tile_A;
+    typedef tile< 8, 8, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
@@ -661,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const float * y_df = (const float *) y;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A[ntx][WARP_SIZE/QI8_0];
-    float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
+    tile_A A[ntx][WARP_SIZE/QI8_0];
+    float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
@@ -672,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
             const int k0 = k00 + k01;
 
-            A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+            load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
@@ -689,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
-            mma_B  B;
-            float dB[mma_C::ne/2];
+            tile_B B;
+            float dB[tile_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
                     dB[l] =             y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -710,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma_K8(A[n][k01/QI8_0], B);
+                tile_C C;
+                mma(C, A[n][k01/QI8_0], B);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
                 }
             }
         }
@@ -756,23 +760,23 @@ template 
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 8, int> tile_A;
+    typedef tile< 8, 8, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_dm = (const half2 *) y;
 
-    mma_A    A[ntx][WARP_SIZE/QI8_1];
-    float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
+    tile_A   A[ntx][WARP_SIZE/QI8_1];
+    float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
@@ -782,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
-            A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+            load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
@@ -799,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
-            mma_B    B;
-            float2 dsB[mma_C::ne/2];
+            tile_B   B;
+            float2 dsB[tile_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma_K8(A[n][k01/QI8_1], B);
+                tile_C C;
+                mma(C, A[n][k01/QI8_1], B);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
                 }
             }
         }
@@ -864,28 +868,28 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
 template 
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_A_I16K8 mma_A_K8;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 4, int> tile_A;
+    typedef tile<16, 8, int> tile_A_8;
+    typedef tile< 8, 4, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + WARP_SIZE*2;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
 
-    mma_A   A[ntx][8];
-    float  dA[ntx][mma_C::ne/2][8];
+    tile_A  A[ntx][8];
+    float  dA[ntx][tile_C::ne/2][8];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
@@ -893,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             const int k0 = k00 + k01;
 
-            ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+            load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
@@ -910,31 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
-            mma_B B[2];
-            float dB[mma_C::ne/2];
+            tile_B B[2];
+            float dB[tile_C::ne/2];
 
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+            // Here load_generic is faster than load_ldmatrix.
+            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);
+            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C[2];
-                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+                tile_C C[2];
+                mma(C[0], A[n][k01/4 + 0], B[0]);
+                mma(C[1], A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
                 }
             }
         }
@@ -942,20 +947,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template  static __device__ __forceinline__ void load_tiles_q2_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % QI2_K;
 
@@ -977,11 +982,11 @@ template  static __device__ __forceinlin
 
             const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int sc_m = bxi->scales[kqsx];
@@ -992,11 +997,11 @@ template  static __device__ __forceinlin
         const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
 #endif // FAST_FP16_AVAILABLE
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
 #else
         x_dm[i*(WARP_SIZE + 1)       + kqsx] = x_dm_ik;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -1051,29 +1056,29 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
 template 
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_A_I16K8 mma_A_K8;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 4, int> tile_A;
+    typedef tile<16, 8, int> tile_A_8;
+    typedef tile< 8, 4, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
 
-    mma_A   A[ntx][8];
-    float  dA[ntx][mma_C::ne/2][8];
-    float  mA[ntx][mma_C::ne/2][8];
+    tile_A  A[ntx][8];
+    float  dA[ntx][tile_C::ne/2][8];
+    float  mA[ntx][tile_C::ne/2][8];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
@@ -1081,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
-            ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+            load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
         }
     }
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
@@ -1104,57 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        float2 dB[mma_C::ne/2];
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+        float2 dB[tile_C::ne/2];
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int j = j0 + mma_C::get_j(l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int j = j0 + tile_C::get_j(l);
 
             dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
         }
 
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
-            mma_B B[2];
+            tile_B B[2];
 
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+            // Here load_generic is faster than load_ldmatrix.
+            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);
+            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
 
-            mma_C Cm[2];
+            tile_C Cm[2];
             if (k01 >= WARP_SIZE * 3/4) {
-                mma_A A1;
+                tile_A A1;
                 A1.x[0] = 0x01010101;
                 A1.x[1] = 0x01010101;
-                Cm[0].mma_K4(A1, B[0]);
-                Cm[1].mma_K4(A1, B[1]);
+                mma(Cm[0], A1, B[0]);
+                mma(Cm[1], A1, B[1]);
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C Cd[2];
+                tile_C Cd[2];
 
-                Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
+                mma(Cd[0], A[n][k01/4 + 0], B[0]);
+                mma(Cd[1], A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
+                for (int l = 0; l < tile_C::ne; ++l) {
                     float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
                     if (k01 >= WARP_SIZE * 3/4) {
                         tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
                     }
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
                 }
             }
         }
 
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
-            float2 sB[mma_C::ne/2];
+            float2 sB[tile_C::ne/2];
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
             }
@@ -1162,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
                 }
             }
         }
@@ -1172,13 +1178,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template  static __device__ __forceinline__ void load_tiles_q3_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
@@ -1186,7 +1192,7 @@ template  static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % QI3_K;
 
@@ -1212,11 +1218,11 @@ template  static __device__ __forceinlin
 
             const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
     }
 
@@ -1242,7 +1248,7 @@ template  static __device__ __forceinlin
 
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         const int8_t * sc8 = (const int8_t *) ≻
         const float d = bxi->d;
 
@@ -1252,10 +1258,10 @@ template  static __device__ __forceinlin
         }
 #else
         x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-#ifndef INT8_MMA_AVAILABLE
+#ifndef NEW_MMA_AVAILABLE
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
         int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
@@ -1268,7 +1274,7 @@ template  static __device__ __forceinlin
 
         x_df[i] = bxi->d;
     }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template 
@@ -1317,7 +1323,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
 template  static __device__ __forceinline__ void load_tiles_q4_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
@@ -1325,7 +1331,7 @@ template  static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1338,15 +1344,15 @@ template  static __device__ __forceinlin
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
         const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
@@ -1407,7 +1413,7 @@ template  static __device__ __forceinlin
 
         x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
     }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template 
@@ -1446,7 +1452,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
 template  static __device__ __forceinline__ void load_tiles_q5_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
 #else
@@ -1454,7 +1460,7 @@ template  static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1478,16 +1484,16 @@ template  static __device__ __forceinlin
         const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
         const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kq0] = ql0 | qh0;
         x_qs[i*(2*WARP_SIZE + 1)     + kq1] = ql1 | qh1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
@@ -1548,7 +1554,7 @@ template  static __device__ __forceinlin
 
         x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
     }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template 
@@ -1587,7 +1593,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
 template  static __device__ __forceinline__ void load_tiles_q6_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
     int   * x_sc = (int   *) (x_df + WARP_SIZE/QI6_K);
@@ -1596,7 +1602,7 @@ template  static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1619,13 +1625,13 @@ template  static __device__ __forceinlin
         const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
         const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*(2*WARP_SIZE + 1)     + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
@@ -1641,11 +1647,11 @@ template  static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q6_K       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1658,11 +1664,11 @@ template  static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
 #else
         x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -1702,17 +1708,17 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
 template 
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 4, int> tile_A;
+    typedef tile< 8, 4, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + WARP_SIZE*2;
@@ -1720,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
 
-    mma_A   A[ntx][8];
-    int   scA[ntx][mma_C::ne/2][8];
-    float  dA[ntx][mma_C::ne/2];
+    tile_A   A[ntx][8];
+    int    scA[ntx][tile_C::ne/2][8];
+    float   dA[ntx][tile_C::ne/2];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
@@ -1732,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             const int k0 = k00 + k01;
 
-            A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),        MMQ_MMA_TILE_X_K_Q6_K);
-            A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+            load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),         MMQ_MMA_TILE_X_K_Q6_K);
+            load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
         }
 
 #pragma unroll
@@ -1741,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
             const int k0 = k00 + k01;
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
                 const int      sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
                 const int8_t * sc        = (const int8_t *) &sc_packed;
@@ -1755,40 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
             dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
         }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        float tmp[ntx][mma_C::ne] = {{0.0f}};
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+        float tmp[ntx][tile_C::ne] = {{0.0f}};
 
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
-            mma_B B[2];
-            float dB[mma_C::ne/2];
+            tile_B B[2];
+            float dB[tile_C::ne/2];
 
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0        + k01, MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
+            // Here load_generic is faster than load_ldmatrix.
+            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0         + k01, MMQ_TILE_Y_K);
+            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C[2];
-                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+                tile_C C[2];
+                mma(C[0], A[n][k01/4 + 0], B[0]);
+                mma(C[1], A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
+                for (int l = 0; l < tile_C::ne; ++l) {
                     tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
                 }
             }
@@ -1797,28 +1804,28 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 #pragma unroll
         for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+            for (int l = 0; l < tile_C::ne; ++l) {
+                sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
             }
         }
     }
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq4_nl(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_NL;
     const int kqsx = threadIdx.x % QI4_NL;
@@ -1836,13 +1843,13 @@ template  static __device__ __forceinlin
         const int aux_q4 = get_int_b2(bxi->qs, kqsx);
         const int2 v = get_int_from_table_16(aux_q4);
         const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
@@ -1858,25 +1865,25 @@ template  static __device__ __forceinlin
 
         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kbxd] = __half2float(bxi->d);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq2_xxs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI2_XXS/2);
 
@@ -1905,36 +1912,36 @@ template  static __device__ __forceinlin
             const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
             const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/4;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq2_xs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI2_XS/2);
 
@@ -1959,38 +1966,38 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq2_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI2_S/2);
 
@@ -2022,38 +2029,38 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq3_xxs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI3_XXS/2);
 
@@ -2080,36 +2087,36 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/2;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq3_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI3_S/2);
 
@@ -2143,36 +2150,36 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = ls*d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq1_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % QI1_S;
 
@@ -2198,37 +2205,37 @@ template  static __device__ __forceinlin
             const int grid0 = (grid >> 0) & 0x0F0F0F0F;
             const int grid1 = (grid >> 4) & 0x0F0F0F0F;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
         const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
 #else
         x_ds[i*(WARP_SIZE/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq4_xs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = 0;           // threadIdx.x / QI4_XS
     const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
@@ -2246,13 +2253,13 @@ template  static __device__ __forceinlin
         const int aux_q4 = get_int_b4(bxi->qs, kqsx);
         const int2 v = get_int_from_table_16(aux_q4);
         const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -2270,11 +2277,11 @@ template  static __device__ __forceinlin
         const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
             | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + threadIdx.x % 8] = d * (ls - 32);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -2307,36 +2314,36 @@ template
 static __device__ __forceinline__ void mmq_write_back_mma(
     const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
 
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
-#ifdef INT8_MMA_AVAILABLE
-    static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
-#endif // INT8_MMA_AVAILABLE
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
+#ifdef NEW_MMA_AVAILABLE
+    static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne; ++l) {
+                const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
 
                 if (j > j_max) {
                     continue;
                 }
 
-                const int i = i0 + n*mma_C::I + mma_C::get_i(l);
+                const int i = i0 + n*tile_C::I + tile_C::get_i(l);
 
                 if (need_check && i > i_max) {
                     continue;
                 }
 
-                dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+                dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
             }
         }
     }
@@ -2505,13 +2512,13 @@ static __device__ void mul_mat_q_process_tile(
     int * tile_y = (int *) data_mul_mat_q;
     int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_mma;
     constexpr mmq_write_back_t write_back = mmq_write_back_mma;
 #else
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_dp4a;
     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     constexpr int blocks_per_iter = MMQ_ITER_K / qk;
 
@@ -2643,7 +2650,7 @@ static __global__ void mul_mat_q(
     const int jt =  kbc /    (blocks_per_ne00*nty);
     const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
 
-    constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
+    constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     mul_mat_q_process_tile
         (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
             it, jt, kb0_start, kb0_stop);
@@ -2749,7 +2756,7 @@ template
 static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
     const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
-    const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
     const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
     return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
 }
@@ -2825,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc);
     const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
-    const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
+    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
 
     int mmq_x_best  = 0;
     int nparts_best = INT_MAX;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
index ac45f2d17..f89ed03b5 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
@@ -1,25 +1,29 @@
+#include "ggml.h"
 #include "common.cuh"
 #include "mmv.cuh"
 
 template 
 static __global__ void mul_mat_vec(
         const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
-        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
-    const int64_t row     = blockIdx.x;
-    const int64_t channel = blockIdx.z;
-    const int     tid     = threadIdx.x;
+        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
+    const int64_t row       = blockIdx.x;
+    const int64_t channel   = blockIdx.y;
+    const int64_t sample    = blockIdx.z;
+    const int     tid       = threadIdx.x;
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-    x   += (channel/channel_ratio)*stride_channel_x + row*stride_row;
-    y   +=  channel               *stride_channel_y;
-    dst +=  channel               *stride_channel_dst;
+    x   +=  (sample/sample_ratio)*stride_sample_x   + (channel/channel_ratio)*stride_channel_x + row*stride_row;
+    y   +=   sample              *stride_sample_y   +  channel               *stride_channel_y;
+    dst +=   sample              *stride_sample_dst +  channel               *stride_channel_dst;
 
     const float2 * y2 = (const float2 *) y;
 
     extern __shared__ char data_mmv[];
     float * buf_iw = (float *) data_mmv;
 
-    if (block_size > WARP_SIZE) {
-        if (tid < WARP_SIZE) {
+    if (block_size > warp_size) {
+        if (tid < warp_size) {
             buf_iw[tid] = 0.0f;
         }
         __syncthreads();
@@ -67,16 +71,16 @@ static __global__ void mul_mat_vec(
         static_assert(std::is_same::value, "unsupported type");
     }
 
-    sumf = warp_reduce_sum(sumf);
+    sumf = warp_reduce_sum(sumf);
 
-    if (block_size > WARP_SIZE) {
-        buf_iw[tid/WARP_SIZE] = sumf;
+    if (block_size > warp_size) {
+        buf_iw[tid/warp_size] = sumf;
         __syncthreads();
-        if (tid >= WARP_SIZE) {
+        if (tid >= warp_size) {
             return;
         }
         sumf = buf_iw[tid];
-        sumf = warp_reduce_sum(sumf);
+        sumf = warp_reduce_sum(sumf);
     }
 
     if (tid != 0) {
@@ -90,16 +94,28 @@ template 
 static void launch_mul_mat_vec_cuda(
         const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         cudaStream_t stream) {
     GGML_ASSERT(ncols      % 2 == 0);
     GGML_ASSERT(stride_row % 2 == 0);
     GGML_ASSERT(nchannels_y % nchannels_x == 0);
+    GGML_ASSERT(nsamples_y  % nsamples_x  == 0);
     const int64_t channel_ratio = nchannels_y / nchannels_x;
+    const int64_t sample_ratio  = nsamples_y  / nsamples_x;
+    int device;
+    int warp_size;
 
-    int64_t block_size_best = WARP_SIZE;
-    int64_t niter_best      = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
-    for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
+    CUDA_CHECK(cudaGetDevice(&device));
+    warp_size = ggml_cuda_info().devices[device].warp_size;
+
+    int64_t block_size_best = warp_size;
+    int64_t niter_best      = (ncols + 2*warp_size - 1) / (2*warp_size);
+    int64_t max_block_size  = 256;
+    if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
+        max_block_size = 128;
+    }
+    for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
         const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
         if (niter < niter_best) {
             niter_best      = niter;
@@ -107,41 +123,49 @@ static void launch_mul_mat_vec_cuda(
         }
     }
 
-    const int smem = WARP_SIZE*sizeof(float);
-    const dim3 block_nums(nrows, 1, nchannels_y);
+    const int smem = warp_size*sizeof(float);
+    const dim3 block_nums(nrows, nchannels_y, nsamples_y);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   64: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   96: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  128: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  160: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  192: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  224: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  256: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -153,16 +177,19 @@ template
 static void mul_mat_vec_cuda(
         const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         enum ggml_prec prec, cudaStream_t stream) {
     switch (prec) {
         case GGML_PREC_DEFAULT: {
-            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
-                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+            launch_mul_mat_vec_cuda
+                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case GGML_PREC_F32: {
-            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
-                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+            launch_mul_mat_vec_cuda
+                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
     }
 }
@@ -171,10 +198,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
-    GGML_ASSERT(src1->ne[1] == 1);
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(ne11 == 1);
+    GGML_ASSERT(ne12 == ne2);
+    GGML_ASSERT(ne13 == ne3);
+
+    GGML_ASSERT(nb00 == ts_src0);
+    GGML_ASSERT(nb10 == ts_src1);
+    GGML_ASSERT(nb0  == ts_dst);
 
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
@@ -182,29 +218,22 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const float * src1_d = (const float *) src1->data;
     float       *  dst_d = (float       *)  dst->data;
 
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne12 = src1->ne[2];
-    GGML_ASSERT(dst->ne[2] == ne12);
-
-    GGML_ASSERT(src0->ne[3] == 1);
-    GGML_ASSERT(src1->ne[3] == 1);
-    GGML_ASSERT( dst->ne[3] == 1);
-
-    const int64_t stride_row         = src0->nb[1] / ggml_type_size(src0->type);
-    const int64_t channel_stride_x   = src0->nb[2] / ggml_type_size(src0->type);
-    const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type);
-    const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type);
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s12 = src1->nb[2] / ts_src1;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s13 = src1->nb[3] / ts_src1;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
 
     switch (src0->type) {
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
-                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
-                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -233,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec(
     const int64_t stride_row         = ne00;
     const int64_t nchannels_x        = 1;
     const int64_t nchannels_y        = 1;
-    const int64_t channel_stride_x   = 0;
-    const int64_t channel_stride_y   = 0;
-    const int64_t channel_stride_dst = 0;
+    const int64_t stride_channel_x   = 0;
+    const int64_t stride_channel_y   = 0;
+    const int64_t stride_channel_dst = 0;
+    const int64_t nsamples_x         = 1;
+    const int64_t nsamples_y         = 1;
+    const int64_t stride_sample_x    = 0;
+    const int64_t stride_sample_y    = 0;
+    const int64_t stride_sample_dst  = 0;
 
     switch (src0->type) {
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
index e3b912d87..4fb466ca0 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
@@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
     int64_t nwarps = 1;
     int64_t rows_per_cuda_block = 1;
 
-    if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
+    if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
         switch(ncols_y) {
             case 1:
                 nwarps = 4;
@@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda(
                 break;
         }
     }
+
     const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
     const dim3 block_nums(nblocks, 1, 1);
     const dim3 block_dims(WARP_SIZE, nwarps, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
index 133e219f0..f127616ed 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
@@ -1,24 +1,36 @@
 #include "norm.cuh"
+#include 
 
 template 
-static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    const int tid = threadIdx.x;
+static __global__ void norm_f32(
+        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps) {
+    const int nrows     = gridDim.x;
+    const int nchannels = gridDim.y;
 
-    float2 mean_var = make_float2(0.f, 0.f);
+    const int row       = blockIdx.x;
+    const int channel   = blockIdx.y;
+    const int sample    = blockIdx.z;
+    const int tid       = threadIdx.x;
+
+    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
+    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+    float2 mean_var = make_float2(0.0f, 0.0f);
 
     for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row*ncols + col];
+        const float xi = x[col];
         mean_var.x += xi;
         mean_var.y += xi * xi;
     }
 
     // sum up partial sums
     mean_var = warp_reduce_sum(mean_var);
-    if (block_size > WARP_SIZE) {
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
         __shared__ float2 s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = mean_var;
         }
@@ -32,7 +44,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
     const float inv_std = rsqrtf(var + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
+        dst[col] = (x[col] - mean) * inv_std;
     }
 }
 
@@ -40,14 +52,8 @@ template 
 static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
     // blockIdx.x: num_groups idx
     // threadIdx.x: block_size idx
-    int start = blockIdx.x * group_size;
-    int end = start + group_size;
-
-    start += threadIdx.x;
-
-    if (end >= ne_elements) {
-        end = ne_elements;
-    }
+    const int start =     blockIdx.x*group_size + threadIdx.x;
+    const int end   = min(blockIdx.x*group_size + group_size,  ne_elements);
 
     float tmp = 0.0f; // partial sum for thread in warp
 
@@ -56,10 +62,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
     }
 
     tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
         __shared__ float s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
@@ -68,11 +75,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp = warp_reduce_sum(tmp);
     }
 
-    float mean = tmp / group_size;
+    const float mean = tmp / group_size;
     tmp = 0.0f;
 
     for (int j = start; j < end; j += block_size) {
-        float xi = x[j] - mean;
+        const float xi = x[j] - mean;
         dst[j] = xi;
         tmp += xi * xi;
     }
@@ -80,8 +87,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
     tmp = warp_reduce_sum(tmp);
     if (block_size > WARP_SIZE) {
         __shared__ float s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
@@ -90,31 +97,42 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp = warp_reduce_sum(tmp);
     }
 
-    float variance = tmp / group_size;
-    float scale = rsqrtf(variance + eps);
+    const float variance = tmp / group_size;
+    const float scale = rsqrtf(variance + eps);
     for (int j = start; j < end; j += block_size) {
         dst[j] *= scale;
     }
 }
 
 template 
-static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    const int tid = threadIdx.x;
+static __global__ void rms_norm_f32(
+        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps) {
+    const int nrows     = gridDim.x;
+    const int nchannels = gridDim.y;
+
+    const int row       = blockIdx.x;
+    const int channel   = blockIdx.y;
+    const int sample    = blockIdx.z;
+    const int tid       = threadIdx.x;
+
+    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
+    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
 
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row*ncols + col];
+        const float xi = x[col];
         tmp += xi * xi;
     }
 
     // sum up partial sums
     tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
         __shared__ float s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
@@ -127,22 +145,77 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
     const float scale = rsqrtf(mean + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[row*ncols + col] = scale * x[row*ncols + col];
+        dst[col] = scale * x[col];
     }
 }
 
-static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
+template 
+static __global__ void rms_norm_back_f32(
+        const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    grad += int64_t(row)*ncols;
+    xf   += int64_t(row)*ncols;
+    dst  += int64_t(row)*ncols;
+
+    float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
+    float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xfi = xf[col];
+        sum_xx += xfi * xfi;
+        sum_xg += xfi * grad[col];
+    }
+
+    // sum up partial sums
+    sum_xx = warp_reduce_sum(sum_xx);
+    sum_xg = warp_reduce_sum(sum_xg);
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
+        __shared__ float s_sum_xx[32];
+        __shared__ float s_sum_xg[32];
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum_xx[warp_id] = sum_xx;
+            s_sum_xg[warp_id] = sum_xg;
+        }
+        __syncthreads();
+
+        sum_xx = s_sum_xx[lane_id];
+        sum_xx = warp_reduce_sum(sum_xx);
+
+        sum_xg = s_sum_xg[lane_id];
+        sum_xg = warp_reduce_sum(sum_xg);
+    }
+
+    const float mean_eps = sum_xx / ncols + eps;
+    const float sum_eps  = sum_xx + ncols*eps;
+
+    const float scale_grad = rsqrtf(mean_eps);
+    const float scale_x    = -scale_grad * sum_xg/sum_eps;
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[col] = scale_grad*grad[col] + scale_x*xf[col];
+    }
+}
+
+static void norm_f32_cuda(
+        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        norm_f32<<>>(x, dst, ncols, eps);
+        norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        norm_f32<1024><<>>(x, dst, ncols, eps);
+        norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
-static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
+static void group_norm_f32_cuda(
+        const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
     if (group_size < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
         group_norm_f32<<>>(x, dst, group_size, ne_elements, eps);
@@ -152,35 +225,51 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
     }
 }
 
-static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
+static void rms_norm_f32_cuda(
+        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        rms_norm_f32<<>>(x, dst, ncols, eps);
+        rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        rms_norm_f32<1024><<>>(x, dst, ncols, eps);
+        rms_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    }
+}
+
+static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        rms_norm_back_f32<<>>(grad, xf, dst, ncols, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        rms_norm_back_f32<1024><<>>(grad, xf, dst, ncols, eps);
     }
 }
 
 void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
-    norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+
+    norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
 }
 
 void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -189,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -198,6 +285,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
     float eps;
     memcpy(&eps, dst->op_params + 1, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
     int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
     group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
@@ -205,20 +293,50 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
-    rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+
+    rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * grad  = dst->src[0]; // gradients
+    const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
+
+    const float * grad_d  = (const float *) grad->data;
+    const float * src0f_d = (const float *) src0f->data;
+    float       * dst_d   = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(grad));
+
+    GGML_ASSERT( grad->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+    GGML_ASSERT(  dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0f->ne[0];
+    const int64_t nrows = ggml_nrows(src0f);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
+
+    rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
index 431a8f74d..d63d34380 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/out-prod.cu b/ml/backend/ggml/ggml/src/ggml-cuda/out-prod.cu
index 619cfdcb5..c9b2b699c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/out-prod.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/out-prod.cu
@@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
 
     GGML_ASSERT(ne01 == ne11);
     GGML_ASSERT(ne0 == ne00);
     GGML_ASSERT(ne1 == ne10);
 
-    GGML_ASSERT(ne2 == src0->ne[2]);
+    GGML_ASSERT(ne2 % src0->ne[2] == 0);
+    GGML_ASSERT(ne3 % src0->ne[3] == 0);
+
     GGML_ASSERT(ne2 == src1->ne[2]);
-    GGML_ASSERT(ne3 == src0->ne[3]);
     GGML_ASSERT(ne3 == src1->ne[3]);
 
     const float * src0_d = (const float *) src0->data;
@@ -33,19 +32,37 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const float alpha = 1.0f;
     const float beta = 0.0f;
 
-    GGML_ASSERT(ne2 == 1);
-    GGML_ASSERT(ne3 == 1);
     CUBLAS_CHECK(cublasSetStream(handle, stream));
 
+    const int64_t lda = nb01 / sizeof(float);
+    const int64_t ldc = nb1  / sizeof(float);
+
     const bool src1_T = ggml_is_transposed(src1);
     const cublasOperation_t src1_cublas_op =  src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
     const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float);
     GGML_ASSERT(                             (src1_T ?        nb11 :        nb10) == sizeof(float));
 
-    CUBLAS_CHECK(
-        cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
-                ne0, ne1, ne01,
-                &alpha, src0_d, ne00,
-                        src1_d, ldb,
-                &beta,  dst_d,  ne0));
+    // data strides in dimensions 2/3
+    const size_t s02 = nb02 / sizeof(float);
+    const size_t s03 = nb03 / sizeof(float);
+    const size_t s12 = nb12 / sizeof(float);
+    const size_t s13 = nb13 / sizeof(float);
+    const size_t s2  = nb2  / sizeof(float);
+    const size_t s3  = nb3  / sizeof(float);
+
+    // dps == dst per src0, used for group query attention
+    const int64_t dps2 = ne2 / ne02;
+    const int64_t dps3 = ne3 / ne03;
+
+    // TODO batched matrix multiplication
+    for (int64_t i3 = 0; i3 < ne3; ++i3) {
+        for (int64_t i2 = 0; i2 < ne2; ++i2) {
+            CUBLAS_CHECK(
+                cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
+                        ne0, ne1, ne01,
+                        &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
+                                src1_d +  i3      *s13 +  i2      *s12, ldb,
+                        &beta,  dst_d  +  i3      *s3  +  i2      *s2,  ldc));
+        }
+    }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
index 39fd4b165..b4b874093 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
@@ -92,4 +92,4 @@ void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     unpad_f32_cuda(src0_d, dst_d,
         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
         dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
-}
+}
\ No newline at end of file
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu
index 2c84778d2..18f691b2d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu
@@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
 
 // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+template
 static __device__ void rope_yarn(
-    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta) {
+        const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
+        float mscale, float & cos_theta, float & sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -29,24 +30,28 @@ static __device__ void rope_yarn(
         // Get n-d magnitude scaling corrected for interpolation
         mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
     }
-    *cos_theta = cosf(theta) * mscale;
-    *sin_theta = sinf(theta) * mscale;
+    cos_theta = cosf(theta) * mscale;
+    sin_theta = sinf(theta) * mscale;
+    if (!forward) {
+        sin_theta *= -1.0f;
+    }
 }
 
-template
+template
 static __global__ void rope_norm(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row_dst*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -54,39 +59,43 @@ static __global__ void rope_norm(
         return;
     }
 
-    const int i  = row*ne0 + i0;
-    const int i2 = row/p_delta_rows;
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+    const int idst = row_dst*ne0 + i0;
+    const int ix   = channel_x*s2 + row_x*s1 + i0;
+
+    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + 1];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + 1];
 
-    dst[i + 0] = x0*cos_theta - x1*sin_theta;
-    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0] = x0*cos_theta - x1*sin_theta;
+    dst[idst + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-template
+template
 static __global__ void rope_neox(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row_dst*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -94,39 +103,43 @@ static __global__ void rope_neox(
         return;
     }
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows;
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+    const int idst = row_dst*ne0 + i0/2;
+    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+
+    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims/2];
 
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-template
+template
 static __global__ void rope_multi(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
+        const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row_dst*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -134,25 +147,28 @@ static __global__ void rope_multi(
         return;
     }
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows;
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
-    int sec_w = sections.v[1] + sections.v[0];
-    int sector = (i0 / 2) % sect_dims;
+    const int idst = row_dst*ne0 + i0/2;
+    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+
+    const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
+    const int sec_w = sections.v[1] + sections.v[0];
+    const int sector = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
     if (sector < sections.v[0]) {
-        theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
     }
     else if (sector >= sections.v[0] && sector < sec_w) {
-        theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
     }
     else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
-        theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
     }
     else if (sector >= sec_w + sections.v[2]) {
-        theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -160,42 +176,46 @@ static __global__ void rope_multi(
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims/2];
 
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-template
+template
 static __global__ void rope_vision(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
+        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
+        const float theta_scale, const float * freq_factors, const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows; // i2-th tokens
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    int sect_dims = sections.v[0] + sections.v[1];
-    int sec_w = sections.v[1] + sections.v[0];
-    int sector = (i0 / 2) % sect_dims;
+    const int idst = row_dst*ne0 + i0/2;
+    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+
+    const int sect_dims = sections.v[0] + sections.v[1];
+    const int sec_w = sections.v[1] + sections.v[0];
+    const int sector = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
     if (sector < sections.v[0]) {
         const int p = sector;
-        theta_base = pos[i2]*powf(theta_scale, p);
+        theta_base = pos[channel_x]*powf(theta_scale, p);
     }
     else if (sector >= sections.v[0] && sector < sec_w) {
         const int p = sector - sections.v[0];
-        theta_base = pos[i2 + ne2]*powf(theta_scale, p);
+        theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -203,19 +223,20 @@ static __global__ void rope_vision(
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims];
 
-    dst[i + 0]      = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0]      = x0*cos_theta - x1*sin_theta;
+    dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
 }
 
-template
+template
 static void rope_norm_cuda(
-    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -224,22 +245,21 @@ static void rope_norm_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_norm<<>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_norm<<>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     } else {
-        rope_norm<<>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_norm<<>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     }
 }
 
-template
+template
 static void rope_neox_cuda(
-    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -248,22 +268,21 @@ static void rope_neox_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_neox<<>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_neox<<>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     } else {
-        rope_neox<<>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_neox<<>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     }
 }
 
-template
+template
 static void rope_multi_cuda(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -272,22 +291,21 @@ static void rope_multi_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_multi<<>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_multi<<>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     } else {
-        rope_multi<<>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_multi<<>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     }
 }
 
-template
+template
 static void rope_vision_cuda(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -298,80 +316,18 @@ static void rope_vision_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_vision<<>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_vision<<>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     } else {
-        rope_vision<<>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_vision<<>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     }
 }
 
-static void rope_norm_cuda_f16(
-    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
-
-    rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_norm_cuda_f32(
-    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
-
-    rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_neox_cuda_f16(
-    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
-
-    rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_neox_cuda_f32(
-    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
-) {
-
-    rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_multi_cuda_f16(
-    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-static void rope_multi_cuda_f32(
-    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-static void rope_vision_cuda_f16(
-    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-static void rope_vision_cuda_f32(
-    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+template 
+void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
     const ggml_tensor * src2 = dst->src[2];
@@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
     GGML_ASSERT(src0->type == dst->type);
@@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t ne02 = src0->ne[2]; // num heads
     const int64_t nr = ggml_nrows(src0);
 
+    const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
+    const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
+
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
     const int mode       = ((int32_t *) dst->op_params)[2];
@@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     // compute
     if (is_neox) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_neox_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_neox_cuda(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_neox_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_neox_cuda(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_mrope && !is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_multi_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_multi_cuda(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_multi_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_multi_cuda(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_vision_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_vision_cuda(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_vision_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_vision_cuda(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else {
         if (src0->type == GGML_TYPE_F32) {
-            rope_norm_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_norm_cuda(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_norm_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_norm_cuda(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     }
 }
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_rope_impl(ctx, dst);
+}
+
+void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_rope_impl(ctx, dst);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cuh
index 0f787a0b2..9139f3b22 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cuh
@@ -3,3 +3,5 @@
 #define CUDA_ROPE_BLOCK_SIZE 256
 
 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu
index c24abae1f..aac6e0999 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu
@@ -1,5 +1,7 @@
 #include "common.cuh"
+#include "ggml.h"
 #include "softmax.cuh"
+#include 
 
 template 
 static __device__ __forceinline__ float t2f32(T val) {
@@ -11,14 +13,26 @@ __device__ float __forceinline__ t2f32(half val) {
     return __half2float(val);
 }
 
-template 
-static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
+// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+template 
+static __global__ void soft_max_f32(
+        const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
+        const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
     const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
     const int tid  = threadIdx.x;
     const int rowx = blockIdx.x;
     const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
 
+    x    += int64_t(rowx)*ncols;
+    mask += int64_t(rowy)*ncols * (mask != nullptr);
+    dst  += int64_t(rowx)*ncols;
+
     const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 
     const int warp_id = threadIdx.x / WARP_SIZE;
@@ -29,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
     extern __shared__ float data_soft_max_f32[];
     float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
     // shared memory buffer to cache values between iterations:
-    float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
+    float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
 
     float max_val = -INFINITY;
 
@@ -41,10 +55,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
             break;
         }
 
-        const int64_t ix = (int64_t)rowx*ncols + col;
-        const int64_t iy = (int64_t)rowy*ncols + col;
-
-        const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
+        const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
 
         vals[col] = val;
         max_val = max(max_val, val);
@@ -110,8 +121,32 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
             return;
         }
 
-        const int64_t idst = (int64_t)rowx*ncols + col;
-        dst[idst] = vals[col] * inv_sum;
+        dst[col] = vals[col] * inv_sum;
+    }
+}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
+static __global__ void soft_max_back_f32(
+        const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
+    const int tid  = threadIdx.x;
+    const int rowx = blockIdx.x;
+
+    grad += int64_t(rowx)*ncols;
+    dstf += int64_t(rowx)*ncols;
+    dst  += int64_t(rowx)*ncols;
+
+    float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        dgf_dot += dstf[col]*grad[col];
+    }
+
+    dgf_dot = warp_reduce_sum(dgf_dot);
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
     }
 }
 
@@ -121,7 +156,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
     const dim3 block_dims(nth,     1, 1);
     const dim3 block_nums(nrows_x, 1, 1);
-    const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
+    const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
     static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
 
     const uint32_t n_head      = nrows_x/nrows_y;
@@ -131,50 +166,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
     // FIXME: this limit could be raised by ~2-4x on Ampere or newer
-    if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
+    if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
         switch (ncols_x) {
             case 32:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 64:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 128:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 256:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 512:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 1024:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 2048:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 4096:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             default:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
         }
     } else {
-        const size_t shmem_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+        const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
+        soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
     }
 }
 
+static void soft_max_back_f32_cuda(
+        const float * grad, const float * dstf, float * dst,
+        const int ncols, const int nrows, const float scale, cudaStream_t stream) {
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(nrows,     1, 1);
+
+    soft_max_back_f32<<>>(grad, dstf, dst, ncols, scale);
+}
+
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
 
-    const float * src0_d = (const float *)src0->data;
-    const void  * src1_d = src1 ? (const void *)src1->data : nullptr;
+    const float * src0_d = (const float *) src0->data;
+    const void  * src1_d = src1 ? (const void *) src1->data : nullptr;
+    float       *  dst_d = (float *) dst->data;
 
-    float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -189,18 +242,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float scale    = 1.0f;
     float max_bias = 0.0f;
 
-    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
 
     const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
     if (use_f16) {
-        const half * src1_dd = (const half *)src1_d;
-
-        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, (const half  *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
     } else {
-        const float * src1_dd = (const float *)src1_d;
-
-        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
     }
 }
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // grad
+    const ggml_tensor * src1 = dst->src[1]; // forward pass output
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(max_bias == 0.0f);
+
+    soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cuh
index 4ef4ff86c..93dfee835 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cuh
@@ -3,3 +3,5 @@
 #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
 
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu b/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu
index e0dafc1d2..f9589080a 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu
@@ -1,6 +1,6 @@
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
 #define USE_CUB
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
 
 #ifdef USE_CUB
 #include 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
new file mode 100644
index 000000000..80108615a
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 1, 8);
+DECL_FATTN_MMA_F16_CASE(80, 1, 8);
+DECL_FATTN_MMA_F16_CASE(96, 1, 8);
+DECL_FATTN_MMA_F16_CASE(112, 1, 8);
+DECL_FATTN_MMA_F16_CASE(128, 1, 8);
+DECL_FATTN_MMA_F16_CASE(256, 1, 8);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
new file mode 100644
index 000000000..66161c0ab
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 1);
+DECL_FATTN_MMA_F16_CASE(80, 16, 1);
+DECL_FATTN_MMA_F16_CASE(96, 16, 1);
+DECL_FATTN_MMA_F16_CASE(112, 16, 1);
+DECL_FATTN_MMA_F16_CASE(128, 16, 1);
+DECL_FATTN_MMA_F16_CASE(256, 16, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
new file mode 100644
index 000000000..ee88c72aa
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 2);
+DECL_FATTN_MMA_F16_CASE(80, 16, 2);
+DECL_FATTN_MMA_F16_CASE(96, 16, 2);
+DECL_FATTN_MMA_F16_CASE(112, 16, 2);
+DECL_FATTN_MMA_F16_CASE(128, 16, 2);
+DECL_FATTN_MMA_F16_CASE(256, 16, 2);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
new file mode 100644
index 000000000..d888a5a42
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 4);
+DECL_FATTN_MMA_F16_CASE(80, 16, 4);
+DECL_FATTN_MMA_F16_CASE(96, 16, 4);
+DECL_FATTN_MMA_F16_CASE(112, 16, 4);
+DECL_FATTN_MMA_F16_CASE(128, 16, 4);
+DECL_FATTN_MMA_F16_CASE(256, 16, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
new file mode 100644
index 000000000..d93a2d08e
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 2, 4);
+DECL_FATTN_MMA_F16_CASE(80, 2, 4);
+DECL_FATTN_MMA_F16_CASE(96, 2, 4);
+DECL_FATTN_MMA_F16_CASE(112, 2, 4);
+DECL_FATTN_MMA_F16_CASE(128, 2, 4);
+DECL_FATTN_MMA_F16_CASE(256, 2, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
new file mode 100644
index 000000000..617464c94
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 2, 8);
+DECL_FATTN_MMA_F16_CASE(80, 2, 8);
+DECL_FATTN_MMA_F16_CASE(96, 2, 8);
+DECL_FATTN_MMA_F16_CASE(112, 2, 8);
+DECL_FATTN_MMA_F16_CASE(128, 2, 8);
+DECL_FATTN_MMA_F16_CASE(256, 2, 8);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
new file mode 100644
index 000000000..970d2b686
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 32, 1);
+DECL_FATTN_MMA_F16_CASE(80, 32, 1);
+DECL_FATTN_MMA_F16_CASE(96, 32, 1);
+DECL_FATTN_MMA_F16_CASE(112, 32, 1);
+DECL_FATTN_MMA_F16_CASE(128, 32, 1);
+DECL_FATTN_MMA_F16_CASE(256, 32, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
new file mode 100644
index 000000000..65cd377c3
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 32, 2);
+DECL_FATTN_MMA_F16_CASE(80, 32, 2);
+DECL_FATTN_MMA_F16_CASE(96, 32, 2);
+DECL_FATTN_MMA_F16_CASE(112, 32, 2);
+DECL_FATTN_MMA_F16_CASE(128, 32, 2);
+DECL_FATTN_MMA_F16_CASE(256, 32, 2);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
new file mode 100644
index 000000000..f4a8bf348
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 2);
+DECL_FATTN_MMA_F16_CASE(80, 4, 2);
+DECL_FATTN_MMA_F16_CASE(96, 4, 2);
+DECL_FATTN_MMA_F16_CASE(112, 4, 2);
+DECL_FATTN_MMA_F16_CASE(128, 4, 2);
+DECL_FATTN_MMA_F16_CASE(256, 4, 2);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
new file mode 100644
index 000000000..de191a8ab
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 4);
+DECL_FATTN_MMA_F16_CASE(80, 4, 4);
+DECL_FATTN_MMA_F16_CASE(96, 4, 4);
+DECL_FATTN_MMA_F16_CASE(112, 4, 4);
+DECL_FATTN_MMA_F16_CASE(128, 4, 4);
+DECL_FATTN_MMA_F16_CASE(256, 4, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
new file mode 100644
index 000000000..e8cb0e1b3
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 8);
+DECL_FATTN_MMA_F16_CASE(80, 4, 8);
+DECL_FATTN_MMA_F16_CASE(96, 4, 8);
+DECL_FATTN_MMA_F16_CASE(112, 4, 8);
+DECL_FATTN_MMA_F16_CASE(128, 4, 8);
+DECL_FATTN_MMA_F16_CASE(256, 4, 8);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
new file mode 100644
index 000000000..a532e9629
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 1);
+DECL_FATTN_MMA_F16_CASE(80, 64, 1);
+DECL_FATTN_MMA_F16_CASE(96, 64, 1);
+DECL_FATTN_MMA_F16_CASE(112, 64, 1);
+DECL_FATTN_MMA_F16_CASE(128, 64, 1);
+DECL_FATTN_MMA_F16_CASE(256, 64, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
new file mode 100644
index 000000000..bf25181aa
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 1);
+DECL_FATTN_MMA_F16_CASE(80, 8, 1);
+DECL_FATTN_MMA_F16_CASE(96, 8, 1);
+DECL_FATTN_MMA_F16_CASE(112, 8, 1);
+DECL_FATTN_MMA_F16_CASE(128, 8, 1);
+DECL_FATTN_MMA_F16_CASE(256, 8, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
new file mode 100644
index 000000000..378c132e6
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 2);
+DECL_FATTN_MMA_F16_CASE(80, 8, 2);
+DECL_FATTN_MMA_F16_CASE(96, 8, 2);
+DECL_FATTN_MMA_F16_CASE(112, 8, 2);
+DECL_FATTN_MMA_F16_CASE(128, 8, 2);
+DECL_FATTN_MMA_F16_CASE(256, 8, 2);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
new file mode 100644
index 000000000..372641be9
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 4);
+DECL_FATTN_MMA_F16_CASE(80, 8, 4);
+DECL_FATTN_MMA_F16_CASE(96, 8, 4);
+DECL_FATTN_MMA_F16_CASE(112, 8, 4);
+DECL_FATTN_MMA_F16_CASE(128, 8, 4);
+DECL_FATTN_MMA_F16_CASE(256, 8, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
new file mode 100644
index 000000000..9ff5968b6
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 8);
+DECL_FATTN_MMA_F16_CASE(80, 8, 8);
+DECL_FATTN_MMA_F16_CASE(96, 8, 8);
+DECL_FATTN_MMA_F16_CASE(112, 8, 8);
+DECL_FATTN_MMA_F16_CASE(128, 8, 8);
+DECL_FATTN_MMA_F16_CASE(256, 8, 8);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu
index 81fc92202..6b21f407d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu
@@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void silu_back_f32(
+        const float * grad, const float * xf, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const float xfi = xf[i];
+    const float s = 1.0f / (1.0f + expf(-xfi));
+    dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
+}
+
 static __global__ void tanh_f32(const float * x, float * dst, int k) {
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
@@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<>>(x, dst, k);
 }
 
+static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+    silu_back_f32<<>>(grad, x, dst, k);
+}
+
 static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
     tanh_f32<<>>(x, dst, k);
@@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
 
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // input from forward pass
+    const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
+}
+
 void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh
index c91936728..e7f62643a 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh
@@ -4,6 +4,7 @@
 #define CUDA_STEP_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_SILU_BACK_BLOCK_SIZE 256
 #define CUDA_TANH_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
 #define CUDA_SIGMOID_BLOCK_SIZE 256
@@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu b/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu
index 42578341a..bbdafbee5 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu
@@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     const float * s_d  = (const float *)dst->src[5]->data;
 
     const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[3];
+    const int64_t T = dst->src[0]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[2];
+    const int64_t H = dst->src[0]->ne[1];
 
     float * dst_d = (float *)dst->data;
 
diff --git a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
index b15fbd24d..4a0384dd4 100644
--- a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
@@ -40,13 +40,20 @@ find_package(hip     REQUIRED)
 find_package(hipblas REQUIRED)
 find_package(rocblas REQUIRED)
 
+if (${hip_VERSION} VERSION_LESS 5.5)
+    message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
+endif()
+
 message(STATUS "HIP and hipBLAS found")
 
+# Workaround old compilers
+set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
+
 file(GLOB   GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
 list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
 
 file(GLOB   GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
-file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
+file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
 file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
@@ -70,7 +77,9 @@ ggml_add_backend_library(ggml-hip
                         )
 
 # TODO: do not use CUDA definitions for HIP
-target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
+if (NOT GGML_BACKEND_DL)
+    target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
+endif()
 
 add_compile_definitions(GGML_USE_HIP)
 
@@ -90,6 +99,18 @@ if (GGML_CUDA_NO_PEER_COPY)
     add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
 endif()
 
+if (GGML_HIP_GRAPHS)
+    add_compile_definitions(GGML_HIP_GRAPHS)
+endif()
+
+if (GGML_HIP_NO_VMM)
+    add_compile_definitions(GGML_HIP_NO_VMM)
+endif()
+
+if (NOT GGML_CUDA_FA)
+    add_compile_definitions(GGML_CUDA_NO_FA)
+endif()
+
 if (CXX_IS_HIPCC)
     set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
     target_link_libraries(ggml-hip PRIVATE hip::device)
diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h
index 549772c57..1fbcbd045 100644
--- a/ml/backend/ggml/ggml/src/ggml-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-impl.h
@@ -3,6 +3,8 @@
 // GGML internal header
 
 #include "ggml.h"
+#include "gguf.h"
+
 #include 
 #include 
 #include  // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
@@ -14,7 +16,7 @@
 #include 
 #endif // __ARM_FEATURE_SVE
 
-#if defined(__ARM_NEON) && !defined(__CUDACC__)
+#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
 // if YCM cannot find , make a symbolic link to it, for example:
 //
 //   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
@@ -551,22 +553,15 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
 #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
 #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
 
-// expose GGUF internals for test code
-
-GGML_API size_t gguf_type_size(enum gguf_type type);
-
-GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
-
-struct gguf_buf {
-    void * data;
-    size_t size;
-    size_t offset;
-};
-GGML_API struct gguf_buf gguf_buf_init(size_t size);
-GGML_API void gguf_buf_free(struct gguf_buf buf);
-
-GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta);
-
 #ifdef __cplusplus
 }
 #endif
+
+#ifdef __cplusplus
+#include 
+
+// expose GGUF internals for test code
+GGML_API size_t gguf_type_size(enum gguf_type type);
+GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
+GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta);
+#endif // __cplusplus
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
index f10966df3..c3610ac07 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
@@ -477,7 +477,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
     240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
 GGML_TABLE_END()
 
-//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
 GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
     0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
@@ -512,7 +511,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
     0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
 GGML_TABLE_END()
-//#endif
 
 
 GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
@@ -2513,24 +2511,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
 template 
 void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
     const half d_all = xb->d;
-    device const uint8_t * ql = (device const uint8_t *)xb->ql;
-    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const uint16_t * ql = (device const uint16_t *)xb->ql;
+    device const uint16_t * qh = (device const uint16_t *)xb->qh;
     device const int8_t * scales = (device const int8_t *)xb->scales;
 
-    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
-    qh = qh + 32*(il/8) + 16*(il&1);
+    ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
+    qh = qh + 16*(il/8) + 8*(il&1);
     float sc = scales[(il%2) + 2 * ((il/2))];
     il = (il/2) & 3;
 
-    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
-    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
-    const float       coef = il>1 ? 1.f/16.f          : 1.f;
+    const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
+    const uint32_t kmask2 = il>1 ? 0xF0F0F0F0                       : 0x0F0F0F0F;
     const float ml = d_all * sc * 32.f;
-    const float dl = d_all * sc * coef;
-    for (int i = 0; i < 16; ++i) {
-        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
-                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
-        reg[i/4][i%4] = dl * q - ml;
+    const float dl0 = d_all * sc;
+    const float dl1 = dl0 / 256.f;
+    const float dl2 = dl0 / (256.f * 256.f);
+    const float dl3 = dl0 / (256.f * 256.f * 256.f);
+    const uint8_t shr_h = il>2 ? 2 : 0;
+    const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
+    const uint8_t shr_l = il>1 ? 4 : 0;
+    for (int i = 0; i < 4; ++i) {
+        const uint32_t  low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
+        const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
+        const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
+        reg[i][0] = dl0 *  ((half)(q & 0xFF))       - ml;
+        reg[i][1] = dl1 * ((float)(q & 0xFF00))     - ml;
+        reg[i][2] = dl2 * ((float)(q & 0xFF0000))   - ml;
+        reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
     }
 }
 
@@ -3198,7 +3205,7 @@ kernel void kernel_soft_max(
     }
 
     // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
     threadgroup_barrier(mem_flags::mem_none);
 
     float sum = simd_sum(lsum);
@@ -3303,7 +3310,7 @@ kernel void kernel_soft_max_4(
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
     // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
     threadgroup_barrier(mem_flags::mem_none);
 
     float sum = simd_sum(lsum);
@@ -6517,6 +6524,49 @@ kernel void kernel_cpy_f32_iq4_nl(
     }
 }
 
+template
+kernel void kernel_cpy_q_f32(
+        constant ggml_metal_kargs_cpy & args,
+        device  const char * src0,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
+
+    device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       T4x4    * dst_data = (device       T4x4    *)(dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
+        T4x4 temp;
+        dequantize_func(src_data + i00/nl, i00%nl, temp);
+        dst_data[i00] = temp;
+    }
+}
+
+typedef decltype(kernel_cpy_q_f32) cpy_q_f_t;
+
+template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+
+template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+
 kernel void kernel_concat(
     constant ggml_metal_kargs_concat & args,
     device  const char * src0,
@@ -6601,7 +6651,6 @@ void kernel_mul_mv_q2_K_f32_impl(
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
-
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
             for (int i = 0; i < 8; i += 2) {
@@ -6632,7 +6681,7 @@ void kernel_mul_mv_q2_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -6798,7 +6847,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
     if (tiisg == 0) {
-        for (int row = 0; row < 2; ++row) {
+        for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
             dst_f32[first_row + row] = sumf1[row];
         }
     }
@@ -6914,7 +6963,7 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -7046,7 +7095,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = tot;
@@ -7091,6 +7140,10 @@ void kernel_mul_mv_q6_K_f32_impl(
 
     const int row = 2*r0 + sgitg;
 
+    if (row >= args.ne0) {
+        return;
+    }
+
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
@@ -7246,7 +7299,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -7364,7 +7417,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -7474,7 +7527,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.5f;
@@ -7586,7 +7639,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -7699,7 +7752,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -7799,7 +7852,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -7894,7 +7947,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -7984,7 +8037,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -8073,7 +8126,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
index 318addecc..e4c093f9c 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
@@ -19,7 +19,17 @@
 // max number of MTLCommandBuffer used to submit a graph for processing
 #define GGML_METAL_MAX_COMMAND_BUFFERS 8
 
-#define UNUSED(x) (void)(x)
+#ifndef TARGET_OS_VISION
+#define TARGET_OS_VISION 0
+#endif
+
+// create residency sets only on macOS >= 15.0
+#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
+    TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
+    TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
+    TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
+#define GGML_METAL_HAS_RESIDENCY_SETS 1
+#endif
 
 // globals
 
@@ -39,6 +49,7 @@ static struct ggml_backend_metal_device_context {
 
     bool has_simdgroup_reduction;
     bool has_simdgroup_mm;
+    bool has_residency_sets;
     bool has_bfloat;
     bool use_bfloat;
 
@@ -48,6 +59,7 @@ static struct ggml_backend_metal_device_context {
     /*.mtl_device_ref_count    =*/ 0,
     /*.has_simdgroup_reduction =*/ false,
     /*.has_simdgroup_mm        =*/ false,
+    /*.has_residency_sets      =*/ false,
     /*.has_bfloat              =*/ false,
     /*.use_bfloat              =*/ false,
     /*.name                    =*/ "",
@@ -59,12 +71,18 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
 
     if (ctx->mtl_device == nil) {
         ctx->mtl_device = MTLCreateSystemDefaultDevice();
+    }
 
+    if (ctx->mtl_device) {
         ctx->has_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
         ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
 
         ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
 
+#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
+        ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
+#endif
+
         ctx->has_bfloat  = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
         ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
 
@@ -90,8 +108,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     ctx->mtl_device_ref_count--;
 
     if (ctx->mtl_device_ref_count == 0) {
-        [ctx->mtl_device release];
-        ctx->mtl_device = nil;
+        if (ctx->mtl_device) {
+            [ctx->mtl_device release];
+            ctx->mtl_device = nil;
+        }
     }
 }
 
@@ -388,6 +408,16 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
     GGML_METAL_KERNEL_TYPE_CONCAT,
     GGML_METAL_KERNEL_TYPE_SQR,
     GGML_METAL_KERNEL_TYPE_SQRT,
@@ -484,6 +514,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
 
     ctx->queue  = [device newCommandQueue];
+    if (ctx->queue == nil) {
+        GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
+        return NULL;
+    }
+
     ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
     id metal_library;
@@ -650,6 +685,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 
     GGML_LOG_INFO("%s: simdgroup reduction   = %s\n", __func__, ctx_dev->has_simdgroup_reduction     ? "true" : "false");
     GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm            ? "true" : "false");
+    GGML_LOG_INFO("%s: has residency sets    = %s\n", __func__, ctx_dev->has_residency_sets          ? "true" : "false");
     GGML_LOG_INFO("%s: has bfloat            = %s\n", __func__, ctx_dev->has_bfloat                  ? "true" : "false");
     GGML_LOG_INFO("%s: use bfloat            = %s\n", __func__, ctx_dev->use_bfloat                  ? "true" : "false");
     GGML_LOG_INFO("%s: hasUnifiedMemory      = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
@@ -988,6 +1024,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,                  cpy_q4_0_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,                  cpy_q4_0_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,                  cpy_q4_1_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,                  cpy_q4_1_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,                  cpy_q5_0_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,                  cpy_q5_0_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,                  cpy_q5_1_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,                  cpy_q5_1_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,                  cpy_q8_0_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,                  cpy_q8_0_f16,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
@@ -1037,8 +1083,70 @@ struct ggml_backend_metal_buffer_context {
     // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
     int n_buffers;
     struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
+
+    // optional MTLResidencySet
+    id rset;
 };
 
+// rset init
+static bool ggml_backend_metal_buffer_rset_init(
+        struct ggml_backend_metal_buffer_context * ctx,
+        struct ggml_backend_metal_device_context * ctx_dev,
+        id device) {
+    ctx->rset = nil;
+
+    if (!ctx_dev->has_residency_sets) {
+        return true;
+    }
+
+#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
+    if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
+        MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
+        desc.label = @"ggml_backend_metal";
+        desc.initialCapacity = ctx->n_buffers;
+
+        NSError * error;
+        ctx->rset = [device newResidencySetWithDescriptor:desc error:&error];
+        if (error) {
+            GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+            [desc release];
+            return false;
+        }
+
+        [desc release];
+
+        for (int i = 0; i < ctx->n_buffers; i++) {
+            [ctx->rset addAllocation:ctx->buffers[i].metal];
+        }
+
+        [ctx->rset commit];
+        [ctx->rset requestResidency];
+
+        return true;
+    }
+#else
+    GGML_UNUSED(ctx_dev);
+    GGML_UNUSED(device);
+#endif
+
+    return true;
+}
+
+// rset free
+static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) {
+#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
+    if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
+        if (ctx->rset) {
+            [ctx->rset endResidency];
+            [ctx->rset removeAllAllocations];
+            [ctx->rset release];
+        }
+    }
+#else
+    GGML_UNUSED(ctx);
+#endif
+}
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
@@ -1122,12 +1230,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
-            return has_simdgroup_reduction;
+            return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_RMS_NORM:
-            return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
+            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ARGMAX:
-        case GGML_OP_NORM:
             return true;
+        case GGML_OP_NORM:
+            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ROPE:
             {
                 const int mode = ((const int32_t *) op->op_params)[2];
@@ -1201,6 +1310,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                             default:
                                 return false;
                         }
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                        switch (op->type) {
+                            case GGML_TYPE_F32:
+                            case GGML_TYPE_F16:
+                                return true;
+                            default:
+                                return false;
+                        }
                     default:
                         return false;
                 };
@@ -1897,7 +2018,7 @@ static void ggml_metal_encode_node(
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
                 // TODO: add ggml_metal_kargs struct
-                // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
+                // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 if (id_src1) {
@@ -3843,10 +3964,6 @@ static void ggml_metal_encode_node(
         case GGML_OP_CPY:
         case GGML_OP_CONT:
             {
-                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-
-                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
                 id pipeline = nil;
 
                 switch (src0t) {
@@ -3880,7 +3997,47 @@ static void ggml_metal_encode_node(
                             switch (dstt) {
                                 case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
                                 case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
-                                default: GGML_ASSERT(false && "not implemented");
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q4_0:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q4_1:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q5_0:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q5_1:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q8_0:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
                             };
                         } break;
                     default: GGML_ABORT("not implemented");
@@ -3910,7 +4067,11 @@ static void ggml_metal_encode_node(
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
+                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
+
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+
             } break;
         case GGML_OP_SET:
             {
@@ -4209,6 +4370,8 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
     for (int i = 0; i < ctx->n_buffers; i++) {
         [ctx->buffers[i].metal release];
     }
+
+    ggml_backend_metal_buffer_rset_free(ctx);
     ggml_backend_metal_device_rel(buffer->buft->device->context);
 
     if (ctx->owned) {
@@ -4232,19 +4395,19 @@ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
 static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
     memset((char *)tensor->data + offset, value, size);
 
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     memcpy((char *)tensor->data + offset, data, size);
 
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
     memcpy(data, (const char *)tensor->data + offset, size);
 
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
 }
 
 static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
@@ -4254,7 +4417,7 @@ static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, c
     }
     return false;
 
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -4280,7 +4443,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
 static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
     return "Metal";
 
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) {
@@ -4304,8 +4467,8 @@ static void ggml_backend_metal_log_allocated_size(id device, size_t s
     }
 #endif
 #endif
-    UNUSED(device);
-    UNUSED(size_aligned);
+    GGML_UNUSED(device);
+    GGML_UNUSED(size_aligned);
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -4318,7 +4481,8 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
         size_aligned += (size_page - (size_aligned % size_page));
     }
 
-    id device = ggml_backend_metal_device_acq(buft->device->context);
+    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
+    id device = ggml_backend_metal_device_acq(ctx_dev);
 
     ctx->all_data = ggml_metal_host_malloc(size_aligned);
     ctx->all_size = size_aligned;
@@ -4341,7 +4505,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
     if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
         GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
         free(ctx);
-        ggml_backend_metal_device_rel(buft->device->context);
+        ggml_backend_metal_device_rel(ctx_dev);
+        return NULL;
+    }
+
+    if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
+        GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
+        free(ctx);
+        ggml_backend_metal_device_rel(ctx_dev);
         return NULL;
     }
 
@@ -4352,7 +4523,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
 
 static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return 32;
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
@@ -4362,13 +4533,13 @@ static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_ty
 
     return max_size;
 
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
     return true;
 
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
@@ -4391,7 +4562,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
 static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
     return "Metal_Mapped";
 
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
@@ -4434,7 +4605,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
         size_aligned += (size_page - (size_aligned % size_page));
     }
 
-    id device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
+    struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
+    id device = ggml_backend_metal_device_acq(ctx_dev);
 
     // the buffer fits into the max buffer size allowed by the device
     if (size_aligned <= device.maxBufferLength) {
@@ -4487,6 +4659,13 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
         }
     }
 
+    if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
+        GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
+        free(ctx);
+        ggml_backend_metal_device_rel(ctx_dev);
+        return NULL;
+    }
+
     return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
 }
 
@@ -4495,7 +4674,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
 static const char * ggml_backend_metal_name(ggml_backend_t backend) {
     return "Metal";
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 static void ggml_backend_metal_free(ggml_backend_t backend) {
@@ -4800,6 +4979,13 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
         }
     }
 
+    if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
+        GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
+        free(ctx);
+        ggml_backend_metal_device_rel(ctx_dev);
+        return NULL;
+    }
+
     return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
 }
 
@@ -4813,7 +4999,7 @@ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml
     return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
             buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
 
-    UNUSED(dev);
+    GGML_UNUSED(dev);
 }
 
 static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
index 204c93e6f..f38909d0b 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
@@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
 template 
 void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
     const half d_all = xb->d;
-    device const uint8_t * ql = (device const uint8_t *)xb->ql;
-    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const uint16_t * ql = (device const uint16_t *)xb->ql;
+    device const uint16_t * qh = (device const uint16_t *)xb->qh;
     device const int8_t * scales = (device const int8_t *)xb->scales;
 
-    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
-    qh = qh + 32*(il/8) + 16*(il&1);
+    ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
+    qh = qh + 16*(il/8) + 8*(il&1);
     float sc = scales[(il%2) + 2 * ((il/2))];
     il = (il/2) & 3;
 
-    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
-    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
-    const float       coef = il>1 ? 1.f/16.f          : 1.f;
+    const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
+    const uint32_t kmask2 = il>1 ? 0xF0F0F0F0                       : 0x0F0F0F0F;
     const float ml = d_all * sc * 32.f;
-    const float dl = d_all * sc * coef;
-    for (int i = 0; i < 16; ++i) {
-        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
-                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
-        reg[i/4][i%4] = dl * q - ml;
+    const float dl0 = d_all * sc;
+    const float dl1 = dl0 / 256.f;
+    const float dl2 = dl0 / (256.f * 256.f);
+    const float dl3 = dl0 / (256.f * 256.f * 256.f);
+    const uint8_t shr_h = il>2 ? 2 : 0;
+    const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
+    const uint8_t shr_l = il>1 ? 4 : 0;
+    for (int i = 0; i < 4; ++i) {
+        const uint32_t  low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
+        const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
+        const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
+        reg[i][0] = dl0 *  ((half)(q & 0xFF))       - ml;
+        reg[i][1] = dl1 * ((float)(q & 0xFF00))     - ml;
+        reg[i][2] = dl2 * ((float)(q & 0xFF0000))   - ml;
+        reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
     }
 }
 
@@ -1058,7 +1067,7 @@ kernel void kernel_soft_max(
     }
 
     // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
     threadgroup_barrier(mem_flags::mem_none);
 
     float sum = simd_sum(lsum);
@@ -1163,7 +1172,7 @@ kernel void kernel_soft_max_4(
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
     // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
     threadgroup_barrier(mem_flags::mem_none);
 
     float sum = simd_sum(lsum);
@@ -4377,6 +4386,49 @@ kernel void kernel_cpy_f32_iq4_nl(
     }
 }
 
+template
+kernel void kernel_cpy_q_f32(
+        constant ggml_metal_kargs_cpy & args,
+        device  const char * src0,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
+
+    device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       T4x4    * dst_data = (device       T4x4    *)(dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
+        T4x4 temp;
+        dequantize_func(src_data + i00/nl, i00%nl, temp);
+        dst_data[i00] = temp;
+    }
+}
+
+typedef decltype(kernel_cpy_q_f32) cpy_q_f_t;
+
+template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+
+template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+
 kernel void kernel_concat(
     constant ggml_metal_kargs_concat & args,
     device  const char * src0,
@@ -4461,7 +4513,6 @@ void kernel_mul_mv_q2_K_f32_impl(
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
-
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
             for (int i = 0; i < 8; i += 2) {
@@ -4492,7 +4543,7 @@ void kernel_mul_mv_q2_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -4658,7 +4709,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
     if (tiisg == 0) {
-        for (int row = 0; row < 2; ++row) {
+        for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
             dst_f32[first_row + row] = sumf1[row];
         }
     }
@@ -4774,7 +4825,7 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -4906,7 +4957,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = tot;
@@ -4951,6 +5002,10 @@ void kernel_mul_mv_q6_K_f32_impl(
 
     const int row = 2*r0 + sgitg;
 
+    if (row >= args.ne0) {
+        return;
+    }
+
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
@@ -5106,7 +5161,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5224,7 +5279,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5334,7 +5389,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.5f;
@@ -5446,7 +5501,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5559,7 +5614,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5659,7 +5714,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5754,7 +5809,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5844,7 +5899,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5933,7 +5988,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c
index 7ffcd907d..635aa2990 100644
--- a/ml/backend/ggml/ggml/src/ggml.c
+++ b/ml/backend/ggml/ggml/src/ggml.c
@@ -128,6 +128,10 @@ static void ggml_print_backtrace_symbols(void) {
 #endif
 
 static void ggml_print_backtrace(void) {
+    const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
+    if (GGML_NO_BACKTRACE) {
+        return;
+    }
     char attach[32];
     snprintf(attach, sizeof(attach), "attach %d", getpid());
     int pid = fork();
@@ -236,7 +240,11 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
 
 
 void * ggml_aligned_malloc(size_t size) {
+#if defined(__s390x__)
+    const int alignment = 256;
+#else
     const int alignment = 64;
+#endif
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
     return _aligned_malloc(size, alignment);
@@ -969,6 +977,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GET_REL_POS",
     "ADD_REL_POS",
     "RWKV_WKV6",
+    "GATED_LINEAR_ATTN",
 
     "UNARY",
 
@@ -988,7 +997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1066,6 +1075,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "get_rel_pos(x)",
     "add_rel_pos(x)",
     "rwkv_wkv6(k, v, r, tf, td, s)",
+    "gated_linear_attn(k, v, q, gate, s)",
 
     "unary(x)",
 
@@ -1085,7 +1095,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1375,7 +1385,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
         (t0->nb[3] == t1->nb[3]);
 }
 
-// check if t1 can be represented as a repeatition of t0
+// check if t1 can be represented as a repetition of t0
 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
@@ -1590,15 +1600,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
 
     struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
 
-#ifdef __clang__
-    // temporary until ggml_tensor::backend is removed
-    #pragma clang diagnostic push
-    #pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
     *result = (struct ggml_tensor) {
         /*.type         =*/ type,
-        /*.backend      =*/ GGML_BACKEND_TYPE_CPU,
         /*.buffer       =*/ NULL,
         /*.ne           =*/ { 1, 1, 1, 1 },
         /*.nb           =*/ { 0, 0, 0, 0 },
@@ -1614,10 +1617,6 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.padding      =*/ { 0 },
     };
 
-#ifdef __clang__
-    #pragma clang diagnostic pop
-#endif
-
     // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
     //GGML_ASSERT_ALIGNED(result->data);
 
@@ -3461,12 +3460,14 @@ struct ggml_tensor * ggml_soft_max_ext(
     return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
 }
 
-// ggml_soft_max_back
+// ggml_soft_max_ext_back
 
-static struct ggml_tensor * ggml_soft_max_back_impl(
+static struct ggml_tensor * ggml_soft_max_ext_back_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
+        float                 scale,
+        float                 max_bias,
         bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
@@ -3474,21 +3475,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
     result->src[0] = a;
     result->src[1] = b;
 
+    memcpy((float *) result->op_params + 0, &scale,    sizeof(float));
+    memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
+
     return result;
 }
 
-struct ggml_tensor * ggml_soft_max_back(
+struct ggml_tensor * ggml_soft_max_ext_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
-    return ggml_soft_max_back_impl(ctx, a, b, false);
+        struct ggml_tensor  * b,
+        float                 scale,
+        float                 max_bias) {
+    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
 }
 
-struct ggml_tensor * ggml_soft_max_back_inplace(
+struct ggml_tensor * ggml_soft_max_ext_back_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
-    return ggml_soft_max_back_impl(ctx, a, b, true);
+        struct ggml_tensor  * b,
+        float                 scale,
+        float                 max_bias) {
+    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
 }
 
 // ggml_rope
@@ -3706,7 +3714,7 @@ void ggml_rope_yarn_corr_dims(
 
 // ggml_rope_back
 
-struct ggml_tensor * ggml_rope_back(
+struct ggml_tensor * ggml_rope_ext_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
@@ -3720,29 +3728,32 @@ struct ggml_tensor * ggml_rope_back(
         float                 attn_factor,
         float                 beta_fast,
         float                 beta_slow) {
-    GGML_ASSERT(ggml_is_vector(b));
-    GGML_ASSERT(b->type == GGML_TYPE_I32);
-    GGML_ASSERT(a->ne[2] == b->ne[0]);
-
-    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
-
-    int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
-    memcpy(params +  5, &freq_base,    sizeof(float));
-    memcpy(params +  6, &freq_scale,   sizeof(float));
-    memcpy(params +  7, &ext_factor,   sizeof(float));
-    memcpy(params +  8, &attn_factor,  sizeof(float));
-    memcpy(params +  9, &beta_fast,    sizeof(float));
-    memcpy(params + 10, &beta_slow,    sizeof(float));
-    ggml_set_op_params(result, params, sizeof(params));
-
-    result->op     = GGML_OP_ROPE_BACK;
-    result->src[0] = a;
-    result->src[1] = b;
-    result->src[2] = c;
-
+    struct ggml_tensor * result = ggml_rope_ext(
+        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+    result->op = GGML_OP_ROPE_BACK;
     return result;
 }
 
+struct ggml_tensor * ggml_rope_multi_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
+        int                   n_dims,
+        int                   sections[4],
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    struct ggml_tensor * result = ggml_rope_multi(
+        ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+    result->op = GGML_OP_ROPE_BACK;
+    return result;
+}
 // ggml_clamp
 
 struct ggml_tensor * ggml_clamp(
@@ -4661,15 +4672,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     GGML_ASSERT(ggml_is_contiguous(state));
 
     const int64_t S = k->ne[0];
-    const int64_t H = k->ne[2];
-    const int64_t n_tokens = k->ne[3];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
     const int64_t n_seqs = state->ne[1];
     {
-        GGML_ASSERT(k->ne[1] == 1);
-        GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
-        GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
-        // TODO: RWKV v4 and v5
-        GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
+        GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
         GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
     }
 
@@ -4688,6 +4697,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     return result;
 }
 
+// ggml_gated_linear_attn
+
+struct ggml_tensor * ggml_gated_linear_attn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * state,
+        float scale) {
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
+        GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_set_op_params_f32(result, 0, scale);
+
+    result->op     = GGML_OP_GATED_LINEAR_ATTN;
+    result->src[0] = k;
+    result->src[1] = v;
+    result->src[2] = q;
+    result->src[3] = g;
+    result->src[4] = state;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
@@ -5062,10 +5114,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         struct ggml_tensor  * c) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
-    GGML_ASSERT(ggml_is_scalar(c));
+    GGML_ASSERT(ggml_is_scalar(a));
+    GGML_ASSERT(ggml_are_same_shape(b, c));
 
-    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
 
     result->op     = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
     result->src[0] = a;
@@ -5244,7 +5296,7 @@ static void ggml_sub_or_set(
 }
 
 static void ggml_compute_backward(
-        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
+        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
     struct ggml_tensor * tensor = cgraph->nodes[i];
     struct ggml_tensor * grad   = ggml_graph_get_grad(cgraph, tensor);
 
@@ -5316,7 +5368,7 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_MUL: {
             if (src0_needs_grads) {
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
             }
             if (src1_needs_grads) {
                 struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
@@ -5388,7 +5440,7 @@ static void ggml_compute_backward(
             if (src0_needs_grads) {
                 float eps;
                 memcpy(&eps, tensor->op_params, sizeof(float));
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
             }
         } break;
         case GGML_OP_MUL_MAT: {
@@ -5408,21 +5460,25 @@ static void ggml_compute_backward(
             // src1.shape   [n,p,qq,rr]
 
             if (src0_needs_grads) {
-                struct ggml_tensor * s1_tg =
+                GGML_ASSERT(grad->ne[2] == src1->ne[2]);
+                GGML_ASSERT(grad->ne[3] == src1->ne[3]);
+                struct ggml_tensor * tmp =
                     ggml_out_prod(ctx, // [n,m,qq,rr]
                         src1,          // [n,p,qq,rr]
                         grad);         // [m,p,qq,rr]
-                const int64_t qq = s1_tg->ne[2];
-                const int64_t rr = s1_tg->ne[3];
-                const int64_t q1 = src0->ne[2];
-                const int64_t r1 = src0->ne[3];
-                const bool ne2_broadcasted = qq > q1;
-                const bool ne3_broadcasted = rr > r1;
-                if (ne2_broadcasted || ne3_broadcasted) {
-                    // sum broadcast repetitions of s1_tg into shape of src0
-                    s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+                if (!ggml_are_same_shape(tmp, src0)) {
+                    GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
+                    GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
+                    GGML_ASSERT(tmp->ne[3] == 1);
+
+                    const int64_t nr2 = tmp->ne[2] / src0->ne[2];
+                    const size_t nb2 = tmp->nb[2] * nr2;
+                    const size_t nb3 = tmp->nb[2];
+
+                    tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
+                    tmp = ggml_repeat_back(ctx, tmp, src0);
                 }
-                ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
+                ggml_add_or_set(ctx, cgraph, isrc0, tmp);
             }
             if (src1_needs_grads) {
                 ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5491,7 +5547,9 @@ static void ggml_compute_backward(
             if (src0_needs_grads) {
                 GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
                 GGML_ASSERT(ggml_is_contiguous(grad));
-                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+                GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
+                ggml_add_or_set(ctx, cgraph, isrc0,
+                    ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
             }
         } break;
         case GGML_OP_RESHAPE: {
@@ -5571,7 +5629,13 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_SOFT_MAX: {
             if (src0_needs_grads) {
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
+                float scale    = 1.0f;
+                float max_bias = 0.0f;
+
+                memcpy(&scale,    (const float *) tensor->op_params + 0, sizeof(float));
+                memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
+
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
             }
             GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
         } break;
@@ -5583,6 +5647,7 @@ static void ggml_compute_backward(
                 //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
                 const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
                 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+                int sections[4] = {0, 0, 0, 0};
 
                 memcpy(&freq_base,   (const float *) tensor->op_params +  5, sizeof(float));
                 memcpy(&freq_scale,  (const float *) tensor->op_params +  6, sizeof(float));
@@ -5590,10 +5655,14 @@ static void ggml_compute_backward(
                 memcpy(&attn_factor, (const float *) tensor->op_params +  8, sizeof(float));
                 memcpy(&beta_fast,   (const float *) tensor->op_params +  9, sizeof(float));
                 memcpy(&beta_slow,   (const float *) tensor->op_params + 10, sizeof(float));
+                memcpy(§ions,                    tensor->op_params + 11, sizeof(sections));
 
-                ggml_add_or_set(ctx, cgraph, isrc0,
-                    ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
-                        freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
+                struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
+                    ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
+                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
+                    ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
+                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+                ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
             }
             GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
         } break;
@@ -5607,7 +5676,7 @@ static void ggml_compute_backward(
                 const int32_t d1    = ggml_get_op_params_i32(tensor, 5);
                 const bool    is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
 
-                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
             }
         } break;
         case GGML_OP_POOL_2D: {
@@ -5650,7 +5719,7 @@ static void ggml_compute_backward(
                 } break;
                 case GGML_UNARY_OP_SILU: {
                     if (src0_needs_grads) {
-                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
                     }
                 } break;
                 case GGML_UNARY_OP_EXP: {
@@ -5667,7 +5736,7 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_CROSS_ENTROPY_LOSS: {
             if (src0_needs_grads) {
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
             }
             GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
         } break;
@@ -6438,1271 +6507,6 @@ size_t ggml_quantize_chunk(
 
 ////////////////////////////////////////////////////////////////////////////////
 
-struct gguf_str {
-    uint64_t n;  // GGUFv2
-    char * data;
-};
-
-static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
-    [GGUF_TYPE_UINT8]   = sizeof(uint8_t),
-    [GGUF_TYPE_INT8]    = sizeof(int8_t),
-    [GGUF_TYPE_UINT16]  = sizeof(uint16_t),
-    [GGUF_TYPE_INT16]   = sizeof(int16_t),
-    [GGUF_TYPE_UINT32]  = sizeof(uint32_t),
-    [GGUF_TYPE_INT32]   = sizeof(int32_t),
-    [GGUF_TYPE_FLOAT32] = sizeof(float),
-    [GGUF_TYPE_BOOL]    = sizeof(bool),
-    [GGUF_TYPE_STRING]  = sizeof(struct gguf_str),
-    [GGUF_TYPE_UINT64]  = sizeof(uint64_t),
-    [GGUF_TYPE_INT64]   = sizeof(int64_t),
-    [GGUF_TYPE_FLOAT64] = sizeof(double),
-    [GGUF_TYPE_ARRAY]   = 0, // undefined
-};
-static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
-
-static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
-    [GGUF_TYPE_UINT8]   = "u8",
-    [GGUF_TYPE_INT8]    = "i8",
-    [GGUF_TYPE_UINT16]  = "u16",
-    [GGUF_TYPE_INT16]   = "i16",
-    [GGUF_TYPE_UINT32]  = "u32",
-    [GGUF_TYPE_INT32]   = "i32",
-    [GGUF_TYPE_FLOAT32] = "f32",
-    [GGUF_TYPE_BOOL]    = "bool",
-    [GGUF_TYPE_STRING]  = "str",
-    [GGUF_TYPE_ARRAY]   = "arr",
-    [GGUF_TYPE_UINT64]  = "u64",
-    [GGUF_TYPE_INT64]   = "i64",
-    [GGUF_TYPE_FLOAT64] = "f64",
-};
-static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
-
-union gguf_value {
-    uint8_t  uint8;
-    int8_t   int8;
-    uint16_t uint16;
-    int16_t  int16;
-    uint32_t uint32;
-    int32_t  int32;
-    float    float32;
-    uint64_t uint64;
-    int64_t  int64;
-    double   float64;
-    bool     bool_;
-
-    struct gguf_str str;
-
-    struct {
-        enum gguf_type type;
-
-        uint64_t n;  // GGUFv2
-        void * data;
-    } arr;
-};
-
-struct gguf_kv {
-    struct gguf_str key;
-
-    enum  gguf_type  type;
-    union gguf_value value;
-};
-
-struct gguf_header {
-    char magic[4];
-
-    uint32_t version;
-    uint64_t n_tensors; // GGUFv2
-    uint64_t n_kv;      // GGUFv2
-};
-
-struct gguf_tensor_info {
-    struct gguf_str name;
-
-    uint32_t n_dims;
-    uint64_t ne[GGML_MAX_DIMS];
-
-    enum ggml_type type;
-
-    uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
-
-    // for writing API
-    const void * data;
-    size_t size;
-};
-
-struct gguf_context {
-    struct gguf_header header;
-
-    struct gguf_kv          * kv;
-    struct gguf_tensor_info * infos;
-
-    size_t alignment;
-    size_t offset;    // offset of `data` from beginning of file
-    size_t size;      // size of `data` in bytes
-
-    //uint8_t * padding;
-    void * data;
-};
-
-size_t gguf_type_size(enum gguf_type type) {
-    GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT);
-    return GGUF_TYPE_SIZE[type];
-}
-
-static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
-    if (info->n_dims > GGML_MAX_DIMS) {
-        fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
-        return false;
-    }
-
-    if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
-        fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
-        return false;
-    }
-
-    if (strlen(info->name.data) >= GGML_MAX_NAME) {
-        fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
-        return false;
-    }
-
-    for (uint32_t i = 0; i < info->n_dims; ++i) {
-        if (info->ne[i] <= 0) {
-            fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
-            return false;
-        }
-    }
-
-    // prevent overflow for total number of elements
-    if (INT64_MAX/info->ne[1] <= info->ne[0]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
-        return false;
-    }
-
-    if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
-        return false;
-    }
-
-    if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
-        return false;
-    }
-
-    return true;
-}
-
-static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
-    const size_t n = fread(dst, 1, size, file);
-    *offset += n;
-    return n == size;
-}
-
-static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
-    p->n    = 0;
-    p->data = NULL;
-
-    bool ok = true;
-
-    ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset);
-
-    // early exit if string length is invalid, prevents from integer overflow
-    if (p->n == SIZE_MAX) {
-        fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n);
-        return false;
-    }
-
-    p->data = calloc(p->n + 1, 1);
-    if (!p->data) {
-        fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
-        return false;
-    }
-
-    ok = ok && gguf_fread_el(file,  p->data, p->n, offset);
-
-    return ok;
-}
-
-static void gguf_free_kv(struct gguf_kv * kv) {
-    if (kv->key.data) {
-        GGML_FREE(kv->key.data);
-    }
-
-    if (kv->type == GGUF_TYPE_STRING) {
-        if (kv->value.str.data) {
-            GGML_FREE(kv->value.str.data);
-        }
-    }
-
-    if (kv->type == GGUF_TYPE_ARRAY) {
-        if (kv->value.arr.data) {
-            if (kv->value.arr.type == GGUF_TYPE_STRING) {
-                for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
-                    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j];
-                    if (str->data) {
-                        GGML_FREE(str->data);
-                    }
-                }
-            }
-            GGML_FREE(kv->value.arr.data);
-        }
-    }
-}
-
-struct gguf_context * gguf_init_empty(void) {
-    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
-    if (!ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
-        return NULL;
-    }
-
-    memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
-    ctx->header.version   = GGUF_VERSION;
-    ctx->header.n_tensors = 0;
-    ctx->header.n_kv      = 0;
-
-    ctx->kv    = NULL;
-    ctx->infos = NULL;
-
-    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
-    ctx->offset    = 0;
-    ctx->size      = 0;
-
-    ctx->data = NULL;
-
-    return ctx;
-}
-
-struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
-    // offset from start of file
-    size_t offset = 0;
-
-    char magic[4];
-
-    // check the magic before making allocations
-    {
-        gguf_fread_el(file, &magic, sizeof(magic), &offset);
-
-        for (uint32_t i = 0; i < sizeof(magic); i++) {
-            if (magic[i] != GGUF_MAGIC[i]) {
-                fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
-                return NULL;
-            }
-        }
-    }
-
-    bool ok = true;
-
-    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
-    if (!ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
-        return NULL;
-    }
-
-    // read the header
-    {
-        strncpy(ctx->header.magic, magic, 4);
-
-        ctx->kv    = NULL;
-        ctx->infos = NULL;
-        ctx->data  = NULL;
-
-        ok = ok && gguf_fread_el(file, &ctx->header.version,   sizeof(ctx->header.version),   &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_kv,      sizeof(ctx->header.n_kv),      &offset);
-
-        if (ctx->header.version == 1) {
-            fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        // sanity-checks to prevent from integer/buffer overflows
-
-        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info));
-        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead());
-        ok = ok && (ctx->header.n_kv      < (SIZE_MAX/2)/sizeof(struct gguf_kv));
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read header\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-    }
-
-    // read the kv pairs
-    {
-        const uint64_t n_kv = ctx->header.n_kv;
-
-        if (n_kv > 0) {
-            ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
-            if (!ctx->kv) {
-                fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
-                gguf_free(ctx);
-                return NULL;
-            }
-        }
-
-        for (uint64_t i = 0; i < n_kv; ++i) {
-            struct gguf_kv * kv = &ctx->kv[i];
-
-            //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
-
-            ok = ok && gguf_fread_str(file, &kv->key,                    &offset);
-            ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
-
-            //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
-
-            switch (kv->type) {
-                case GGUF_TYPE_UINT8:   ok = ok && gguf_fread_el (file, &kv->value.uint8,   sizeof(kv->value.uint8),   &offset); break;
-                case GGUF_TYPE_INT8:    ok = ok && gguf_fread_el (file, &kv->value.int8,    sizeof(kv->value.int8),    &offset); break;
-                case GGUF_TYPE_UINT16:  ok = ok && gguf_fread_el (file, &kv->value.uint16,  sizeof(kv->value.uint16),  &offset); break;
-                case GGUF_TYPE_INT16:   ok = ok && gguf_fread_el (file, &kv->value.int16,   sizeof(kv->value.int16),   &offset); break;
-                case GGUF_TYPE_UINT32:  ok = ok && gguf_fread_el (file, &kv->value.uint32,  sizeof(kv->value.uint32),  &offset); break;
-                case GGUF_TYPE_INT32:   ok = ok && gguf_fread_el (file, &kv->value.int32,   sizeof(kv->value.int32),   &offset); break;
-                case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
-                case GGUF_TYPE_UINT64:  ok = ok && gguf_fread_el (file, &kv->value.uint64,  sizeof(kv->value.uint64),  &offset); break;
-                case GGUF_TYPE_INT64:   ok = ok && gguf_fread_el (file, &kv->value.int64,   sizeof(kv->value.int64),   &offset); break;
-                case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
-                case GGUF_TYPE_BOOL:    ok = ok && gguf_fread_el (file, &kv->value.bool_,   sizeof(kv->value.bool_),   &offset); break;
-                case GGUF_TYPE_STRING:  ok = ok && gguf_fread_str(file, &kv->value.str,                                &offset); break;
-                case GGUF_TYPE_ARRAY:
-                    {
-                        ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
-                        ok = ok && gguf_fread_el(file, &kv->value.arr.n,    sizeof(kv->value.arr.n),    &offset);
-
-                        switch (kv->value.arr.type) {
-                            case GGUF_TYPE_UINT8:
-                            case GGUF_TYPE_INT8:
-                            case GGUF_TYPE_UINT16:
-                            case GGUF_TYPE_INT16:
-                            case GGUF_TYPE_UINT32:
-                            case GGUF_TYPE_INT32:
-                            case GGUF_TYPE_FLOAT32:
-                            case GGUF_TYPE_UINT64:
-                            case GGUF_TYPE_INT64:
-                            case GGUF_TYPE_FLOAT64:
-                            case GGUF_TYPE_BOOL:
-                                {
-                                    // prevent from integer overflow in the malloc below
-                                    if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) {
-                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
-                                    if (!kv->value.arr.data) {
-                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
-                                } break;
-                            case GGUF_TYPE_STRING:
-                                {
-                                    // prevent from integer overflow in the malloc below
-                                    if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) {
-                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
-                                    if (!kv->value.arr.data) {
-                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
-                                        ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
-                                    }
-                                } break;
-                            case GGUF_TYPE_ARRAY:
-                            default:
-                                {
-                                    fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
-                                    ok = false;
-                                } break;
-                        }
-                    } break;
-                default:
-                    {
-                        fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
-                        ok = false;
-                    } break;
-            }
-
-            if (!ok) {
-                break;
-            }
-        }
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-    }
-
-    // read the tensor infos
-    if (ctx->header.n_tensors > 0) {
-        ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
-        if (!ctx->infos) {
-            fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            for (int j = 0; j < GGML_MAX_DIMS; ++j) {
-                info->ne[j] = 1;
-            }
-
-            ok = ok && gguf_fread_str(file, &info->name,                          &offset);
-            ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims),  &offset);
-
-            ok = ok && (info->n_dims <= GGML_MAX_DIMS);
-
-            for (uint32_t j = 0; j < info->n_dims; ++j) {
-                ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
-            }
-
-            ok = ok && gguf_fread_el (file, &info->type,   sizeof(info->type),    &offset);
-            ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset),  &offset);
-
-            ok = ok && gguf_tensor_info_sanitize(info);
-
-            // make sure there is no duplicated tensor names
-            for (uint64_t j = 0; j < i && ok; ++j) {
-                if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
-                    fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
-                    ok = false;
-                }
-            }
-
-            if (!ok) {
-                fprintf(stderr, "%s: failed to read tensor info\n", __func__);
-                gguf_free(ctx);
-                return NULL;
-            }
-        }
-    }
-
-    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
-
-    int alignment_idx = gguf_find_key(ctx, "general.alignment");
-    if (alignment_idx != -1) {
-        ctx->alignment = gguf_get_val_u32(ctx, alignment_idx);
-    }
-
-    // we require the data section to be aligned, so take into account any padding
-    {
-        const size_t offset_pad = offset % ctx->alignment;
-
-        if (offset_pad != 0) {
-            offset += ctx->alignment - offset_pad;
-            fseek(file, offset, SEEK_SET);
-        }
-    }
-
-    // store the current file offset - this is where the data section starts
-    ctx->offset = offset;
-
-    // compute the total size of the data section, taking into account the alignment
-    {
-        ctx->size = 0;
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            const int64_t ne =
-                (int64_t) info->ne[0] *
-                (int64_t) info->ne[1] *
-                (int64_t) info->ne[2] *
-                (int64_t) info->ne[3];
-
-            if (ggml_blck_size(info->type) == 0 ) {
-                // this tensor type support have been removed:
-                fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
-                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type));
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            if (ne % ggml_blck_size(info->type) != 0) {
-                fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
-                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            const size_t size_cur = ggml_row_size(info->type, ne);
-
-            ctx->size += GGML_PAD(size_cur, ctx->alignment);
-        }
-    }
-
-    // load the tensor data only if requested
-    if (params.ctx != NULL) {
-        // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
-        // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
-        // the ggml_tensor structs to the appropriate locations in the binary blob
-
-        // compute the exact size needed for the new ggml_context
-        const size_t mem_size =
-            params.no_alloc ?
-            (ctx->header.n_tensors    )*ggml_tensor_overhead() :
-            (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
-
-        struct ggml_init_params pdata = {
-            .mem_size   = mem_size,
-            .mem_buffer = NULL,
-            .no_alloc   = params.no_alloc,
-        };
-
-        *params.ctx = ggml_init(pdata);
-        if (*params.ctx == NULL) {
-            fprintf(stderr, "%s: failed to initialize context\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        struct ggml_context * ctx_data = *params.ctx;
-
-        struct ggml_tensor * data = NULL;
-
-        if (!params.no_alloc) {
-            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
-
-            ok = ok && data != NULL;
-
-            // read the binary blob with the tensor data
-            ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset);
-
-            if (!ok) {
-                fprintf(stderr, "%s: failed to read tensor data\n", __func__);
-                ggml_free(ctx_data);
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            ctx->data = data->data;
-        }
-
-        ggml_set_no_alloc(ctx_data, true);
-
-        // create the tensors
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            const int64_t ne[GGML_MAX_DIMS] = {
-                ctx->infos[i].ne[0],
-                ctx->infos[i].ne[1],
-                ctx->infos[i].ne[2],
-                ctx->infos[i].ne[3],
-            };
-
-            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
-
-            ok = ok && cur != NULL;
-
-            if (!ok) {
-                break;
-            }
-
-            ggml_set_name(cur, ctx->infos[i].name.data);
-
-            // point the data member to the appropriate location in the binary blob using the tensor infos
-            if (!params.no_alloc) {
-              //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
-                cur->data = (char *) data->data + ctx->infos[i].offset;               // offset from data
-            }
-        }
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
-            ggml_free(ctx_data);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        ggml_set_no_alloc(ctx_data, params.no_alloc);
-    }
-
-    return ctx;
-}
-
-struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
-    FILE * file = ggml_fopen(fname, "rb");
-    if (!file) {
-        fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
-        return NULL;
-    }
-
-    struct gguf_context * result = gguf_init_from_file_impl(file, params);
-    fclose(file);
-    return result;
-}
-
-void gguf_free(struct gguf_context * ctx) {
-    if (ctx == NULL) {
-        return;
-    }
-
-    if (ctx->kv) {
-        // free string memory - not great..
-        for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
-            gguf_free_kv(&ctx->kv[i]);
-        }
-
-        GGML_FREE(ctx->kv);
-    }
-
-    if (ctx->infos) {
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            if (info->name.data) {
-                GGML_FREE(info->name.data);
-            }
-        }
-
-        GGML_FREE(ctx->infos);
-    }
-
-    GGML_FREE(ctx);
-}
-
-const char * gguf_type_name(enum gguf_type type) {
-    return GGUF_TYPE_NAME[type];
-}
-
-int gguf_get_version(const struct gguf_context * ctx) {
-    return ctx->header.version;
-}
-
-size_t gguf_get_alignment(const struct gguf_context * ctx) {
-    return ctx->alignment;
-}
-
-size_t gguf_get_data_offset(const struct gguf_context * ctx) {
-    return ctx->offset;
-}
-
-void * gguf_get_data(const struct gguf_context * ctx) {
-    return ctx->data;
-}
-
-int gguf_get_n_kv(const struct gguf_context * ctx) {
-    return ctx->header.n_kv;
-}
-
-int gguf_find_key(const struct gguf_context * ctx, const char * key) {
-    // return -1 if key not found
-    int keyfound = -1;
-
-    const int n_kv = gguf_get_n_kv(ctx);
-
-    for (int i = 0; i < n_kv; ++i) {
-        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
-            keyfound = i;
-            break;
-        }
-    }
-
-    return keyfound;
-}
-
-const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    return ctx->kv[key_id].key.data;
-}
-
-enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    return ctx->kv[key_id].type;
-}
-
-enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.type;
-}
-
-const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.data;
-}
-
-const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    struct gguf_kv * kv = &ctx->kv[key_id];
-    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
-    return str->data;
-}
-
-int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.n;
-}
-
-uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
-    return ctx->kv[key_id].value.uint8;
-}
-
-int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
-    return ctx->kv[key_id].value.int8;
-}
-
-uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
-    return ctx->kv[key_id].value.uint16;
-}
-
-int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
-    return ctx->kv[key_id].value.int16;
-}
-
-uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
-    return ctx->kv[key_id].value.uint32;
-}
-
-int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
-    return ctx->kv[key_id].value.int32;
-}
-
-float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
-    return ctx->kv[key_id].value.float32;
-}
-
-uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
-    return ctx->kv[key_id].value.uint64;
-}
-
-int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
-    return ctx->kv[key_id].value.int64;
-}
-
-double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
-    return ctx->kv[key_id].value.float64;
-}
-
-bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
-    return ctx->kv[key_id].value.bool_;
-}
-
-const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
-    return ctx->kv[key_id].value.str.data;
-}
-
-const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
-    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
-    return &ctx->kv[key_id].value;
-}
-
-int gguf_get_n_tensors(const struct gguf_context * ctx) {
-    return ctx->header.n_tensors;
-}
-
-int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
-    // return -1 if tensor not found
-    int tensorfound = -1;
-
-    const int n_tensors = gguf_get_n_tensors(ctx);
-
-    for (int i = 0; i < n_tensors; ++i) {
-        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
-            tensorfound = i;
-            break;
-        }
-    }
-
-    return tensorfound;
-}
-
-size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].offset;
-}
-
-char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].name.data;
-}
-
-enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].type;
-}
-
-// returns the index
-static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) {
-    const int idx = gguf_find_key(ctx, key);
-    if (idx >= 0) {
-        return idx;
-    }
-
-    const int n_kv = gguf_get_n_kv(ctx);
-
-    ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv));
-    ctx->kv[n_kv].key.n    = strlen(key);
-    ctx->kv[n_kv].key.data = strdup(key);
-    ctx->header.n_kv++;
-
-    return n_kv;
-}
-
-void gguf_remove_key(struct gguf_context * ctx, const char * key) {
-    const int idx = gguf_find_key(ctx, key);
-    if (idx >= 0) {
-        const int n_kv = gguf_get_n_kv(ctx);
-        gguf_free_kv(&ctx->kv[idx]);
-        for (int i = idx; i < n_kv-1; ++i) {
-            ctx->kv[i] = ctx->kv[i+1];
-        }
-        ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv));
-        ctx->header.n_kv--;
-    }
-}
-
-void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_UINT8;
-    ctx->kv[idx].value.uint8 = val;
-}
-
-void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type       = GGUF_TYPE_INT8;
-    ctx->kv[idx].value.int8 = val;
-}
-
-void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT16;
-    ctx->kv[idx].value.uint16 = val;
-}
-
-void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT16;
-    ctx->kv[idx].value.int16 = val;
-}
-
-void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT32;
-    ctx->kv[idx].value.uint32 = val;
-}
-
-void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT32;
-    ctx->kv[idx].value.int32 = val;
-}
-
-void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type          = GGUF_TYPE_FLOAT32;
-    ctx->kv[idx].value.float32 = val;
-}
-
-void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT64;
-    ctx->kv[idx].value.uint64 = val;
-}
-
-void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT64;
-    ctx->kv[idx].value.int64 = val;
-}
-
-void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type          = GGUF_TYPE_FLOAT64;
-    ctx->kv[idx].value.float64 = val;
-}
-
-void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_BOOL;
-    ctx->kv[idx].value.bool_ = val;
-}
-
-void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_STRING;
-    ctx->kv[idx].value.str.n    = strlen(val);
-    ctx->kv[idx].value.str.data = strdup(val);
-}
-
-void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
-    ctx->kv[idx].value.arr.type = type;
-    ctx->kv[idx].value.arr.n    = n;
-    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
-    memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
-}
-
-void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
-    ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
-    ctx->kv[idx].value.arr.n    = n;
-    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
-    for (int i = 0; i < n; i++) {
-        struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
-        str->n    = strlen(data[i]);
-        str->data = strdup(data[i]);
-    }
-}
-
-// set or add KV pairs from another context
-void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
-    for (uint32_t i = 0; i < src->header.n_kv; i++) {
-        switch (src->kv[i].type) {
-            case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, src->kv[i].key.data, src->kv[i].value.uint8);    break;
-            case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, src->kv[i].key.data, src->kv[i].value.int8);     break;
-            case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16);   break;
-            case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16);    break;
-            case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32);   break;
-            case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32);    break;
-            case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32);  break;
-            case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64);   break;
-            case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64);    break;
-            case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64);  break;
-            case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_);    break;
-            case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
-            case GGUF_TYPE_ARRAY:
-                {
-                    if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
-                        const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
-                        for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
-                            data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
-                        }
-                        gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
-                        GGML_FREE((void *)data);
-                    } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
-                        GGML_ABORT("nested arrays not supported");
-                    } else {
-                        gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
-                    }
-                } break;
-            default: GGML_ABORT("invalid type");
-        }
-    }
-}
-
-void gguf_add_tensor(
-             struct gguf_context * ctx,
-        const struct ggml_tensor * tensor) {
-    GGML_ASSERT(tensor);
-    if (gguf_find_tensor(ctx, tensor->name) != -1) {
-        GGML_ABORT("duplicated tensor name");
-    }
-
-    const int idx = ctx->header.n_tensors;
-    ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
-
-    ctx->infos[idx].name.n    = strlen(tensor->name);
-    ctx->infos[idx].name.data = strdup(tensor->name);
-
-    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
-        ctx->infos[idx].ne[i] = 1;
-    }
-
-    ctx->infos[idx].n_dims = ggml_n_dims(tensor);
-    for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
-        ctx->infos[idx].ne[i] = tensor->ne[i];
-    }
-
-    ctx->infos[idx].type   = tensor->type;
-    ctx->infos[idx].offset = 0;
-    ctx->infos[idx].data   = tensor->data;
-    ctx->infos[idx].size   = ggml_nbytes(tensor);
-
-    if (ctx->header.n_tensors > 0) {
-        ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment);
-    }
-
-    ctx->header.n_tensors++;
-}
-
-void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
-    const int idx = gguf_find_tensor(ctx, name);
-    if (idx < 0) {
-        GGML_ABORT("tensor not found");
-    }
-
-    ctx->infos[idx].type = type;
-}
-
-void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) {
-    const int idx = gguf_find_tensor(ctx, name);
-    if (idx < 0) {
-        GGML_ABORT("tensor not found");
-    }
-
-    ctx->infos[idx].data = data;
-    ctx->infos[idx].size = size;
-
-    // update offsets
-    for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
-        ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment);
-    }
-}
-
-//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) {
-//    fwrite(&val->n,   sizeof(val->n),    1, file);
-//    fwrite(val->data, sizeof(char), val->n, file);
-//}
-//
-//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) {
-//    fwrite(val, sizeof(char), size, file);
-//}
-
-struct gguf_buf gguf_buf_init(size_t size) {
-    struct gguf_buf buf = {
-        /*buf.data   =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
-        /*buf.size   =*/ size,
-        /*buf.offset =*/ 0,
-    };
-
-    return buf;
-}
-
-void gguf_buf_free(struct gguf_buf buf) {
-    if (buf.data) {
-        GGML_FREE(buf.data);
-    }
-}
-
-static void gguf_buf_grow(struct gguf_buf * buf, size_t size) {
-    if (buf->offset + size > buf->size) {
-        buf->size = 1.5*(buf->offset + size);
-        if (buf->data) {
-            buf->data = realloc(buf->data, buf->size);
-        }
-    }
-}
-
-static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) {
-    gguf_buf_grow(buf, sizeof(val->n) + val->n);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
-    }
-    buf->offset += sizeof(val->n);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, val->data, val->n);
-    }
-    buf->offset += val->n;
-}
-
-static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) {
-    gguf_buf_grow(buf, el_size);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, val, el_size);
-    }
-    buf->offset += el_size;
-}
-
-void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
-    // write header
-    gguf_bwrite_el(buf, &ctx->header.magic,     sizeof(ctx->header.magic));
-    gguf_bwrite_el(buf, &ctx->header.version,   sizeof(ctx->header.version));
-    gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
-    gguf_bwrite_el(buf, &ctx->header.n_kv,      sizeof(ctx->header.n_kv));
-
-    // write key-value pairs
-    for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
-        struct gguf_kv * kv = &ctx->kv[i];
-
-        gguf_bwrite_str(buf, &kv->key);
-        gguf_bwrite_el (buf, &kv->type, sizeof(kv->type));
-
-        switch (kv->type) {
-            case GGUF_TYPE_UINT8:   gguf_bwrite_el( buf, &kv->value.uint8,   sizeof(kv->value.uint8)  ); break;
-            case GGUF_TYPE_INT8:    gguf_bwrite_el (buf, &kv->value.int8,    sizeof(kv->value.int8)   ); break;
-            case GGUF_TYPE_UINT16:  gguf_bwrite_el (buf, &kv->value.uint16,  sizeof(kv->value.uint16) ); break;
-            case GGUF_TYPE_INT16:   gguf_bwrite_el (buf, &kv->value.int16,   sizeof(kv->value.int16)  ); break;
-            case GGUF_TYPE_UINT32:  gguf_bwrite_el (buf, &kv->value.uint32,  sizeof(kv->value.uint32) ); break;
-            case GGUF_TYPE_INT32:   gguf_bwrite_el (buf, &kv->value.int32,   sizeof(kv->value.int32)  ); break;
-            case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
-            case GGUF_TYPE_UINT64:  gguf_bwrite_el (buf, &kv->value.uint64,  sizeof(kv->value.uint64) ); break;
-            case GGUF_TYPE_INT64:   gguf_bwrite_el (buf, &kv->value.int64,   sizeof(kv->value.int64)  ); break;
-            case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
-            case GGUF_TYPE_BOOL:    gguf_bwrite_el (buf, &kv->value.bool_,   sizeof(kv->value.bool_)  ); break;
-            case GGUF_TYPE_STRING:  gguf_bwrite_str(buf, &kv->value.str                               ); break;
-            case GGUF_TYPE_ARRAY:
-                {
-                    gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
-                    gguf_bwrite_el(buf, &kv->value.arr.n,    sizeof(kv->value.arr.n)   );
-
-                    switch (kv->value.arr.type) {
-                        case GGUF_TYPE_UINT8:
-                        case GGUF_TYPE_INT8:
-                        case GGUF_TYPE_UINT16:
-                        case GGUF_TYPE_INT16:
-                        case GGUF_TYPE_UINT32:
-                        case GGUF_TYPE_INT32:
-                        case GGUF_TYPE_FLOAT32:
-                        case GGUF_TYPE_UINT64:
-                        case GGUF_TYPE_INT64:
-                        case GGUF_TYPE_FLOAT64:
-                        case GGUF_TYPE_BOOL:
-                            {
-                                gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type));
-                            } break;
-                        case GGUF_TYPE_STRING:
-                            {
-                                for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
-                                    gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]);
-                                }
-                            } break;
-                        case GGUF_TYPE_ARRAY:
-                        default: GGML_ABORT("invalid type");
-                    }
-                } break;
-            default: GGML_ABORT("invalid type");
-        }
-    }
-
-    // write tensor infos
-    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
-        struct gguf_tensor_info * info = &ctx->infos[i];
-
-        gguf_bwrite_str(buf, &info->name);
-        gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
-        for (uint32_t j = 0; j < info->n_dims; ++j) {
-            gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
-        }
-        gguf_bwrite_el(buf, &info->type,   sizeof(info->type));
-        gguf_bwrite_el(buf, &info->offset, sizeof(info->offset));
-    }
-
-    // we require the data section to be aligned, so take into account any padding
-    {
-        const size_t offset     = buf->offset;
-        const size_t offset_pad = GGML_PAD(offset, ctx->alignment);
-
-        if (offset_pad != offset) {
-            uint8_t pad = 0;
-            for (size_t i = 0; i < offset_pad - offset; ++i) {
-                gguf_bwrite_el(buf, &pad, sizeof(pad));
-            }
-        }
-    }
-
-    if (only_meta) {
-        return;
-    }
-
-    size_t offset = 0;
-
-    // write tensor data
-    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
-        struct gguf_tensor_info * info = &ctx->infos[i];
-
-        const size_t size     = info->size;
-        const size_t size_pad = GGML_PAD(size, ctx->alignment);
-
-        gguf_bwrite_el(buf, info->data, size);
-
-        if (size_pad != size) {
-            uint8_t pad = 0;
-            for (size_t j = 0; j < size_pad - size; ++j) {
-                gguf_bwrite_el(buf, &pad, sizeof(pad));
-            }
-        }
-
-        GGML_ASSERT(offset == info->offset);
-
-        offset += size_pad;
-    }
-}
-
-void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
-    FILE * file = ggml_fopen(fname, "wb");
-    if (!file) {
-        GGML_ABORT("failed to open file for writing");
-    }
-
-    struct gguf_buf buf = gguf_buf_init(16*1024);
-
-    gguf_write_to_buf(ctx, &buf, only_meta);
-
-    fwrite(buf.data, 1, buf.offset, file);
-
-    gguf_buf_free(buf);
-
-    fclose(file);
-}
-
-size_t gguf_get_meta_size(const struct gguf_context * ctx) {
-    // no allocs - only compute size
-    struct gguf_buf buf = gguf_buf_init(0);
-
-    gguf_write_to_buf(ctx, &buf, true);
-
-    return buf.offset;
-}
-
-void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
-    struct gguf_buf buf = gguf_buf_init(16*1024);
-
-    gguf_write_to_buf(ctx, &buf, true);
-
-    memcpy(data, buf.data, buf.offset);
-
-    gguf_buf_free(buf);
-}
-
 void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
     g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
     g_logger_state.log_callback_user_data = user_data;
diff --git a/ml/backend/ggml/ggml/src/ggml_darwin_arm64.go b/ml/backend/ggml/ggml/src/ggml_darwin_arm64.go
index beffa64e2..d4354c603 100644
--- a/ml/backend/ggml/ggml/src/ggml_darwin_arm64.go
+++ b/ml/backend/ggml/ggml/src/ggml_darwin_arm64.go
@@ -1,6 +1,6 @@
 package ggml
 
-// #cgo CPPFLAGS: -DGGML_USE_METAL -DGGML_USE_BLAS
+// #cgo CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_BLAS
 // #cgo LDFLAGS: -framework Foundation
 import "C"
 
diff --git a/ml/backend/ggml/ggml/src/gguf.cpp b/ml/backend/ggml/ggml/src/gguf.cpp
new file mode 100644
index 000000000..ab13669c5
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/gguf.cpp
@@ -0,0 +1,1329 @@
+#include "ggml.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "gguf.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+template 
+struct type_to_gguf_type;
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT8;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT8;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT16;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT16;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT32;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT32;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT32;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_BOOL;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_STRING;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT64;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT64;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT64;
+};
+
+static const std::map GGUF_TYPE_SIZE = {
+    {GGUF_TYPE_UINT8,   sizeof(uint8_t)},
+    {GGUF_TYPE_INT8,    sizeof(int8_t)},
+    {GGUF_TYPE_UINT16,  sizeof(uint16_t)},
+    {GGUF_TYPE_INT16,   sizeof(int16_t)},
+    {GGUF_TYPE_UINT32,  sizeof(uint32_t)},
+    {GGUF_TYPE_INT32,   sizeof(int32_t)},
+    {GGUF_TYPE_FLOAT32, sizeof(float)},
+    {GGUF_TYPE_BOOL,    sizeof(int8_t)},
+    {GGUF_TYPE_STRING,  0}, // undefined
+    {GGUF_TYPE_ARRAY,   0}, // undefined
+    {GGUF_TYPE_UINT64,  sizeof(uint64_t)},
+    {GGUF_TYPE_INT64,   sizeof(int64_t)},
+    {GGUF_TYPE_FLOAT64, sizeof(double)},
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+static const std::map GGUF_TYPE_NAME = {
+    {GGUF_TYPE_UINT8,   "u8"},
+    {GGUF_TYPE_INT8,    "i8"},
+    {GGUF_TYPE_UINT16,  "u16"},
+    {GGUF_TYPE_INT16,   "i16"},
+    {GGUF_TYPE_UINT32,  "u32"},
+    {GGUF_TYPE_INT32,   "i32"},
+    {GGUF_TYPE_FLOAT32, "f32"},
+    {GGUF_TYPE_BOOL,    "bool"},
+    {GGUF_TYPE_STRING,  "str"},
+    {GGUF_TYPE_ARRAY,   "arr"},
+    {GGUF_TYPE_UINT64,  "u64"},
+    {GGUF_TYPE_INT64,   "i64"},
+    {GGUF_TYPE_FLOAT64, "f64"},
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+size_t gguf_type_size(enum gguf_type type) {
+    auto it = GGUF_TYPE_SIZE.find(type);
+    return it == GGUF_TYPE_SIZE.end() ? 0 : it->second;
+}
+
+struct gguf_kv {
+    std::string key;
+
+    bool is_array;
+    enum gguf_type type;
+
+    std::vector      data;
+    std::vector data_string;
+
+    template 
+    gguf_kv(const std::string & key, const T value)
+            : key(key), is_array(false), type(type_to_gguf_type::value) {
+        GGML_ASSERT(!key.empty());
+        data.resize(sizeof(T));
+        memcpy(data.data(), &value, sizeof(T));
+    }
+
+    template 
+    gguf_kv(const std::string & key, const std::vector & value)
+            : key(key), is_array(true), type(type_to_gguf_type::value) {
+        GGML_ASSERT(!key.empty());
+        data.resize(value.size()*sizeof(T));
+        for (size_t i = 0; i < value.size(); ++i) {
+            const T tmp = value[i];
+            memcpy(data.data() + i*sizeof(T), &tmp, sizeof(T));
+        }
+    }
+
+    gguf_kv(const std::string & key, const std::string & value)
+            : key(key), is_array(false), type(GGUF_TYPE_STRING) {
+        GGML_ASSERT(!key.empty());
+        data_string.push_back(value);
+    }
+
+    gguf_kv(const std::string & key, const std::vector & value)
+            : key(key), is_array(true), type(GGUF_TYPE_STRING) {
+        GGML_ASSERT(!key.empty());
+        data_string = value;
+    }
+
+    const std::string & get_key() const {
+        return key;
+    }
+
+    const enum gguf_type & get_type() const {
+        return type;
+    }
+
+    size_t get_ne() const {
+        if (type == GGUF_TYPE_STRING) {
+            const size_t ne = data_string.size();
+            GGML_ASSERT(is_array || ne == 1);
+            return ne;
+        }
+        const size_t type_size = gguf_type_size(type);
+        GGML_ASSERT(data.size() % type_size == 0);
+        const size_t ne = data.size() / type_size;
+        GGML_ASSERT(is_array || ne == 1);
+        return ne;
+    }
+
+    template 
+    const T & get_val(const size_t i = 0) const {
+        GGML_ASSERT(type_to_gguf_type::value == type);
+        if constexpr (std::is_same::value) {
+            GGML_ASSERT(data_string.size() >= i+1);
+            return data_string[i];
+        }
+        const size_t type_size = gguf_type_size(type);
+        GGML_ASSERT(data.size() % type_size == 0);
+        GGML_ASSERT(data.size() >= (i+1)*type_size);
+        return reinterpret_cast(data.data())[i];
+    }
+
+    void cast(const enum gguf_type new_type) {
+        const size_t new_type_size = gguf_type_size(new_type);
+        GGML_ASSERT(data.size() % new_type_size == 0);
+        type = new_type;
+    }
+};
+
+struct gguf_tensor_info {
+    struct ggml_tensor t; // for holding the equivalent info
+    uint64_t offset;      // offset from start of `data`, must be a multiple of `ALIGNMENT`
+};
+
+struct gguf_context {
+    uint32_t version = GGUF_VERSION;
+
+    std::vector kv;
+    std::vector info;
+
+    size_t alignment = GGUF_DEFAULT_ALIGNMENT;
+    size_t offset    = 0; // offset of `data` from beginning of file
+    size_t size      = 0; // size of `data` in bytes
+
+    void * data = nullptr;
+};
+
+struct gguf_reader {
+    FILE * file;
+
+    gguf_reader(FILE * file) : file(file) {}
+
+    template 
+    bool read(T & dst) const {
+        return fread(&dst, 1, sizeof(dst), file) == sizeof(dst);
+    }
+
+    template 
+    bool read(std::vector & dst, const size_t n) const {
+        dst.resize(n);
+        for (size_t i = 0; i < dst.size(); ++i) {
+            if constexpr (std::is_same::value) {
+                bool tmp;
+                if (!read(tmp)) {
+                    return false;
+                }
+                dst[i] = tmp;
+            } else {
+                if (!read(dst[i])) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
+    bool read(bool & dst) const {
+        int8_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = tmp != 0;
+        return true;
+    }
+
+    bool read(enum ggml_type & dst) const {
+        int32_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = ggml_type(tmp);
+        return true;
+    }
+
+    bool read(enum gguf_type & dst) const {
+        int32_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = gguf_type(tmp);
+        return true;
+    }
+
+    bool read(std::string & dst) const {
+        uint64_t size = -1;
+        if (!read(size)) {
+            return false;
+        }
+        dst.resize(size);
+        return fread(dst.data(), 1, dst.length(), file) == dst.length();
+    }
+
+    bool read(void * dst, const size_t size) const {
+        return fread(dst, 1, size, file) == size;
+    }
+};
+
+struct gguf_context * gguf_init_empty(void) {
+    return new gguf_context;
+}
+
+template
+bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector & kv, const std::string & key, const bool is_array, const size_t n) {
+    if (is_array) {
+        std::vector value;
+        try {
+            if (!gr.read(value, n)) {
+                return false;
+            }
+        } catch (std::length_error &) {
+            fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
+            return false;
+        } catch (std::bad_alloc &) {
+            fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
+            return false;
+        }
+        kv.emplace_back(key, value);
+    } else {
+        T value;
+        if (!gr.read(value)) {
+            return false;
+        }
+        kv.emplace_back(key, value);
+    }
+    return true;
+}
+
+struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
+    const struct gguf_reader gr(file);
+    struct gguf_context * ctx = new gguf_context;
+
+    bool ok = true;
+
+    // file magic
+    {
+        std::vector magic;
+        ok = ok && gr.read(magic, 4);
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read magic\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        for (uint32_t i = 0; i < magic.size(); i++) {
+            if (magic[i] != GGUF_MAGIC[i]) {
+                fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
+                gguf_free(ctx);
+                return nullptr;
+            }
+        }
+    }
+
+    // header
+    int64_t n_kv      = 0;
+    int64_t n_tensors = 0;
+
+    if (ok && gr.read(ctx->version)) {
+        if (ctx->version == 1) {
+            fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
+            ok = false;
+        }
+        if (ctx->version > GGUF_VERSION) {
+            fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
+                __func__, ctx->version, GGUF_VERSION);
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (ok && gr.read(n_tensors)) {
+        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
+        if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
+            fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
+                __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (ok && gr.read(n_kv)) {
+        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
+        if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
+            fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
+                    __func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (!ok) {
+        fprintf(stderr, "%s: failed to read header\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+
+    // KV pairs
+    {
+        for (int64_t i = 0; ok && i < n_kv; ++i) {
+            std::string key;
+            gguf_type   type     = gguf_type(-1);
+            bool        is_array = false;
+            uint64_t    n        = 1;
+
+            try {
+                ok = ok && gr.read(key);
+            } catch (std::length_error &) {
+                fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
+                ok = false;
+            } catch (std::bad_alloc &) {
+                fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
+                ok = false;
+            }
+            for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
+                if (key == ctx->kv[j].key) {
+                    fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
+                    ok = false;
+                }
+            }
+            if (!ok) {
+                break;
+            }
+
+            ok = ok && gr.read(type);
+            if (type == GGUF_TYPE_ARRAY) {
+                is_array = true;
+                ok = ok && gr.read(type);
+                ok = ok && gr.read(n);
+            }
+            if (!ok) {
+                break;
+            }
+
+            switch (type) {
+                case GGUF_TYPE_UINT8:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT8:    ok = ok && gguf_read_emplace_helper     (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT16:  ok = ok && gguf_read_emplace_helper   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT16:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT32:  ok = ok && gguf_read_emplace_helper   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT32:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_FLOAT32: ok = ok && gguf_read_emplace_helper      (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_BOOL:    ok = ok && gguf_read_emplace_helper       (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_STRING:  ok = ok && gguf_read_emplace_helper(gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT64:  ok = ok && gguf_read_emplace_helper   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT64:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_FLOAT64: ok = ok && gguf_read_emplace_helper     (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_ARRAY:
+                default:
+                    {
+                        fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
+                        ok = false;
+                    } break;
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+        GGML_ASSERT(int64_t(ctx->kv.size()) == n_kv);
+
+        const int alignment_idx = gguf_find_key(ctx, GGUF_KEY_GENERAL_ALIGNMENT);
+        ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
+
+        if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
+            fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
+            gguf_free(ctx);
+            return nullptr;
+        }
+    }
+
+    // read the tensor info
+    for (int64_t i = 0; ok && i < n_tensors; ++i) {
+        struct gguf_tensor_info info;
+
+        // tensor name
+        {
+            std::string name;
+            try {
+                ok = ok && gr.read(name);
+            } catch (std::length_error &) {
+                fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
+                ok = false;
+            } catch (std::bad_alloc &) {
+                fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
+                ok = false;
+            }
+            if (name.length() >= GGML_MAX_NAME) {
+                fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
+                ok = false;
+                break;
+            }
+            ggml_set_name(&info.t, name.c_str());
+
+            // make sure there are no duplicate tensor names
+            for (int64_t j = 0; ok && j < i; ++j) {
+                if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
+                    fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
+                    ok = false;
+                    break;
+                }
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor shape
+        {
+            uint32_t n_dims = -1;
+            ok = ok && gr.read(n_dims);
+            if (n_dims > GGML_MAX_DIMS) {
+                fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
+                    __func__, info.t.name, n_dims, GGML_MAX_DIMS);
+                ok = false;
+                break;
+            }
+            for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) {
+                info.t.ne[j] = 1;
+                if (j < n_dims) {
+                    ok = ok && gr.read(info.t.ne[j]);
+                }
+
+                // check that all ne are non-negative
+                if (info.t.ne[j] < 0) {
+                    fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
+                        __func__, info.t.name, j, info.t.ne[j]);
+                    ok = false;
+                    break;
+                }
+            }
+
+            // check that the total number of elements is representable
+            if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) ||
+                       (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
+                       (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
+
+                fprintf(stderr, "%s: total number of elements in tensor '%s' with shape "
+                    "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
+                    __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
+                ok = false;
+                break;
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor type
+        {
+            ok = ok && gr.read(info.t.type);
+
+            // check that tensor type is within defined range
+            if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
+                fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n",
+                    __func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
+                ok = false;
+                break;
+            }
+            const size_t  type_size = ggml_type_size(info.t.type);
+            const int64_t blck_size = ggml_blck_size(info.t.type);
+
+            // check that row size is divisible by block size
+            if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
+                fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
+                    "not a multiple of block size (%" PRId64 ")\n",
+                    __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
+                ok = false;
+                break;
+            }
+
+            // calculate byte offsets given the tensor shape and type
+            info.t.nb[0] = type_size;
+            info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size);
+            for (int j = 2; j < GGML_MAX_DIMS; ++j) {
+                info.t.nb[j] = info.t.nb[j - 1]*info.t.ne[j - 1];
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor data offset within buffer
+        ok = ok && gr.read(info.offset);
+
+        ctx->info.push_back(info);
+    }
+
+    if (!ok) {
+        fprintf(stderr, "%s: failed to read tensor info\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+    GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);
+
+    // we require the data section to be aligned, so take into account any padding
+    if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
+        fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+
+    // store the current file offset - this is where the data section starts
+    ctx->offset = ftell(file);
+
+    // compute the total size of the data section, taking into account the alignment
+    {
+        ctx->size = 0;
+        for (size_t i = 0; i < ctx->info.size(); ++i) {
+            const gguf_tensor_info & ti = ctx->info[i];
+            if (ti.offset != ctx->size) {
+                fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
+                    __func__, ti.t.name, ti.offset, ctx->size);
+                fprintf(stderr, "%s: failed to read tensor data\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+            ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
+        }
+    }
+
+    // load the tensor data only if requested
+    if (params.ctx != nullptr) {
+        // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
+        // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
+        //   the ggml_tensor structs to the appropriate locations in the binary blob
+
+        // compute the exact size needed for the new ggml_context
+        const size_t mem_size =
+            params.no_alloc ?
+            (n_tensors    )*ggml_tensor_overhead() :
+            (n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
+
+        struct ggml_init_params pdata = {
+            /*mem_size   =*/ mem_size,
+            /*mem_buffer =*/ nullptr,
+            /*no_alloc   =*/ params.no_alloc,
+        };
+
+        *params.ctx = ggml_init(pdata);
+        if (*params.ctx == nullptr) {
+            fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        struct ggml_context * ctx_data = *params.ctx;
+
+        struct ggml_tensor * data = nullptr;
+
+        if (!params.no_alloc) {
+            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
+
+            ok = ok && data != nullptr;
+
+            if (ok) {
+                ggml_set_name(data, "GGUF tensor data binary blob");
+            }
+
+            // read the binary blob with the tensor data
+            ok = ok && gr.read(data->data, ctx->size);
+
+            if (!ok) {
+                fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__);
+                ggml_free(ctx_data);
+                *params.ctx = nullptr;
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            ctx->data = data->data;
+        }
+
+        ggml_set_no_alloc(ctx_data, true);
+
+        // create the tensors
+        for (size_t i = 0; i < ctx->info.size(); ++i) {
+            const struct gguf_tensor_info & info = ctx->info[i];
+
+            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, info.t.type, GGML_MAX_DIMS, info.t.ne);
+
+            ok = ok && cur != nullptr;
+
+            if (!ok) {
+                break;
+            }
+
+            ggml_set_name(cur, info.t.name);
+
+            // point the data member to the appropriate location in the binary blob using the tensor info
+            if (!params.no_alloc) {
+                cur->data = (char *) data->data + info.offset;
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to create tensors\n", __func__);
+            ggml_free(ctx_data);
+            *params.ctx = nullptr;
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        ggml_set_no_alloc(ctx_data, params.no_alloc);
+    }
+
+    return ctx;
+}
+
+struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
+    FILE * file = ggml_fopen(fname, "rb");
+
+    if (!file) {
+        fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname);
+        return nullptr;
+    }
+
+    struct gguf_context * result = gguf_init_from_file_impl(file, params);
+    fclose(file);
+    return result;
+}
+
+void gguf_free(struct gguf_context * ctx) {
+    if (ctx == nullptr) {
+        return;
+    }
+    delete ctx;
+}
+
+const char * gguf_type_name(enum gguf_type type) {
+    auto it = GGUF_TYPE_NAME.find(type);
+    return it == GGUF_TYPE_NAME.end() ? nullptr : it->second;
+}
+
+uint32_t gguf_get_version(const struct gguf_context * ctx) {
+    return ctx->version;
+}
+
+size_t gguf_get_alignment(const struct gguf_context * ctx) {
+    return ctx->alignment;
+}
+
+size_t gguf_get_data_offset(const struct gguf_context * ctx) {
+    return ctx->offset;
+}
+
+int64_t gguf_get_n_kv(const struct gguf_context * ctx) {
+    return ctx->kv.size();
+}
+
+int64_t gguf_find_key(const struct gguf_context * ctx, const char * key) {
+    // return -1 if key not found
+    int64_t keyfound = -1;
+
+    const int64_t n_kv = gguf_get_n_kv(ctx);
+
+    for (int64_t i = 0; i < n_kv; ++i) {
+        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
+            keyfound = i;
+            break;
+        }
+    }
+
+    return keyfound;
+}
+
+const char * gguf_get_key(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].get_key().c_str();
+}
+
+enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].is_array ? GGUF_TYPE_ARRAY : ctx->kv[key_id].get_type();
+}
+
+enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].is_array);
+    return ctx->kv[key_id].get_type();
+}
+
+const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data.data();
+}
+
+const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data_string[i].c_str();
+}
+
+size_t gguf_get_arr_n(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+
+    if (ctx->kv[key_id].type == GGUF_TYPE_STRING) {
+        return ctx->kv[key_id].data_string.size();
+    }
+
+    const size_t type_size = gguf_type_size(ctx->kv[key_id].type);
+    GGML_ASSERT(ctx->kv[key_id].data.size() % type_size == 0);
+    return ctx->kv[key_id].data.size() / type_size;
+}
+
+uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int8_t gguf_get_val_i8(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int16_t gguf_get_val_i16(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int32_t gguf_get_val_i32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+float gguf_get_val_f32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int64_t gguf_get_val_i64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+double gguf_get_val_f64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val().c_str();
+}
+
+const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data.data();
+}
+
+int64_t gguf_get_n_tensors(const struct gguf_context * ctx) {
+    return ctx->info.size();
+}
+
+int64_t gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
+    // return -1 if tensor not found
+    int64_t tensor_id = -1;
+
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
+            tensor_id = i;
+            break;
+        }
+    }
+
+    return tensor_id;
+}
+
+size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].offset;
+}
+
+const char * gguf_get_tensor_name(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].t.name;
+}
+
+enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].t.type;
+}
+
+size_t gguf_get_tensor_size(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ggml_nbytes(&ctx->info[tensor_id].t);
+}
+
+int64_t gguf_remove_key(struct gguf_context * ctx, const char * key) {
+    const int64_t key_id = gguf_find_key(ctx, key);
+    if (key_id >= 0) {
+        ctx->kv.erase(ctx->kv.begin() + key_id);
+    }
+    return key_id;
+}
+
+template
+static void gguf_check_reserved_keys(const std::string & key, const T val) {
+    if (key == GGUF_KEY_GENERAL_ALIGNMENT) {
+        if constexpr (std::is_same::value) {
+            GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2");
+        } else {
+            GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32");
+        }
+    }
+}
+
+void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, std::string(val));
+}
+
+void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n) {
+    gguf_check_reserved_keys(key, data);
+    gguf_remove_key(ctx, key);
+
+    const size_t nbytes = n*gguf_type_size(type);
+    std::vector tmp(nbytes);
+    if (!tmp.empty()) {
+        memcpy(tmp.data(), data, nbytes);
+    }
+    ctx->kv.emplace_back(key, tmp);
+    ctx->kv.back().cast(type);
+}
+
+void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, size_t n) {
+    gguf_check_reserved_keys(key, data);
+    gguf_remove_key(ctx, key);
+
+    std::vector tmp(n);
+    for (size_t i = 0; i < n; ++i) {
+        tmp[i] = data[i];
+    }
+    ctx->kv.emplace_back(key, tmp);
+}
+
+// set or add KV pairs from another context
+void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src) {
+    const int64_t n_kv = gguf_get_n_kv(src);
+    for (int64_t i = 0; i < n_kv; ++i) {
+        const struct gguf_kv & kv = src->kv[i];
+
+        if (!kv.is_array) {
+            switch (kv.get_type()) {
+                case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, kv.get_key().c_str(), kv.get_val());              break;
+                case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, kv.get_key().c_str(), kv.get_val());            break;
+                case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, kv.get_key().c_str(), kv.get_val());            break;
+                case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, kv.get_key().c_str(), kv.get_val());               break;
+                case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, kv.get_key().c_str(), kv.get_val());            break;
+                case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, kv.get_key().c_str(), kv.get_val());              break;
+                case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, kv.get_key().c_str(), kv.get_val());                break;
+                case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, kv.get_key().c_str(), kv.get_val().c_str()); break;
+                case GGUF_TYPE_ARRAY:
+                default: GGML_ABORT("invalid type");
+            }
+            continue;
+        }
+
+        const size_t ne = kv.get_ne();
+
+        switch (kv.get_type()) {
+            case GGUF_TYPE_UINT8:
+            case GGUF_TYPE_INT8:
+            case GGUF_TYPE_UINT16:
+            case GGUF_TYPE_INT16:
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:
+            case GGUF_TYPE_FLOAT32:
+            case GGUF_TYPE_UINT64:
+            case GGUF_TYPE_INT64:
+            case GGUF_TYPE_FLOAT64:
+            case GGUF_TYPE_BOOL: {
+                gguf_set_arr_data(ctx, kv.get_key().c_str(), kv.get_type(), kv.data.data(), ne);
+            } break;
+            case GGUF_TYPE_STRING: {
+                std::vector tmp(ne);
+                for (size_t j = 0; j < ne; ++j) {
+                    tmp[j] = kv.data_string[j].c_str();
+                }
+                gguf_set_arr_str(ctx, kv.get_key().c_str(), tmp.data(), ne);
+            } break;
+            case GGUF_TYPE_ARRAY:
+            default: GGML_ABORT("invalid type");
+        }
+    }
+}
+
+void gguf_add_tensor(
+             struct gguf_context * ctx,
+        const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
+    if (gguf_find_tensor(ctx, tensor->name) != -1) {
+        GGML_ABORT("duplicate tensor name: %s", tensor->name);
+    }
+
+    struct gguf_tensor_info ti;
+    ti.t = *tensor;
+    ti.offset = ctx->info.empty() ? 0 :
+        ctx->info.back().offset + GGML_PAD(ggml_nbytes(&ctx->info.back().t), ctx->alignment);
+    ctx->info.push_back(ti);
+}
+
+void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
+    const int64_t tensor_id = gguf_find_tensor(ctx, name);
+    if (tensor_id < 0) {
+        GGML_ABORT("tensor not found: %s", name);
+    }
+    struct ggml_tensor * tensor = &ctx->info[tensor_id].t;
+    const size_t  type_size = ggml_type_size(type);
+    const int64_t blck_size = ggml_blck_size(type);
+
+    tensor->type = type;
+    GGML_ASSERT(tensor->ne[0] % blck_size == 0 && "tensor row size not divisible by block size of new type");
+
+    tensor->nb[0] = type_size;
+    tensor->nb[1] = tensor->nb[0]*(tensor->ne[0]/blck_size);
+    for (int i = 2; i < GGML_MAX_DIMS; i++) {
+        tensor->nb[i] = tensor->nb[i - 1]*tensor->ne[i - 1];
+    }
+
+    // update offsets
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+    for (int64_t i = tensor_id + 1; i < n_tensors; ++i) {
+        ctx->info[i].offset = ctx->info[i - 1].offset + GGML_PAD(ggml_nbytes(&ctx->info[i - 1].t), ctx->alignment);
+    }
+}
+
+void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data) {
+    const int64_t tensor_id = gguf_find_tensor(ctx, name);
+    if (tensor_id < 0) {
+        GGML_ABORT("tensor not found: %s", name);
+    }
+
+    ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const
+}
+
+struct gguf_writer {
+    std::vector & buf;
+
+    gguf_writer(std::vector & buf) : buf(buf) {}
+
+    template 
+    void write(const T & val) const {
+        for (size_t i = 0; i < sizeof(val); ++i) {
+            buf.push_back(reinterpret_cast(&val)[i]);
+        }
+    }
+
+    void write(const std::vector & val) const {
+        buf.insert(buf.end(), val.begin(), val.end());
+    }
+
+    void write(const bool & val) const {
+        const int8_t val8 = val ? 1 : 0;
+        write(val8);
+    }
+
+    void write(const std::string & val) const {
+        {
+            const uint64_t n = val.length();
+            write(n);
+        }
+        for (size_t i = 0; i < val.length(); ++i) {
+            buf.push_back(reinterpret_cast(val.data())[i]);
+        }
+    }
+
+    void write(const char * val) const {
+        write(std::string(val));
+    }
+
+    void write(const enum ggml_type & val) const {
+        write(int32_t(val));
+    }
+
+    void write(const enum gguf_type & val) const {
+        write(int32_t(val));
+    }
+
+    void write(const struct gguf_kv & kv) const {
+        const uint64_t ne = kv.get_ne();
+
+        write(kv.get_key());
+
+        if (kv.is_array) {
+            write(GGUF_TYPE_ARRAY);
+            write(kv.get_type());
+            write(ne);
+        } else {
+            write(kv.get_type());
+        }
+
+        switch (kv.get_type()) {
+            case GGUF_TYPE_UINT8:
+            case GGUF_TYPE_INT8:
+            case GGUF_TYPE_UINT16:
+            case GGUF_TYPE_INT16:
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:
+            case GGUF_TYPE_FLOAT32:
+            case GGUF_TYPE_UINT64:
+            case GGUF_TYPE_INT64:
+            case GGUF_TYPE_FLOAT64: {
+                write(kv.data);
+            } break;
+            case GGUF_TYPE_BOOL: {
+                for (size_t i = 0; i < ne; ++i) {
+                    write(kv.get_val(i));
+                }
+            } break;
+            case GGUF_TYPE_STRING: {
+                for (size_t i = 0; i < ne; ++i) {
+                    write(kv.get_val(i));
+                }
+            } break;
+            case GGUF_TYPE_ARRAY:
+            default: GGML_ABORT("invalid type");
+        }
+    }
+
+    void write_tensor_meta(const struct gguf_tensor_info & info) const {
+        write(info.t.name);
+
+        const uint32_t n_dims = ggml_n_dims(&info.t);
+        write(n_dims);
+
+        for (uint32_t j = 0; j < n_dims; ++j) {
+            write(info.t.ne[j]);
+        }
+        write(info.t.type);
+        write(info.offset);
+    }
+
+    void pad(const size_t alignment) const {
+        while (buf.size() % alignment != 0) {
+            const int8_t zero = 0;
+            write(zero);
+        }
+    }
+
+    void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const {
+        GGML_ASSERT(buf.size() - offset_data == info.offset);
+
+        GGML_ASSERT(ggml_is_contiguous(&info.t));
+        const size_t offset = buf.size();
+        const size_t nbytes = ggml_nbytes(&info.t);
+
+        buf.resize(offset + nbytes);
+        if (info.t.buffer) {
+            ggml_backend_tensor_get(&info.t, buf.data() + offset, 0, nbytes);
+        } else {
+            GGML_ASSERT(info.t.data);
+            memcpy(buf.data() + offset, info.t.data, nbytes);
+        }
+
+        pad(alignment);
+    }
+};
+
+void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) {
+    const struct gguf_writer gw(buf);
+
+    const int64_t n_kv      = gguf_get_n_kv(ctx);
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+
+    // write header
+    gw.write(GGUF_MAGIC[0]);
+    gw.write(GGUF_MAGIC[1]);
+    gw.write(GGUF_MAGIC[2]);
+    gw.write(GGUF_MAGIC[3]);
+    gw.write(ctx->version);
+    gw.write(n_tensors);
+    gw.write(n_kv);
+
+    // write key-value pairs
+    for (int64_t i = 0; i < n_kv; ++i) {
+        gw.write(ctx->kv[i]);
+    }
+
+    // write tensor info
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        gw.write_tensor_meta(ctx->info[i]);
+    }
+
+    // we require the data section to be aligned
+    gw.pad(ctx->alignment);
+
+    if (only_meta) {
+        return;
+    }
+
+    const size_t offset_data = gw.buf.size();
+
+    // write tensor data
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        gw.write_tensor_data(ctx->info[i], offset_data, ctx->alignment);
+    }
+}
+
+bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
+    FILE * file = ggml_fopen(fname, "wb");
+
+    if (!file) {
+        fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
+        return false;
+    }
+
+    std::vector buf;
+    gguf_write_to_buf(ctx, buf, only_meta);
+    const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size();
+    fclose(file);
+    return ok;
+}
+
+size_t gguf_get_meta_size(const struct gguf_context * ctx) {
+    // only return size
+    std::vector buf;
+    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);
+    return buf.size();
+}
+
+void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
+    std::vector buf;
+    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);
+    memcpy(data, buf.data(), buf.size());
+}