diff --git a/Makefile.sync b/Makefile.sync index 4a6a1039d..949ade809 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggerganov/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=71e90e8813f90097701e62f7fce137d96ddf41e2 +FETCH_HEAD=2016f07bd106c73699ecbaace80f55db5ed95dac .PHONY: help help: diff --git a/llama/build-info.cpp b/llama/build-info.cpp index 06deb0dde..be908c364 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "71e90e8813f90097701e62f7fce137d96ddf41e2"; +char const *LLAMA_COMMIT = "2016f07bd106c73699ecbaace80f55db5ed95dac"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/examples/llava/clip-impl.h b/llama/llama.cpp/examples/llava/clip-impl.h index 4d7340a56..180ae9880 100644 --- a/llama/llama.cpp/examples/llava/clip-impl.h +++ b/llama/llama.cpp/examples/llava/clip-impl.h @@ -50,7 +50,6 @@ // tensor name constants // -#define TN_TOKEN_EMBD "%s.token_embd.weight" #define TN_POS_EMBD "%s.position_embd.weight" #define TN_CLASS_EMBD "v.class_embd" #define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat @@ -66,8 +65,6 @@ #define TN_LN_2 "%s.blk.%d.ln2.%s" #define TN_LN_PRE "%s.pre_ln.%s" #define TN_LN_POST "%s.post_ln.%s" -#define TN_TEXT_PROJ "text_projection.weight" -#define TN_VIS_PROJ "visual_projection.weight" #define TN_LLAVA_PROJ "mm.%d.%s" #define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s" #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" diff --git a/llama/llama.cpp/examples/llava/clip.cpp b/llama/llama.cpp/examples/llava/clip.cpp index 4b72ea9fd..d57b4bd6e 100644 --- a/llama/llama.cpp/examples/llava/clip.cpp +++ b/llama/llama.cpp/examples/llava/clip.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #if defined(_WIN32) #define WIN32_LEAN_AND_MEAN @@ -1719,45 +1720,6 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length return true; } -// Linear interpolation between two points -inline float clip_lerp(float s, float e, float t) { - return s + (e - s) * t; -} -// Bilinear resize function -static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) { - dst.nx = target_width; - dst.ny = target_height; - dst.buf.resize(3 * target_width * target_height); - - float x_ratio = static_cast(src.nx - 1) / target_width; - float y_ratio = static_cast(src.ny - 1) / target_height; - - for (int y = 0; y < target_height; y++) { - for (int x = 0; x < target_width; x++) { - float px = x_ratio * x; - float py = y_ratio * y; - int x_floor = static_cast(px); - int y_floor = static_cast(py); - float x_lerp = px - x_floor; - float y_lerp = py - y_floor; - - for (int c = 0; c < 3; c++) { - float top = clip_lerp( - static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]), - static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]), - x_lerp - ); - float bottom = clip_lerp( - static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]), - static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]), - x_lerp - ); - dst.buf[3 * (y * target_width + x) + c] = static_cast(clip_lerp(top, bottom, y_lerp)); - } - } - } -} - // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) { dst.nx = src.nx; @@ -1771,336 +1733,489 @@ static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 } } -inline int clip(int x, int lower, int upper) { - return std::max(lower, std::min(x, upper)); -} +// set of tools to manupulate images +// in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv +struct image_manipulation { + // Bilinear resize function + static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) { + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); -static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { - const int nx = img.nx; - const int ny = img.ny; + float x_ratio = static_cast(src.nx - 1) / target_width; + float y_ratio = static_cast(src.ny - 1) / target_height; - dst.nx = target_width; - dst.ny = target_height; - dst.buf.resize(3 * target_width * target_height); + for (int y = 0; y < target_height; y++) { + for (int x = 0; x < target_width; x++) { + float px = x_ratio * x; + float py = y_ratio * y; + int x_floor = static_cast(px); + int y_floor = static_cast(py); + float x_lerp = px - x_floor; + float y_lerp = py - y_floor; - float Cc; - float C[5]; - float d0, d2, d3, a0, a1, a2, a3; - int i, j, k, jj; - int x, y; - float dx, dy; - float tx, ty; - - tx = (float)nx / (float)target_width; - ty = (float)ny / (float)target_height; - - // Bicubic interpolation; adapted from ViT.cpp, inspired from : - // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 - // -> https://en.wikipedia.org/wiki/Bicubic_interpolation - - for (i = 0; i < target_height; i++) { - for (j = 0; j < target_width; j++) { - x = (int)(tx * j); - y = (int)(ty * i); - - dx = tx * j - x; - dy = ty * i - y; - - for (k = 0; k < 3; k++) { - for (jj = 0; jj <= 3; jj++) { - d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - - a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; - a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; - a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; - - C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; - - d0 = C[0] - C[1]; - d2 = C[2] - C[1]; - d3 = C[3] - C[1]; - a0 = C[1]; - a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; - a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; - a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; - Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; - - const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); - dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + for (int c = 0; c < 3; c++) { + float top = lerp( + static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]), + static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]), + x_lerp + ); + float bottom = lerp( + static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]), + static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]), + x_lerp + ); + dst.buf[3 * (y * target_width + x) + c] = static_cast(lerp(top, bottom, y_lerp)); } } } } - return true; -} + // Bicubic resize function + // part of image will be cropped if the aspect ratio is different + static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { + const int nx = img.nx; + const int ny = img.ny; -// llava-1.6 type of resize_and_pad (black) -static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair& target_resolution) { - int target_width = target_resolution.first; - int target_height = target_resolution.second; + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); - float scale_w = static_cast(target_width) / image.nx; - float scale_h = static_cast(target_height) / image.ny; + float Cc; + float C[5]; + float d0, d2, d3, a0, a1, a2, a3; + int i, j, k, jj; + int x, y; + float dx, dy; + float tx, ty; - int new_width, new_height; + tx = (float)nx / (float)target_width; + ty = (float)ny / (float)target_height; - if (scale_w < scale_h) { - new_width = target_width; - new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); - } else { - new_height = target_height; - new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); + // Bicubic interpolation; adapted from ViT.cpp, inspired from : + // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 + // -> https://en.wikipedia.org/wiki/Bicubic_interpolation + + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { + x = (int)(tx * j); + y = (int)(ty * i); + + dx = tx * j - x; + dy = ty * i - y; + + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { + d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; + + d0 = C[0] - C[1]; + d2 = C[2] - C[1]; + d3 = C[3] - C[1]; + a0 = C[1]; + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; + + const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); + dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + } + } + } + } + + return true; } - clip_image_u8 resized_image; - // bilinear_resize(image, resized_image, new_width, new_height); - bicubic_resize(image, resized_image, new_width, new_height); + // llava-1.6 type of resize_and_pad + // if the ratio is not 1:1, padding with pad_color will be applied + // pad_color is single channel, default is 0 (black) + static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, std::array pad_color = {0, 0, 0}) { + int target_width = target_resolution.width; + int target_height = target_resolution.height; - clip_image_u8 padded_image; - padded_image.nx = target_width; - padded_image.ny = target_height; - padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black + float scale_w = static_cast(target_width) / image.nx; + float scale_h = static_cast(target_height) / image.ny; - // Calculate padding offsets - int pad_x = (target_width - new_width) / 2; - int pad_y = (target_height - new_height) / 2; + int new_width, new_height; - // Copy the resized image into the center of the padded buffer - for (int y = 0; y < new_height; ++y) { - for (int x = 0; x < new_width; ++x) { - for (int c = 0; c < 3; ++c) { - padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; + if (scale_w < scale_h) { + new_width = target_width; + new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); + } else { + new_height = target_height; + new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); + } + + clip_image_u8 resized_image; + bicubic_resize(image, resized_image, new_width, new_height); + + clip_image_u8 padded_image; + padded_image.nx = target_width; + padded_image.ny = target_height; + padded_image.buf.resize(3 * target_width * target_height); + + // Fill the padded image with the fill color + for (size_t i = 0; i < padded_image.buf.size(); i += 3) { + padded_image.buf[i] = pad_color[0]; + padded_image.buf[i + 1] = pad_color[1]; + padded_image.buf[i + 2] = pad_color[2]; + } + + // Calculate padding offsets + int pad_x = (target_width - new_width) / 2; + int pad_y = (target_height - new_height) / 2; + + // Copy the resized image into the center of the padded buffer + for (int y = 0; y < new_height; ++y) { + for (int x = 0; x < new_width; ++x) { + for (int c = 0; c < 3; ++c) { + padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; + } + } + } + dst = std::move(padded_image); + } + + static void crop_image(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) { + dst.nx = w; + dst.ny = h; + dst.buf.resize(3 * w * h); + + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int src_idx = 3 * ((y + i)*image.nx + (x + j)); + int dst_idx = 3 * (i*w + j); + dst.buf[dst_idx] = image.buf[src_idx]; + dst.buf[dst_idx + 1] = image.buf[src_idx + 1]; + dst.buf[dst_idx + 2] = image.buf[src_idx + 2]; } } } - image_output = std::move(padded_image); -} + +private: + static inline int clip(int x, int lower, int upper) { + return std::max(lower, std::min(x, upper)); + } + + // Linear interpolation between two points + static inline float lerp(float s, float e, float t) { + return s + (e - s) * t; + } +}; /** - * Selects the best resolution from a list of possible resolutions based on the original size. + * implementation of LLaVA-UHD: + * - https://arxiv.org/pdf/2403.11703 + * - https://github.com/thunlp/LLaVA-UHD + * - https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118 * - * @param original_size The original size of the image in the format (width, height). - * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. - * @return The best fit resolution in the format (width, height). + * overview: + * - an image always have a single overview (downscaled image) + * - an image can have 0 or multiple slices, depending on the image size + * - each slice can then be considered as a separate image + * + * for example: + * + * [overview] --> [slice 1] --> [slice 2] + * | | + * +--> [slice 3] --> [slice 4] */ -static std::pair select_best_resolution(const std::pair & original_size, const std::vector> & possible_resolutions) { - int original_width = original_size.first; - int original_height = original_size.second; - std::pair best_fit; - int max_effective_resolution = 0; - int min_wasted_resolution = std::numeric_limits::max(); +struct llava_uhd { + struct slice_coordinates { + int x; + int y; + clip_image_size size; + }; - for (const auto& resolution : possible_resolutions) { - int width = resolution.first; - int height = resolution.second; - float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); - int downscaled_width = static_cast(original_width * scale); - int downscaled_height = static_cast(original_height * scale); - int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); - int wasted_resolution = (width * height) - effective_resolution; - // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); - if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { - max_effective_resolution = effective_resolution; - min_wasted_resolution = wasted_resolution; - best_fit = resolution; + struct slice_instructions { + clip_image_size overview_size; // size of downscaled image + clip_image_size refined_size; // size of image right before slicing (must be multiple of slice size) + clip_image_size grid_size; // grid_size.width * grid_size.height = number of slices + std::vector slices; + bool padding_refined = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6) + }; + + static int get_max_slices(struct clip_ctx * ctx) { + if (clip_is_minicpmv(ctx)) { + return 9; } + return 0; } - return best_fit; -} + static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) { + slice_instructions res; + const int patch_size = clip_get_patch_size(ctx); + const int slice_size = clip_get_image_size(ctx); + const int max_slice_nums = get_max_slices(ctx); + const int original_width = original_size.width; + const int original_height = original_size.height; + const float log_ratio = log((float)original_width / original_height); + const float ratio = (float)original_width * original_height / (slice_size * slice_size); + const int multiple = fmin(ceil(ratio), max_slice_nums); + const bool has_slices = (multiple > 1); + const bool has_pinpoints = !ctx->vision_model.hparams.image_grid_pinpoints.empty(); -static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { - std::vector patches; - int width = image.nx; - int height = image.ny; - for (int i = 0; i < height; i += patch_size) { - for (int j = 0; j < width; j += patch_size) { - clip_image_u8_ptr patch(clip_image_u8_init()); - patch->nx = std::min(patch_size, width - j); - patch->ny = std::min(patch_size, height - i); - patch->buf.resize(3 * patch->nx * patch->ny); - for (int y = 0; y < patch->ny; ++y) { - for (int x = 0; x < patch->nx; ++x) { - for (int c = 0; c < 3; ++c) { - patch->buf[3 * (y * patch->nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c]; + if (has_pinpoints) { + // has pinpoints, use them to calculate the grid size (e.g. llava-1.6) + auto refine_size = llava_uhd::select_best_resolution( + ctx->vision_model.hparams.image_grid_pinpoints, + original_size); + res.overview_size = clip_image_size{slice_size, slice_size}; + res.refined_size = refine_size; + res.grid_size = clip_image_size{0, 0}; + res.padding_refined = true; + + for (int y = 0; y < refine_size.height; y += slice_size) { + for (int x = 0; x < refine_size.width; x += slice_size) { + slice_coordinates slice; + slice.x = x; + slice.y = y; + slice.size.width = std::min(slice_size, refine_size.width - x); + slice.size.height = std::min(slice_size, refine_size.height - y); + res.slices.push_back(slice); + if (x == 0) { + res.grid_size.width++; } } + res.grid_size.height++; } - patches.push_back(std::move(patch)); + + return res; } - } - return patches; -} -static int ensure_divide(int length, int patch_size) { - return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); -} + // no pinpoints, dynamically calculate the grid size (e.g. minicpmv) -static std::pair uhd_find_best_resize(std::pair original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { - int width = original_size.first; - int height = original_size.second; - if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { - float r = static_cast(width) / height; - height = static_cast(scale_resolution / std::sqrt(r)); - width = static_cast(height * r); - } - int best_width = ensure_divide(width, patch_size); - int best_height = ensure_divide(height, patch_size); - return std::make_pair(best_width, best_height); -} + auto best_size = get_best_resize(original_size, slice_size, patch_size, has_slices); + res.overview_size = best_size; -static std::pair uhd_get_refine_size(std::pair original_size, std::pair grid, int scale_resolution, int patch_size, bool allow_upscale = false) { - int width, height; - std::tie(width, height) = original_size; - int grid_x, grid_y; - std::tie(grid_x, grid_y) = grid; + if (!has_slices) { + // skip slicing logic + res.refined_size = clip_image_size{0, 0}; + res.grid_size = clip_image_size{0, 0}; - int refine_width = ensure_divide(width, grid_x); - int refine_height = ensure_divide(height, grid_y); + } else { + auto best_grid = get_best_grid(max_slice_nums, multiple, log_ratio); + auto refine_size = get_refine_size(original_size, best_grid, slice_size, patch_size, true); + res.grid_size = best_grid; + res.refined_size = refine_size; - int grid_width = refine_width / grid_x; - int grid_height = refine_height / grid_y; - - // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) - auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair - int best_grid_width, best_grid_height; - std::tie(best_grid_width, best_grid_height) = best_grid_size; - - // std::pair refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) - std::pair refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line) - return refine_size; -} - -static std::pair uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { - std::vector candidate_split_grids_nums; - for (int i : {multiple - 1, multiple, multiple + 1}) { - if (i == 1 || i > max_slice_nums) { - continue; - } - candidate_split_grids_nums.push_back(i); - } - - std::vector> candidate_grids; - for (int split_grids_nums : candidate_split_grids_nums) { - int m = 1; - while (m <= split_grids_nums) { - if (split_grids_nums % m == 0) { - candidate_grids.emplace_back(m, split_grids_nums / m); - } - ++m; - } - } - - std::pair best_grid{1, 1}; - float min_error = std::numeric_limits::infinity(); - for (const auto& grid : candidate_grids) { - float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second)); - if (error < min_error) { - best_grid = grid; - min_error = error; - } - } - return best_grid; -} - -// inspired from LLaVA-UHD: -// -> https://arxiv.org/pdf/2403.11703 -// -> https://github.com/thunlp/LLaVA-UHD -// -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118 -static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { - const std::pair original_size={img->nx,img->ny}; - const int original_width = img->nx; - const int original_height = img->ny; - const float log_ratio = log(1.0*original_width/original_height); - const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); - const int multiple = fmin(ceil(ratio), max_slice_nums); - - std::vector> images; - LOG_DBG("%s: multiple %d\n", __func__, multiple); - images.push_back(std::vector()); - - if (multiple <= 1) { - auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true); - clip_image_u8_ptr source_image(clip_image_u8_init()); - bicubic_resize(*img, *source_image, best_size.first, best_size.second); - // source_image = image.resize(best_size, Image.Resampling.BICUBIC) - images.back().push_back(std::move(source_image)); - } - else if (multiple > 1) { - auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size); - clip_image_u8_ptr source_image(clip_image_u8_init()); - bicubic_resize(*img, *source_image, best_size.first, best_size.second); - // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) - LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); - images.back().push_back(std::move(source_image)); - - std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); - LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); - - auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); - clip_image_u8_ptr refine_image(clip_image_u8_init()); - bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); - - LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); - - // split_to_patches - int width = refine_image->nx; - int height = refine_image->ny; - int grid_x = int(width / best_grid.first); - int grid_y = int(height / best_grid.second); - for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ - images.push_back(std::vector()); - for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ - clip_image_u8_ptr patch(clip_image_u8_init()); - patch->nx = grid_x; - patch->ny = grid_y; - patch->buf.resize(3 * patch->nx * patch->ny); - for (int y = patches_i; y < patches_i + grid_y; ++y) { - for (int x = patches_j; x < patches_j + grid_x; ++x) { - const int i = 3 * (y * refine_image->nx + x); - const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j)); - patch->buf[j] = refine_image->buf[i]; - patch->buf[j+1] = refine_image->buf[i+1]; - patch->buf[j+2] = refine_image->buf[i+2]; - } + int width = refine_size.width; + int height = refine_size.height; + int grid_x = int(width / best_grid.width); + int grid_y = int(height / best_grid.height); + for (int patches_y = 0, ic = 0; + patches_y < refine_size.height && ic < best_grid.height; + patches_y += grid_y, ic += 1) { + for (int patches_x = 0, jc = 0; + patches_x < refine_size.width && jc < best_grid.width; + patches_x += grid_x, jc += 1) { + slice_coordinates slice; + slice.x = patches_x; + slice.y = patches_y; + slice.size.width = grid_x; + slice.size.height = grid_y; + res.slices.push_back(slice); + // LOG_INF("slice %d: %d %d %d %d\n", ic, patches_i, patches_j, grid_x, grid_y); } - images.back().push_back(std::move(patch)); } } - } - return images; -} + return res; + } + + static std::vector slice_image(const clip_image_u8 * img, const slice_instructions & inst) { + std::vector output; + + // resize to overview size + clip_image_u8_ptr resized_img(clip_image_u8_init()); + image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height); + output.push_back(std::move(resized_img)); + if (inst.slices.empty()) { + // no slices, just return the resized image + return output; + } + + // resize to refined size + clip_image_u8_ptr refined_img(clip_image_u8_init()); + if (inst.padding_refined) { + image_manipulation::resize_and_pad_image(*img, *refined_img, inst.refined_size); + } else { + image_manipulation::bilinear_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height); + } + + // create slices + for (const auto & slice : inst.slices) { + int x = slice.x; + int y = slice.y; + int w = slice.size.width; + int h = slice.size.height; + + clip_image_u8_ptr img_slice(clip_image_u8_init()); + image_manipulation::crop_image(*refined_img, *img_slice, x, y, w, h); + output.push_back(std::move(img_slice)); + } + + return output; + } + +private: + static clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.width; + int height = original_size.height; + if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { + float r = static_cast(width) / height; + height = static_cast(scale_resolution / std::sqrt(r)); + width = static_cast(height * r); + } + clip_image_size res; + res.width = ensure_divide(width, patch_size); + res.height = ensure_divide(height, patch_size); + return res; + } + + /** + * Selects the best resolution from a list of possible resolutions based on the original size. + * + * @param original_size The original size of the image + * @param possible_resolutions A list of possible resolutions + * @return The best fit resolution + */ + static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector & possible_resolutions) { + int original_width = original_size.width; + int original_height = original_size.height; + clip_image_size best_fit; + int max_effective_resolution = 0; + int min_wasted_resolution = std::numeric_limits::max(); + + for (const auto & resolution : possible_resolutions) { + int width = resolution.width; + int height = resolution.height; + float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); + int downscaled_width = static_cast(original_width * scale); + int downscaled_height = static_cast(original_height * scale); + int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); + int wasted_resolution = (width * height) - effective_resolution; + // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + best_fit = resolution; + } + } + + return best_fit; + } + + // used by llava 1.6 with custom list of pinpoints + static clip_image_size select_best_resolution(const std::vector & pinpoints, const clip_image_size & original_size) { + std::vector possible_resolutions; + for (size_t i = 0; i < pinpoints.size(); i += 2) { + possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]}); + } + return select_best_resolution(original_size, possible_resolutions); + } + + static int ensure_divide(int length, int patch_size) { + return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); + } + + static clip_image_size get_refine_size(const clip_image_size & original_size, const clip_image_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.width; + int height = original_size.height; + int grid_x = grid.width; + int grid_y = grid.height; + + int refine_width = ensure_divide(width, grid_x); + int refine_height = ensure_divide(height, grid_y); + + clip_image_size grid_size; + grid_size.width = refine_width / grid_x; + grid_size.height = refine_height / grid_y; + + auto best_grid_size = get_best_resize(grid_size, scale_resolution, patch_size, allow_upscale); + int best_grid_width = best_grid_size.width; + int best_grid_height = best_grid_size.height; + + clip_image_size refine_size; + refine_size.width = best_grid_width * grid_x; + refine_size.height = best_grid_height * grid_y; + return refine_size; + } + + static clip_image_size get_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { + std::vector candidate_split_grids_nums; + for (int i : {multiple - 1, multiple, multiple + 1}) { + if (i == 1 || i > max_slice_nums) { + continue; + } + candidate_split_grids_nums.push_back(i); + } + + std::vector candidate_grids; + for (int split_grids_nums : candidate_split_grids_nums) { + int m = 1; + while (m <= split_grids_nums) { + if (split_grids_nums % m == 0) { + candidate_grids.push_back(clip_image_size{m, split_grids_nums / m}); + } + ++m; + } + } + + clip_image_size best_grid{1, 1}; + float min_error = std::numeric_limits::infinity(); + for (const auto& grid : candidate_grids) { + float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height)); + if (error < min_error) { + best_grid = grid; + min_error = error; + } + } + return best_grid; + } +}; + +// TODO @ngxson : decprecate the load_image_size singleton pattern int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { - const int max_slice_nums=9; - const int scale_resolution=448; - const int original_width = ctx_clip->load_image_size.width; - const int original_height = ctx_clip->load_image_size.height; - const float log_ratio = log(1.0*original_width/original_height); - const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); - const int multiple = fmin(ceil(ratio), max_slice_nums); - std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); - return best_grid.first; + const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size); + return inst.grid_size.width; } // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { + if (!ctx->has_vision_encoder) { + LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); + return false; + } + + clip_image_size original_size{img->nx, img->ny}; + bool pad_to_square = true; + auto & params = ctx->vision_model.hparams; + // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing + if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) { + pad_to_square = false; + } if (clip_is_minicpmv(ctx)) { - int max_slice_nums = 9; - std::vector> imgs = uhd_slice_image(img, max_slice_nums); + auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); + std::vector imgs = llava_uhd::slice_image(img, inst); + for (size_t i = 0; i < imgs.size(); ++i) { - for (size_t j = 0; j < imgs[i].size(); ++j) { - LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); - clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std); - res_imgs->entries.push_back(std::move(res)); - } + // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); } return true; } @@ -2109,7 +2224,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str auto patch_size = clip_get_patch_size(ctx) * 2; int nx = ceil((float)img->nx / patch_size) * patch_size; int ny = ceil((float)img->ny / patch_size) * patch_size; - bicubic_resize(*img, resized, nx, ny); + image_manipulation::bicubic_resize(*img, resized, nx, ny); clip_image_f32_ptr img_f32(clip_image_f32_init()); // clip_image_f32_ptr res(clip_image_f32_init()); @@ -2121,8 +2236,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { clip_image_u8 resized_image; - int32_t sz=ctx->vision_model.hparams.image_size; - bicubic_resize(*img, resized_image,sz,sz); + int sz = params.image_size; + image_manipulation::bicubic_resize(*img, resized_image, sz, sz); clip_image_f32_ptr img_f32(clip_image_f32_init()); //clip_image_save_to_bmp(resized_image, "resized.bmp"); normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); @@ -2130,156 +2245,47 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str return true; } - bool pad_to_square = true; - if (!ctx->has_vision_encoder) { - LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); - return false; - } - auto & params = ctx->vision_model.hparams; - // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing - if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) { - pad_to_square = false; - } - // free the previous res_imgs if any set - res_imgs->entries.clear(); - // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily - if (pad_to_square && img->nx != img->ny) { - int longer_side = std::max(img->nx, img->ny); + + if (pad_to_square) { + // for llava-1.5, we resize image to a square, and pad the shorter side with a background color + // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 + const int longer_side = std::max(img->nx, img->ny); temp->nx = longer_side; temp->ny = longer_side; temp->buf.resize(3 * longer_side * longer_side); - const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) - // fill with background color - for (size_t i = 0; i < temp->buf.size(); i++) { - temp->buf[i] = bc[i % 3]; + // background color in RGB from LLaVA (this is the mean rgb color * 255) + const std::array pad_color = {122, 116, 104}; + + // resize the image to the target_size + image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color); + + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); + return true; + + } else if (!params.image_grid_pinpoints.empty()) { + // "spatial_unpad" with "anyres" processing for llava-1.6 + auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); + std::vector imgs = llava_uhd::slice_image(img, inst); + + for (size_t i = 0; i < imgs.size(); ++i) { + // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); } - // copy from the input image - for (int y = 0; y < img->ny; y++) { - for (int x = 0; x < img->nx; x++) { - const int i = 3 * (y * img->nx + x); - const int j = 3 * (y * temp->nx + x); - temp->buf[j] = img->buf[i]; - temp->buf[j+1] = img->buf[i+1]; - temp->buf[j+2] = img->buf[i+2]; - } - } - } else { - if (!params.image_grid_pinpoints.empty()) { - // "spatial_unpad" with "anyres" processing for llava-1.6 - std::vector> possible_resolutions; - for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) { - possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); - } - std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); - // clip_image_save_to_bmp(*img, "input.bmp"); - resize_and_pad_image(*img, *temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 - // clip_image_save_to_bmp(*temp, "resized.bmp"); - // visually verify normalized image: - // normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std); - // { - // clip_image_u8 * temp2 = clip_image_u8_init(); - // clip_image_convert_f32_to_u8(*res, *temp2); - // clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp"); - // clip_image_u8_free(temp2); - // } + return true; - std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) - - clip_image_u8_ptr image_original_resize(clip_image_u8_init()); - // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - patches.insert(patches.begin(), std::move(image_original_resize)); - for (auto & patch : patches) { - clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std); - res_imgs->entries.push_back(std::move(res)); - } - - return true; - } else { - temp->nx = img->nx; - temp->ny = img->ny; - temp->buf.resize(img->buf.size()); - memcpy(temp->buf.data(), img->buf.data(), temp->buf.size()); - } } - const int nx = temp->nx; - const int ny = temp->ny; - // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp"); - - const int nx2 = ctx->vision_model.hparams.image_size; - const int ny2 = ctx->vision_model.hparams.image_size; - clip_image_f32_ptr res(clip_image_f32_init()); - res->nx = nx2; - res->ny = ny2; - res->buf.resize(3 * nx2 * ny2); - - const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size; - - const int nx3 = int(nx / scale + 0.5f); - const int ny3 = int(ny / scale + 0.5f); - - const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; - const auto & s3 = ctx->image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; - - for (int y = 0; y < ny3; y++) { - for (int x = 0; x < nx3; x++) { - for (int c = 0; c < 3; c++) { - // linear interpolation - const float sx = (x + 0.5f) * scale - 0.5f; - const float sy = (y + 0.5f) * scale - 0.5f; - - const int x0 = std::max(0, (int)std::floor(sx)); - const int y0 = std::max(0, (int)std::floor(sy)); - - const int x1 = std::min(x0 + 1, nx - 1); - const int y1 = std::min(y0 + 1, ny - 1); - - const float dx = sx - x0; - const float dy = sy - y0; - - const int j00 = 3 * (y0 * nx + x0) + c; - const int j01 = 3 * (y0 * nx + x1) + c; - const int j10 = 3 * (y1 * nx + x0) + c; - const int j11 = 3 * (y1 * nx + x1) + c; - - const float v00 = temp->buf[j00]; - const float v01 = temp->buf[j01]; - const float v10 = temp->buf[j10]; - const float v11 = temp->buf[j11]; - - const float v0 = v00 * (1.0f - dx) + v01 * dx; - const float v1 = v10 * (1.0f - dx) + v11 * dx; - - const float v = v0 * (1.0f - dy) + v1 * dy; - - const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); - - const int i = 3 * (y * nx3 + x) + c; - - res->buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; - } - } - } - - // { - // clip_image_u8 * temp2 = clip_image_u8_init(); - // clip_image_convert_f32_to_u8(*res, *temp2); - // clip_image_save_to_bmp(*temp2, "resized_normalized_f32_vanilla.bmp"); - // clip_image_u8_free(temp2); - // } - // res_imgs.push_back(res); - - res_imgs->entries.push_back(std::move(res)); - - return true; + GGML_ASSERT(false && "Unknown image preprocessing type"); } ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index bdf3d898b..dd01df60a 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -145,6 +145,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, + { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, + { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -1142,6 +1144,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1636,23 +1640,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index ee081fbfe..b6227eebf 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -149,6 +149,8 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, + LLM_KV_ATTENTION_KEY_LENGTH_MLA, + LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -311,6 +313,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index d6e7b3af8..4b3e6a83e 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -10,6 +10,7 @@ #include #include #include +#include // // llama_context @@ -473,7 +474,6 @@ ggml_tensor * llama_context::build_rope_shift( const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_attn_factor = cparams.yarn_attn_factor; const auto & yarn_beta_fast = cparams.yarn_beta_fast; const auto & yarn_beta_slow = cparams.yarn_beta_slow; @@ -482,6 +482,10 @@ ggml_tensor * llama_context::build_rope_shift( const auto & n_rot = hparams.n_rot; const auto & rope_type = hparams.rope_type; + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; + ggml_tensor * tmp; if (ggml_is_quantized(cur->type)) { diff --git a/llama/llama.cpp/src/llama-graph.cpp b/llama/llama.cpp/src/llama-graph.cpp index 83f3c5a86..d740c1200 100644 --- a/llama/llama.cpp/src/llama-graph.cpp +++ b/llama/llama.cpp/src/llama-graph.cpp @@ -1194,6 +1194,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * v, ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mla, bool v_trans, float kq_scale) const { //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); @@ -1205,8 +1206,6 @@ ggml_tensor * llm_graph_context::build_attn_mha( //const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; - const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0]; - const auto n_tokens = q->ne[1]; const auto n_head = q->ne[2]; const auto n_kv = k->ne[1]; @@ -1235,7 +1234,12 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); - cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); + if (v_mla) { + cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); + cur = ggml_mul_mat(ctx0, v_mla, cur); + } + + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); @@ -1273,9 +1277,14 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA + if (v_mla) { + kqv = ggml_mul_mat(ctx0, v_mla, kqv); + } - cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); if (!cparams.offload_kqv) { // all nodes between the KV store and the attention output are run on the CPU @@ -1310,6 +1319,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { GGML_UNUSED(n_tokens); @@ -1331,7 +1341,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); cb(cur, "kqv_out", il); @@ -1385,6 +1395,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1470,7 +1481,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, 0); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1529,6 +1540,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1548,7 +1560,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); cb(cur, "kqv_out", il); @@ -1717,4 +1729,3 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } - diff --git a/llama/llama.cpp/src/llama-graph.h b/llama/llama.cpp/src/llama-graph.h index 51993998a..260a2af21 100644 --- a/llama/llama.cpp/src/llama-graph.h +++ b/llama/llama.cpp/src/llama-graph.h @@ -517,11 +517,12 @@ struct llm_graph_context { ggml_tensor * build_attn_mha( ggml_cgraph * gf, - ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] - ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] - ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) + ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] + ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] + ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] bool v_trans, float kq_scale) const; @@ -536,6 +537,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -550,6 +552,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -564,6 +567,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index 4567a0e90..c8a34d521 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -46,6 +46,10 @@ struct llama_hparams { uint32_t n_rel_attn_bkts = 0; uint32_t n_vocab = 0; + // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + uint32_t n_embd_head_k_mla = 0; + uint32_t n_embd_head_v_mla = 0; + // for WavTokenizer struct llama_hparams_posnet posnet; struct llama_hparams_convnext convnext; diff --git a/llama/llama.cpp/src/llama-kv-cache.cpp b/llama/llama.cpp/src/llama-kv-cache.cpp index 5c941e7c4..35a750d39 100644 --- a/llama/llama.cpp/src/llama-kv-cache.cpp +++ b/llama/llama.cpp/src/llama-kv-cache.cpp @@ -27,7 +27,7 @@ bool llama_kv_cache_unified::init( recurrent = llama_model_is_recurrent(&model); v_trans = !recurrent && !cparams.flash_attn; - can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + can_shift = !recurrent; LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index cd1d239c8..c8374159f 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -1170,6 +1170,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); @@ -3281,8 +3283,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { const bool is_lite = (hparams.n_layer == 27); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; const int64_t q_lora_rank = hparams.n_lora_q; const int64_t kv_lora_rank = hparams.n_lora_kv; @@ -3308,14 +3316,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!is_lite) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); } - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + if (is_mla) { + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } else { + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4394,6 +4410,8 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); @@ -4600,7 +4618,7 @@ struct llm_build_llama : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -4903,14 +4921,14 @@ struct llm_build_mllama: public llm_graph_context { n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - + cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -5053,7 +5071,7 @@ struct llm_build_deci : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -5195,7 +5213,7 @@ struct llm_build_baichuan : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5310,7 +5328,7 @@ struct llm_build_xverse : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5435,7 +5453,7 @@ struct llm_build_falcon : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5565,7 +5583,7 @@ struct llm_build_grok : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -5716,7 +5734,7 @@ struct llm_build_dbrx : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5830,7 +5848,7 @@ struct llm_build_starcoder : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5929,7 +5947,7 @@ struct llm_build_refact : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6083,7 +6101,7 @@ struct llm_build_bert : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { @@ -6200,7 +6218,7 @@ struct llm_build_bloom : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6341,7 +6359,7 @@ struct llm_build_mpt : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6487,7 +6505,7 @@ struct llm_build_stablelm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6610,7 +6628,7 @@ struct llm_build_qwen : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6730,7 +6748,7 @@ struct llm_build_qwen2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6851,7 +6869,7 @@ struct llm_build_qwen2vl : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6978,7 +6996,7 @@ struct llm_build_qwen2moe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7131,7 +7149,7 @@ struct llm_build_qwen3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7252,7 +7270,7 @@ struct llm_build_qwen3moe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7392,7 +7410,7 @@ struct llm_build_phi2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -7521,7 +7539,7 @@ struct llm_build_phi3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -7656,7 +7674,7 @@ struct llm_build_plamo : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } ggml_tensor * sa_out = cur; @@ -7763,7 +7781,7 @@ struct llm_build_gpt2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7879,7 +7897,7 @@ struct llm_build_codeshell : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -8008,7 +8026,7 @@ struct llm_build_orion : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -8135,7 +8153,7 @@ struct llm_build_internlm2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -8332,7 +8350,7 @@ struct llm_build_minicpm3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -8462,7 +8480,7 @@ struct llm_build_gemma : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -8584,7 +8602,7 @@ struct llm_build_gemma2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } cur = build_norm(cur, @@ -8725,7 +8743,7 @@ struct llm_build_gemma3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); } cur = build_norm(cur, @@ -8865,7 +8883,7 @@ struct llm_build_starcoder2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9200,7 +9218,7 @@ struct llm_build_command_r : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9335,7 +9353,7 @@ struct llm_build_cohere2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9466,7 +9484,7 @@ struct llm_build_olmo : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9586,7 +9604,7 @@ struct llm_build_olmo2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } cur = build_norm(cur, @@ -9719,7 +9737,7 @@ struct llm_build_olmoe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9852,7 +9870,7 @@ struct llm_build_openelm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9966,7 +9984,7 @@ struct llm_build_gptneox : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -10116,7 +10134,7 @@ struct llm_build_arctic : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -10271,7 +10289,7 @@ struct llm_build_deepseek : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -10361,15 +10379,22 @@ struct llm_build_deepseek2 : public llm_graph_context { llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); - const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); - const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); - - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; - const uint32_t kv_lora_rank = hparams.n_lora_kv; + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k)); + const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); ggml_tensor * cur; ggml_tensor * inpL; @@ -10395,16 +10420,14 @@ struct llm_build_deepseek2 : public llm_graph_context { { ggml_tensor * q = NULL; if (!is_lite) { - // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); cb(q, "q", il); q = build_norm(q, - model.layers[il].attn_q_a_norm, NULL, + model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); cb(q, "q", il); - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); cb(q, "q", il); } else { @@ -10412,96 +10435,125 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(q, "q", il); } - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); cb(q_nope, "q_nope", il); - // and {n_head * n_embd_head_qk_rope, n_tokens} - ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, + n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); - // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} - ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); // split into {kv_lora_rank, n_tokens} - ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, + kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_compressed, "kv_compressed", il); + cb(kv_cmpr, "kv_cmpr", il); - // and {n_embd_head_qk_rope, n_tokens} - ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, + n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); - kv_compressed = build_norm(kv_compressed, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, il); - cb(kv_compressed, "kv_compressed", il); - - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); - - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); - - // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); - - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); - - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); + ext_factor, attn_factor, beta_fast, beta_slow + ); cb(q_pe, "q_pe", il); - // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); + ext_factor, attn_factor, beta_fast, beta_slow + ); cb(k_pe, "k_pe", il); - ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); + kv_cmpr = build_norm(kv_cmpr, + model.layers[il].attn_kv_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); - ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); + if (is_mla) { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il); + } else { + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); + cb(kv, "kv", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + 0); + cb(k_nope, "k_nope_view", il); + + // and {n_embd_head_v, n_head, n_tokens} + ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, + n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + ggml_row_size(kv->type, n_embd_head_qk_nope)); + cb(Vcur, "Vcur_view", il); + + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "Vcur_cont", il); + + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + cb(Kcur, "Kcur", il); + + // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + } } if (il == n_layer - 1) { @@ -10667,7 +10719,7 @@ struct llm_build_bitnet : public llm_graph_context { cur = build_attn(inp_attn, gf, NULL, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, model.layers[il].attn_sub_norm, NULL, @@ -10790,7 +10842,7 @@ struct llm_build_t5_enc : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -10896,7 +10948,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_self, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, kq_b, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -10928,7 +10980,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_cross, gf, model.layers[il].wo_cross, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -11061,7 +11113,7 @@ struct llm_build_jais : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/float(n_embd_head), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1) { @@ -11193,7 +11245,7 @@ struct llm_build_chatglm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11326,7 +11378,7 @@ struct llm_build_glm4 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11470,7 +11522,7 @@ struct llm_build_nemotron : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11601,7 +11653,7 @@ struct llm_build_exaone : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -12503,7 +12555,7 @@ struct llm_build_chameleon : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); if (hparams.swin_norm) { cur = build_norm(cur, @@ -12683,14 +12735,14 @@ struct llm_build_solar : public llm_graph_context { n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - + cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -13018,7 +13070,7 @@ struct llm_build_plm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -13141,7 +13193,7 @@ struct llm_build_bailingmoe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } if (il == n_layer - 1) { diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h index 21c4617bb..72bab5bee 100644 --- a/llama/llama.cpp/src/llama-model.h +++ b/llama/llama.cpp/src/llama-model.h @@ -174,6 +174,8 @@ struct llama_layer { struct ggml_tensor * wq_b = nullptr; struct ggml_tensor * wkv_a_mqa = nullptr; struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b = nullptr; + struct ggml_tensor * wv_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index c90f636ca..ba37df355 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -1833,6 +1833,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (false || t.first == "<|fim_prefix|>" // Qwen || t.first == "" + || t.first == "" // Granite || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
@@ -1851,6 +1852,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_suffix|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
@@ -1869,6 +1871,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_middle|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
@@ -1887,6 +1890,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_pad|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == ""
                         ) {
                     special_fim_pad_id = t.second;
@@ -1905,6 +1909,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|repo_name|>"
                         || t.first == ""
                         || t.first == ""
+                        || t.first == ""    // Granite
                         ) {
                     special_fim_rep_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
diff --git a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch
index 690f2e33c..8f5c3a779 100644
--- a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch
+++ b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch
@@ -65,10 +65,10 @@ index 273075f4..dd11f304 100644
      /* .init_tensor     = */ NULL, // no initialization required
      /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,
 diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
-index cec36b36..4b057973 100644
+index e2617b06..242e50a7 100644
 --- a/ggml/src/ggml-cann/ggml-cann.cpp
 +++ b/ggml/src/ggml-cann/ggml-cann.cpp
-@@ -530,6 +530,7 @@ static void ggml_backend_cann_buffer_free_buffer(
+@@ -800,6 +800,7 @@ static void ggml_backend_cann_buffer_free_buffer(
      ggml_backend_cann_buffer_context* ctx =
          (ggml_backend_cann_buffer_context*)buffer->context;
      delete ctx;
@@ -76,7 +76,7 @@ index cec36b36..4b057973 100644
  }
  
  /**
-@@ -1199,6 +1200,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
+@@ -1472,6 +1473,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
   */
  static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
      ACL_CHECK(aclrtFreeHost(buffer->context));
@@ -85,10 +85,10 @@ index cec36b36..4b057973 100644
  
  /**
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index fafe9633..59a49560 100644
+index a7febef7..31750b6f 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -533,6 +533,7 @@ struct ggml_backend_cuda_buffer_context {
+@@ -534,6 +534,7 @@ struct ggml_backend_cuda_buffer_context {
  static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
      delete ctx;
@@ -96,7 +96,7 @@ index fafe9633..59a49560 100644
  }
  
  static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
-@@ -788,6 +789,7 @@ struct ggml_backend_cuda_split_buffer_context {
+@@ -789,6 +790,7 @@ struct ggml_backend_cuda_split_buffer_context {
  static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
      delete ctx;
@@ -104,7 +104,7 @@ index fafe9633..59a49560 100644
  }
  
  static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
-@@ -1061,6 +1063,7 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
+@@ -1062,6 +1064,7 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
  
  static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      CUDA_CHECK(cudaFreeHost(buffer->context));
@@ -125,10 +125,10 @@ index 50579227..2799a0a5 100644
  
  static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index 9f1c6c6c..310afe8a 100644
+index 266d8af4..12886cd3 100644
 --- a/ggml/src/ggml-metal/ggml-metal.m
 +++ b/ggml/src/ggml-metal/ggml-metal.m
-@@ -4641,6 +4641,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
+@@ -4759,6 +4759,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
      }
  
      free(ctx);
@@ -137,10 +137,10 @@ index 9f1c6c6c..310afe8a 100644
  
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
-index b8b5cbd3..14d4561b 100644
+index 05a2f4e6..392cc18d 100644
 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp
 +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
-@@ -1443,6 +1443,7 @@ struct ggml_backend_opencl_buffer_context {
+@@ -1940,6 +1940,7 @@ struct ggml_backend_opencl_buffer_context {
  static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
      delete ctx;
@@ -149,10 +149,10 @@ index b8b5cbd3..14d4561b 100644
  
  static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
-index 862b9b66..34536681 100644
+index a0667b7d..bd83adc5 100644
 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp
 +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
-@@ -443,6 +443,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+@@ -468,6 +468,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
      GGML_ASSERT(status);
      delete ctx;
@@ -161,7 +161,7 @@ index 862b9b66..34536681 100644
  
  static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
-index 3e48a924..a3d182fc 100644
+index 1de34c96..4600f61e 100644
 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp
 +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
 @@ -316,6 +316,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
@@ -189,10 +189,10 @@ index 3e48a924..a3d182fc 100644
  
  static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
 diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-index 783a0ff8..8ac1e07e 100644
+index 39f3cd34..c569a8a5 100644
 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
 +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-@@ -8639,6 +8639,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+@@ -8653,6 +8653,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
      ggml_vk_destroy_buffer(ctx->dev_buffer);
      delete ctx;
@@ -200,7 +200,7 @@ index 783a0ff8..8ac1e07e 100644
  }
  
  static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
-@@ -8782,6 +8783,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
+@@ -8796,6 +8797,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
  static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
      ggml_vk_host_free(vk_instance.devices[0], buffer->context);
diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch
index 02790a51e..e51b43730 100644
--- a/llama/patches/0002-pretokenizer.patch
+++ b/llama/patches/0002-pretokenizer.patch
@@ -10,7 +10,7 @@ logs instead of throwing an error
  1 file changed, 3 insertions(+), 11 deletions(-)
 
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index 464ff01e..0125ee53 100644
+index 48060517..a35b498c 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
 @@ -1491,16 +1491,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
diff --git a/llama/patches/0003-embeddings.patch b/llama/patches/0003-embeddings.patch
index e8253e8e2..c27dbd7b5 100644
--- a/llama/patches/0003-embeddings.patch
+++ b/llama/patches/0003-embeddings.patch
@@ -11,10 +11,10 @@ instead of forcing one or the error
  1 file changed, 3 insertions(+), 3 deletions(-)
 
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index 4735e98e..65135172 100644
+index 983385f8..32f59819 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -1232,7 +1232,7 @@ int llama_context::decode(llama_batch & inp_batch) {
+@@ -1236,7 +1236,7 @@ int llama_context::decode(llama_batch & inp_batch) {
      int64_t n_outputs_all = 0;
  
      // count outputs
@@ -23,7 +23,7 @@ index 4735e98e..65135172 100644
          for (uint32_t i = 0; i < n_tokens_all; ++i) {
              n_outputs_all += batch.logits[i] != 0;
          }
-@@ -1344,7 +1344,7 @@ int llama_context::decode(llama_batch & inp_batch) {
+@@ -1348,7 +1348,7 @@ int llama_context::decode(llama_batch & inp_batch) {
          //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
          //}
  
@@ -32,7 +32,7 @@ index 4735e98e..65135172 100644
          auto * t_embd   = cparams.embeddings ? res->get_embd() : nullptr;
  
          if (t_embd && res->get_embd_pooled()) {
-@@ -1488,7 +1488,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
+@@ -1492,7 +1492,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
      const auto n_embd  = hparams.n_embd;
  
      // TODO: use a per-batch flag for logits presence instead
diff --git a/llama/patches/0004-clip-unicode.patch b/llama/patches/0004-clip-unicode.patch
index adf30d0f5..81d61a827 100644
--- a/llama/patches/0004-clip-unicode.patch
+++ b/llama/patches/0004-clip-unicode.patch
@@ -10,12 +10,12 @@ filesystems for paths that include wide characters
  1 file changed, 39 insertions(+)
 
 diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
-index 49c90b75..4b72ea9f 100644
+index 75970615..d57b4bd6 100644
 --- a/examples/llava/clip.cpp
 +++ b/examples/llava/clip.cpp
-@@ -28,6 +28,19 @@
- #include 
+@@ -29,6 +29,19 @@
  #include 
+ #include 
  
 +#if defined(_WIN32)
 +#define WIN32_LEAN_AND_MEAN
@@ -33,7 +33,7 @@ index 49c90b75..4b72ea9f 100644
  struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
  
  //#define CLIP_DEBUG_FUNCTIONS
-@@ -1429,7 +1442,29 @@ struct clip_model_loader {
+@@ -1430,7 +1443,29 @@ struct clip_model_loader {
          {
              std::vector read_buf;
  
@@ -63,7 +63,7 @@ index 49c90b75..4b72ea9f 100644
              if (!fin) {
                  throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
              }
-@@ -1456,7 +1491,11 @@ struct clip_model_loader {
+@@ -1457,7 +1492,11 @@ struct clip_model_loader {
                      ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
                  }
              }
diff --git a/llama/patches/0005-solar-pro.patch b/llama/patches/0005-solar-pro.patch
index 770464d15..76ddc6197 100644
--- a/llama/patches/0005-solar-pro.patch
+++ b/llama/patches/0005-solar-pro.patch
@@ -1,6 +1,6 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
 From: jmorganca 
-Date: Tue, 8 Apr 2025 16:03:51 -0700
+Date: Sun, 20 Apr 2025 16:11:09 -0700
 Subject: [PATCH] solar-pro
 
 adds support for the Solar Pro architecture
@@ -15,7 +15,7 @@ adds support for the Solar Pro architecture
  7 files changed, 248 insertions(+)
 
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index a6fddc7f..0b0fedcd 100644
+index 62e1480b..f754bc8f 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
 @@ -68,6 +68,7 @@ static const std::map LLM_ARCH_NAMES = {
@@ -31,10 +31,10 @@ index a6fddc7f..0b0fedcd 100644
      { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
      { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
 +    { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,        "%s.attention.block_skip_connection"        },
+     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
+     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
  
-     { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
-     { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-@@ -1478,6 +1480,24 @@ static const std::map> LLM_TENSOR_N
+@@ -1482,6 +1484,24 @@ static const std::map> LLM_TENSOR_N
              { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
          },
      },
@@ -59,7 +59,7 @@ index a6fddc7f..0b0fedcd 100644
      {
          LLM_ARCH_WAVTOKENIZER_DEC,
          {
-@@ -1671,6 +1691,7 @@ static const std::map LLM_TENSOR_INFOS = {
+@@ -1660,6 +1680,7 @@ static const std::map LLM_TENSOR_INFOS = {
      {LLM_TENSOR_FFN_EXP_PROBS_B,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
      // this tensor is loaded for T5, but never used
      {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
@@ -68,7 +68,7 @@ index a6fddc7f..0b0fedcd 100644
      {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index 2c2099b3..74aa3dd0 100644
+index 98ca00a1..439aaeab 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
 @@ -72,6 +72,7 @@ enum llm_arch {
@@ -84,10 +84,10 @@ index 2c2099b3..74aa3dd0 100644
      LLM_KV_ATTENTION_SLIDING_WINDOW,
      LLM_KV_ATTENTION_SCALE,
 +    LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
+     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
+     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
  
-     LLM_KV_ROPE_DIMENSION_COUNT,
-     LLM_KV_ROPE_DIMENSION_SECTIONS,
-@@ -340,6 +342,7 @@ enum llm_tensor {
+@@ -344,6 +346,7 @@ enum llm_tensor {
      LLM_TENSOR_ENC_OUTPUT_NORM,
      LLM_TENSOR_CLS,
      LLM_TENSOR_CLS_OUT,
@@ -115,10 +115,10 @@ index 90dfe7a7..8a667960 100644
      if (il < n_layer) {
          return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
 diff --git a/src/llama-hparams.h b/src/llama-hparams.h
-index 4e0b5719..c3147cbc 100644
+index 80fcd65d..6e278945 100644
 --- a/src/llama-hparams.h
 +++ b/src/llama-hparams.h
-@@ -51,6 +51,8 @@ struct llama_hparams {
+@@ -55,6 +55,8 @@ struct llama_hparams {
      std::array n_head_kv_arr;
      std::array n_ff_arr;
  
@@ -127,7 +127,7 @@ index 4e0b5719..c3147cbc 100644
      uint32_t n_layer_dense_lead = 0;
      uint32_t n_lora_q           = 0;
      uint32_t n_lora_kv          = 0;
-@@ -149,6 +151,9 @@ struct llama_hparams {
+@@ -153,6 +155,9 @@ struct llama_hparams {
      // dimension of the recurrent state embeddings
      uint32_t n_embd_v_s() const;
  
@@ -150,10 +150,10 @@ index ea73a8a7..a012aeae 100644
  llama_model_loader::llama_model_loader(
          const std::string & fname,
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index b74dd72c..5fbd0055 100644
+index 6b7bfecf..aba42819 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
-@@ -1372,6 +1372,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+@@ -1374,6 +1374,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                      default: type = LLM_TYPE_UNKNOWN;
                 }
              } break;
@@ -175,7 +175,7 @@ index b74dd72c..5fbd0055 100644
          case LLM_ARCH_WAVTOKENIZER_DEC:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-@@ -3701,6 +3716,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+@@ -3717,6 +3732,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
  
                          layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
  
@@ -210,7 +210,7 @@ index b74dd72c..5fbd0055 100644
                          layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
                          layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                          layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-@@ -12244,6 +12287,165 @@ struct llm_build_chameleon : public llm_graph_context {
+@@ -12296,6 +12339,165 @@ struct llm_build_chameleon : public llm_graph_context {
      }
  };
  
@@ -309,14 +309,14 @@ index b74dd72c..5fbd0055 100644
 +                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
 +                        ext_factor, attn_factor, beta_fast, beta_slow
 +                        );
-+                
++
 +                cb(Qcur, "Qcur", il);
 +                cb(Kcur, "Kcur", il);
 +                cb(Vcur, "Vcur", il);
 +
 +                cur = build_attn(inp_attn, gf,
 +                        model.layers[il].wo, model.layers[il].bo,
-+                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
++                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
 +                cb(cur, "attn_out", il);
 +            }
 +
@@ -376,7 +376,7 @@ index b74dd72c..5fbd0055 100644
  struct llm_build_wavtokenizer_dec : public llm_graph_context {
      llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
          ggml_tensor * cur;
-@@ -12993,6 +13195,10 @@ llm_graph_result_ptr llama_model::build_graph(
+@@ -13045,6 +13247,10 @@ llm_graph_result_ptr llama_model::build_graph(
              {
                  llm = std::make_unique(*this, params, gf);
              } break;
@@ -387,7 +387,7 @@ index b74dd72c..5fbd0055 100644
          case LLM_ARCH_WAVTOKENIZER_DEC:
              {
                  llm = std::make_unique(*this, params, gf);
-@@ -13139,6 +13345,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
+@@ -13191,6 +13397,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
          case LLM_ARCH_GRANITE:
          case LLM_ARCH_GRANITE_MOE:
          case LLM_ARCH_CHAMELEON:
@@ -396,7 +396,7 @@ index b74dd72c..5fbd0055 100644
              return LLAMA_ROPE_TYPE_NORM;
  
 diff --git a/src/llama-model.h b/src/llama-model.h
-index 0f18dac1..e08d4ae4 100644
+index fd82d106..5865d5e9 100644
 --- a/src/llama-model.h
 +++ b/src/llama-model.h
 @@ -62,6 +62,7 @@ enum llm_type {
@@ -407,7 +407,7 @@ index 0f18dac1..e08d4ae4 100644
      LLM_TYPE_30B,
      LLM_TYPE_32B,
      LLM_TYPE_34B,
-@@ -305,6 +306,8 @@ struct llama_layer {
+@@ -307,6 +308,8 @@ struct llama_layer {
      struct ggml_tensor * ffn_up_scale   = nullptr;
      struct ggml_tensor * ffn_down_scale = nullptr;
  
diff --git a/llama/patches/0007-add-mllama-support.patch b/llama/patches/0006-add-mllama-support.patch
similarity index 94%
rename from llama/patches/0007-add-mllama-support.patch
rename to llama/patches/0006-add-mllama-support.patch
index 6d7e91d7e..e5fa0462c 100644
--- a/llama/patches/0007-add-mllama-support.patch
+++ b/llama/patches/0006-add-mllama-support.patch
@@ -1,6 +1,6 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
 From: jmorganca 
-Date: Tue, 8 Apr 2025 19:27:12 -0700
+Date: Sun, 20 Apr 2025 16:12:36 -0700
 Subject: [PATCH] add mllama support
 
 adds support for the llama 3.2 vision architecture
@@ -28,7 +28,7 @@ adds support for the llama 3.2 vision architecture
  20 files changed, 475 insertions(+), 22 deletions(-)
 
 diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp
-index 91a07e2a..13127c7b 100644
+index 3d566475..654d1358 100644
 --- a/examples/llava/gemma3-cli.cpp
 +++ b/examples/llava/gemma3-cli.cpp
 @@ -106,7 +106,7 @@ struct decode_embd_batch {
@@ -79,10 +79,10 @@ index 03a22cbb..5eb40bcd 100644
              LOG_ERR("%s : failed to eval\n", __func__);
              return false;
 diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp
-index 114c274b..a0e649ad 100644
+index 3fd5bebc..f0cec596 100644
 --- a/examples/llava/mtmd.cpp
 +++ b/examples/llava/mtmd.cpp
-@@ -213,7 +213,7 @@ struct decode_embd_batch {
+@@ -233,7 +233,7 @@ struct decode_embd_batch {
      std::vector seq_ids;
      std::vector         logits;
      llama_batch batch;
@@ -91,7 +91,7 @@ index 114c274b..a0e649ad 100644
          pos     .resize(n_tokens);
          n_seq_id.resize(n_tokens);
          seq_ids .resize(n_tokens + 1);
-@@ -225,6 +225,7 @@ struct decode_embd_batch {
+@@ -245,6 +245,7 @@ struct decode_embd_batch {
              /*n_tokens       =*/ n_tokens,
              /*tokens         =*/ nullptr,
              /*embd           =*/ embd,
@@ -99,9 +99,9 @@ index 114c274b..a0e649ad 100644
              /*pos            =*/ pos.data(),
              /*n_seq_id       =*/ n_seq_id.data(),
              /*seq_id         =*/ seq_ids.data(),
-@@ -291,7 +292,8 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
+@@ -311,7 +312,8 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
  
-             int32_t n_tokens = chunk.tokens_image->n_tokens();
+             int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
              float * embd = mtmd_get_output_embd(ctx);
 -            decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
 +            int n_embd  = llama_model_n_embd(llama_get_model(lctx));
@@ -158,7 +158,7 @@ index 5657fbf0..f91896e4 100644
      LLAMA_API void llama_free(struct llama_context * ctx);
  
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index 0b0fedcd..c1f78618 100644
+index f754bc8f..0568565f 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
 @@ -6,6 +6,7 @@
@@ -174,10 +174,10 @@ index 0b0fedcd..c1f78618 100644
      { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
      { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,        "%s.attention.block_skip_connection"        },
 +    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,       "%s.attention.cross_attention_layers"       },
+     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
+     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
  
-     { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
-     { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-@@ -269,6 +271,40 @@ static const std::map> LLM_TENSOR_N
+@@ -271,6 +273,40 @@ static const std::map> LLM_TENSOR_N
              { LLM_TENSOR_FFN_UP_SHEXP,    "blk.%d.ffn_up_shexp" },
          },
      },
@@ -218,7 +218,7 @@ index 0b0fedcd..c1f78618 100644
      {
          LLM_ARCH_DECI,
          {
-@@ -1692,6 +1728,14 @@ static const std::map LLM_TENSOR_INFOS = {
+@@ -1681,6 +1717,14 @@ static const std::map LLM_TENSOR_INFOS = {
      // this tensor is loaded for T5, but never used
      {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
      {LLM_TENSOR_BSKCN_TV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -234,7 +234,7 @@ index 0b0fedcd..c1f78618 100644
      {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index 74aa3dd0..f987844d 100644
+index 439aaeab..6a989034 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
 @@ -11,6 +11,7 @@
@@ -250,10 +250,10 @@ index 74aa3dd0..f987844d 100644
      LLM_KV_ATTENTION_SCALE,
      LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
 +    LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
+     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
+     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
  
-     LLM_KV_ROPE_DIMENSION_COUNT,
-     LLM_KV_ROPE_DIMENSION_SECTIONS,
-@@ -343,6 +345,14 @@ enum llm_tensor {
+@@ -347,6 +349,14 @@ enum llm_tensor {
      LLM_TENSOR_CLS,
      LLM_TENSOR_CLS_OUT,
      LLM_TENSOR_BSKCN_TV,
@@ -297,10 +297,10 @@ index 01d5ca57..8682b0e6 100644
          batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
      }
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index 65135172..afe6f552 100644
+index 32f59819..0343ba8a 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -858,7 +858,7 @@ float * llama_context::get_logits_ith(int32_t i) {
+@@ -862,7 +862,7 @@ float * llama_context::get_logits_ith(int32_t i) {
              throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
          }
  
@@ -309,7 +309,7 @@ index 65135172..afe6f552 100644
      } catch (const std::exception & err) {
          LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
  #ifndef NDEBUG
-@@ -979,6 +979,10 @@ void llama_context::set_warmup(bool value) {
+@@ -983,6 +983,10 @@ void llama_context::set_warmup(bool value) {
      cparams.warmup = value;
  }
  
@@ -320,7 +320,7 @@ index 65135172..afe6f552 100644
  void llama_context::set_adapter_lora(
              llama_adapter_lora * adapter,
              float scale) {
-@@ -1054,7 +1058,7 @@ int llama_context::encode(llama_batch & inp_batch) {
+@@ -1058,7 +1062,7 @@ int llama_context::encode(llama_batch & inp_batch) {
  
      const int64_t n_embd = hparams.n_embd;
  
@@ -329,7 +329,7 @@ index 65135172..afe6f552 100644
  
      const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
  
-@@ -1194,10 +1198,9 @@ int llama_context::decode(llama_batch & inp_batch) {
+@@ -1198,10 +1202,9 @@ int llama_context::decode(llama_batch & inp_batch) {
  
      const llama_batch & batch = batch_allocr.batch;
  
@@ -341,7 +341,7 @@ index 65135172..afe6f552 100644
  
      const int64_t n_tokens_all = batch.n_tokens;
      const int64_t n_embd       = hparams.n_embd;
-@@ -1245,7 +1248,7 @@ int llama_context::decode(llama_batch & inp_batch) {
+@@ -1249,7 +1252,7 @@ int llama_context::decode(llama_batch & inp_batch) {
  
      const bool logits_all = n_outputs_all == n_tokens_all;
  
@@ -350,7 +350,7 @@ index 65135172..afe6f552 100644
              /* simple_split */ !kv_self->recurrent,
              /* logits_all   */ logits_all);
  
-@@ -1479,12 +1482,11 @@ int llama_context::decode(llama_batch & inp_batch) {
+@@ -1483,12 +1486,11 @@ int llama_context::decode(llama_batch & inp_batch) {
  
  int32_t llama_context::output_reserve(int32_t n_outputs) {
      const auto & hparams = model.hparams;
@@ -364,7 +364,7 @@ index 65135172..afe6f552 100644
      const auto n_embd  = hparams.n_embd;
  
      // TODO: use a per-batch flag for logits presence instead
-@@ -1554,7 +1556,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
+@@ -1558,7 +1560,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
  void llama_context::output_reorder() {
      auto & out_ids = sbatch.out_ids;
      if (!out_ids.empty()) {
@@ -373,7 +373,7 @@ index 65135172..afe6f552 100644
          const uint32_t n_embd  = model.hparams.n_embd;
  
          GGML_ASSERT((size_t) n_outputs == out_ids.size());
-@@ -2061,7 +2063,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
+@@ -2065,7 +2067,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
      {
          LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
  
@@ -382,7 +382,7 @@ index 65135172..afe6f552 100644
  
          io.write(&logits_size, sizeof(logits_size));
  
-@@ -2244,6 +2246,7 @@ llama_context_params llama_context_default_params() {
+@@ -2248,6 +2250,7 @@ llama_context_params llama_context_default_params() {
          /*.offload_kqv                 =*/ true,
          /*.flash_attn                  =*/ false,
          /*.no_perf                     =*/ true,
@@ -390,7 +390,7 @@ index 65135172..afe6f552 100644
          /*.abort_callback              =*/ nullptr,
          /*.abort_callback_data         =*/ nullptr,
      };
-@@ -2371,6 +2374,10 @@ void llama_set_warmup(llama_context * ctx, bool warmup) {
+@@ -2375,6 +2378,10 @@ void llama_set_warmup(llama_context * ctx, bool warmup) {
      ctx->set_warmup(warmup);
  }
  
@@ -426,7 +426,7 @@ index 30e550f0..85ad91b9 100644
  
      enum llama_pooling_type pooling_type;
 diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
-index cd955d63..83f3c5a8 100644
+index a85e9728..d740c120 100644
 --- a/src/llama-graph.cpp
 +++ b/src/llama-graph.cpp
 @@ -546,6 +546,12 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -442,7 +442,7 @@ index cd955d63..83f3c5a8 100644
  //
  // llm_graph_context
  //
-@@ -1495,6 +1501,25 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
+@@ -1506,6 +1512,25 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
      return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
  }
  
@@ -469,7 +469,7 @@ index cd955d63..83f3c5a8 100644
          llm_graph_input_attn_cross * inp,
          ggml_cgraph * gf,
 diff --git a/src/llama-graph.h b/src/llama-graph.h
-index 5b6618f9..51993998 100644
+index d192dc14..260a2af2 100644
 --- a/src/llama-graph.h
 +++ b/src/llama-graph.h
 @@ -86,6 +86,7 @@ public:
@@ -518,7 +518,7 @@ index 8a667960..6a02de03 100644
 +    return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
 +}
 diff --git a/src/llama-hparams.h b/src/llama-hparams.h
-index c3147cbc..4567a0e9 100644
+index 6e278945..c8a34d52 100644
 --- a/src/llama-hparams.h
 +++ b/src/llama-hparams.h
 @@ -2,6 +2,8 @@
@@ -536,9 +536,9 @@ index c3147cbc..4567a0e9 100644
      uint32_t n_rel_attn_bkts = 0;
 +    uint32_t n_vocab = 0;
  
-     // for WavTokenizer
-     struct llama_hparams_posnet   posnet;
-@@ -52,6 +55,7 @@ struct llama_hparams {
+     // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
+     uint32_t n_embd_head_k_mla = 0;
+@@ -56,6 +59,7 @@ struct llama_hparams {
      std::array n_ff_arr;
  
      std::array, 4> n_bskcn_arr = {};
@@ -546,7 +546,7 @@ index c3147cbc..4567a0e9 100644
  
      uint32_t n_layer_dense_lead = 0;
      uint32_t n_lora_q           = 0;
-@@ -154,6 +158,9 @@ struct llama_hparams {
+@@ -158,6 +162,9 @@ struct llama_hparams {
      // Block skip connection
      bool n_bskcn(uint32_t n, uint32_t il) const;
  
@@ -557,7 +557,7 @@ index c3147cbc..4567a0e9 100644
  };
  
 diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
-index dbf5f118..9310f262 100644
+index 7c9d46d8..69f8d35a 100644
 --- a/src/llama-kv-cache.cpp
 +++ b/src/llama-kv-cache.cpp
 @@ -95,8 +95,16 @@ bool llama_kv_cache_unified::init(
@@ -593,7 +593,7 @@ index a012aeae..2e11507d 100644
      bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) {
          const int kid = gguf_find_key(meta.get(), key.c_str());
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index 5fbd0055..d5ad466e 100644
+index aba42819..d051696c 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
 @@ -419,6 +419,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
@@ -650,7 +650,7 @@ index 5fbd0055..d5ad466e 100644
          case LLM_ARCH_DECI:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-@@ -1548,7 +1562,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+@@ -1550,7 +1564,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
          const int64_t n_embd_head_v = hparams.n_embd_head_v;
          const int64_t n_ff          = hparams.n_ff();
          const int64_t n_embd_gqa    = n_embd_v_gqa;
@@ -659,7 +659,7 @@ index 5fbd0055..d5ad466e 100644
          const int64_t n_token_types = vocab.n_token_types();
          const int64_t n_rot         = hparams.n_rot;
          const int64_t n_expert      = hparams.n_expert;
-@@ -1801,6 +1815,52 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+@@ -1803,6 +1817,52 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                          }
                      }
                  } break;
@@ -712,7 +712,7 @@ index 5fbd0055..d5ad466e 100644
              case LLM_ARCH_DECI:
                  {
                      tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-@@ -4665,6 +4725,246 @@ struct llm_build_llama : public llm_graph_context {
+@@ -4683,6 +4743,246 @@ struct llm_build_llama : public llm_graph_context {
      }
  };
  
@@ -893,14 +893,14 @@ index 5fbd0055..d5ad466e 100644
 +                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
 +                        ext_factor, attn_factor, beta_fast, beta_slow
 +                        );
-+                
++
 +                cb(Qcur, "Qcur", il);
 +                cb(Kcur, "Kcur", il);
 +                cb(Vcur, "Vcur", il);
 +
 +                cur = build_attn(inp_attn, gf,
 +                    model.layers[il].wo, model.layers[il].bo,
-+                    Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
++                    Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
 +
 +                if (il == n_layer - 1) {
 +                    // skip computing output for unused tokens
@@ -959,7 +959,7 @@ index 5fbd0055..d5ad466e 100644
  struct llm_build_deci : public llm_graph_context {
      llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
          const int64_t n_embd_head = hparams.n_embd_head_v;
-@@ -12965,6 +13265,10 @@ llm_graph_result_ptr llama_model::build_graph(
+@@ -13017,6 +13317,10 @@ llm_graph_result_ptr llama_model::build_graph(
              {
                  llm = std::make_unique(*this, params, gf);
              } break;
@@ -970,7 +970,7 @@ index 5fbd0055..d5ad466e 100644
          case LLM_ARCH_DECI:
              {
                  llm = std::make_unique(*this, params, gf);
-@@ -13325,6 +13629,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
+@@ -13377,6 +13681,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
          // use what we call a normal RoPE, operating on pairs of consecutive head values
          case LLM_ARCH_LLAMA:
          case LLM_ARCH_LLAMA4:
@@ -979,7 +979,7 @@ index 5fbd0055..d5ad466e 100644
          case LLM_ARCH_BAICHUAN:
          case LLM_ARCH_STARCODER:
 diff --git a/src/llama-model.h b/src/llama-model.h
-index e08d4ae4..21c4617b 100644
+index 5865d5e9..72bab5be 100644
 --- a/src/llama-model.h
 +++ b/src/llama-model.h
 @@ -11,6 +11,7 @@
@@ -998,7 +998,7 @@ index e08d4ae4..21c4617b 100644
      LLM_TYPE_236B,
      LLM_TYPE_314B,
      LLM_TYPE_671B,
-@@ -308,6 +310,16 @@ struct llama_layer {
+@@ -310,6 +312,16 @@ struct llama_layer {
  
      struct ggml_tensor * bskcn_tv = nullptr;
  
diff --git a/llama/patches/0006-conditional-fattn.patch b/llama/patches/0006-conditional-fattn.patch
deleted file mode 100644
index 5be2b81d0..000000000
--- a/llama/patches/0006-conditional-fattn.patch
+++ /dev/null
@@ -1,25 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Daniel Hiltgen 
-Date: Wed, 9 Oct 2024 17:26:23 -0700
-Subject: [PATCH] conditional-fattn
-
----
- ggml/src/ggml-cuda/ggml-cuda.cu | 2 ++
- 1 file changed, 2 insertions(+)
-
-diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index 59a49560..b70c6a32 100644
---- a/ggml/src/ggml-cuda/ggml-cuda.cu
-+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2338,9 +2338,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
-         case GGML_OP_ARGSORT:
-             ggml_cuda_op_argsort(ctx, dst);
-             break;
-+#if !defined(GGML_DISABLE_FLASH_ATTN)
-         case GGML_OP_FLASH_ATTN_EXT:
-             ggml_cuda_flash_attn_ext(ctx, dst);
-             break;
-+#endif
-         case GGML_OP_CROSS_ENTROPY_LOSS:
-             ggml_cuda_cross_entropy_loss(ctx, dst);
-             break;
diff --git a/llama/patches/0008-add-unpad-operator.patch b/llama/patches/0007-add-unpad-operator.patch
similarity index 97%
rename from llama/patches/0008-add-unpad-operator.patch
rename to llama/patches/0007-add-unpad-operator.patch
index ce409ce9f..116545d67 100644
--- a/llama/patches/0008-add-unpad-operator.patch
+++ b/llama/patches/0007-add-unpad-operator.patch
@@ -147,10 +147,10 @@ index 410a3720..3eca1cf8 100644
  void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
  void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index b70c6a32..67208cba 100644
+index 31750b6f..0fef9522 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2245,6 +2245,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
+@@ -2246,6 +2246,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
          case GGML_OP_PAD:
              ggml_cuda_op_pad(ctx, dst);
              break;
@@ -160,7 +160,7 @@ index b70c6a32..67208cba 100644
          case GGML_OP_ARANGE:
              ggml_cuda_op_arange(ctx, dst);
              break;
-@@ -3223,6 +3226,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
+@@ -3222,6 +3225,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
          case GGML_OP_UPSCALE:
              return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
          case GGML_OP_PAD:
@@ -233,7 +233,7 @@ index 8fd386b0..e2ededc3 100644
  void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 +void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index 310afe8a..b121ab9e 100644
+index 12886cd3..b2e95a66 100644
 --- a/ggml/src/ggml-metal/ggml-metal.m
 +++ b/ggml/src/ggml-metal/ggml-metal.m
 @@ -341,6 +341,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
@@ -244,7 +244,7 @@ index 310afe8a..b121ab9e 100644
      GGML_METAL_KERNEL_TYPE_ARANGE_F32,
      GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
      GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
-@@ -998,6 +999,7 @@ @implementation GGMLMetalClass
+@@ -1020,6 +1021,7 @@ @implementation GGMLMetalClass
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                     upscale_f32,                     true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                         pad_f32,                         true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,              pad_reflect_1d_f32,              true);
@@ -252,7 +252,7 @@ index 310afe8a..b121ab9e 100644
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,          timestep_embedding_f32,          true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                      arange_f32,                      true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,             argsort_f32_i32_asc,             true);
-@@ -1339,6 +1341,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
+@@ -1384,6 +1386,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
          case GGML_OP_POOL_2D:
          case GGML_OP_PAD:
          case GGML_OP_PAD_REFLECT_1D:
@@ -260,7 +260,7 @@ index 310afe8a..b121ab9e 100644
          case GGML_OP_TIMESTEP_EMBEDDING:
          case GGML_OP_ARGSORT:
          case GGML_OP_LEAKY_RELU:
-@@ -3669,6 +3672,36 @@ static void ggml_metal_encode_node(
+@@ -3731,6 +3734,36 @@ static void ggml_metal_encode_node(
  
                  const int nth = MIN(1024, ne0);
  
@@ -298,10 +298,10 @@ index 310afe8a..b121ab9e 100644
              } break;
          case GGML_OP_ARANGE:
 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index b08666e2..e3185e5b 100644
+index 8d6e99e6..71f0f97f 100644
 --- a/ggml/src/ggml-metal/ggml-metal.metal
 +++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -2968,6 +2968,51 @@ kernel void kernel_pad_reflect_1d_f32(
+@@ -2975,6 +2975,51 @@ kernel void kernel_pad_reflect_1d_f32(
      }
  }
  
diff --git a/llama/patches/0009-fix-deepseek-deseret-regex.patch b/llama/patches/0008-fix-deepseek-deseret-regex.patch
similarity index 99%
rename from llama/patches/0009-fix-deepseek-deseret-regex.patch
rename to llama/patches/0008-fix-deepseek-deseret-regex.patch
index 7c7336492..9b2d33984 100644
--- a/llama/patches/0009-fix-deepseek-deseret-regex.patch
+++ b/llama/patches/0008-fix-deepseek-deseret-regex.patch
@@ -12,7 +12,7 @@ regex
  2 files changed, 22 insertions(+), 1 deletion(-)
 
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index 0125ee53..d74919d2 100644
+index a35b498c..032019c9 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
 @@ -296,7 +296,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
diff --git a/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch b/llama/patches/0009-maintain-ordering-for-rules-for-grammar.patch
similarity index 100%
rename from llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
rename to llama/patches/0009-maintain-ordering-for-rules-for-grammar.patch
diff --git a/llama/patches/0011-ensure-KV-cache-is-fully-defragmented.patch b/llama/patches/0010-ensure-KV-cache-is-fully-defragmented.patch
similarity index 96%
rename from llama/patches/0011-ensure-KV-cache-is-fully-defragmented.patch
rename to llama/patches/0010-ensure-KV-cache-is-fully-defragmented.patch
index f1fe118e6..c9d4e9ad8 100644
--- a/llama/patches/0011-ensure-KV-cache-is-fully-defragmented.patch
+++ b/llama/patches/0010-ensure-KV-cache-is-fully-defragmented.patch
@@ -22,10 +22,10 @@ multiple batches of processing until everything is complete.
  4 files changed, 51 insertions(+), 106 deletions(-)
 
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index afe6f552..d6e7b3af 100644
+index 0343ba8a..4b3e6a83 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -590,13 +590,12 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
+@@ -594,13 +594,12 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
  
  llm_graph_result_ptr llama_context::build_kv_self_defrag(
          ggml_context * ctx0,
@@ -41,7 +41,7 @@ index afe6f552..d6e7b3af 100644
  #if 0
      // CPU defrag
      //
-@@ -668,32 +667,20 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
+@@ -672,32 +671,20 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
          ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
      }
  #else
@@ -79,7 +79,7 @@ index afe6f552..d6e7b3af 100644
  
              ggml_tensor * view_v_src;
              ggml_tensor * view_v_dst;
-@@ -701,34 +688,30 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
+@@ -705,34 +692,30 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
              if (cparams.flash_attn) {
                  // NOTE: the V cache is not transposed when using flash attention
                  view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
@@ -122,7 +122,7 @@ index afe6f552..d6e7b3af 100644
  #endif
  
      return res;
-@@ -737,8 +720,6 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
+@@ -741,8 +724,6 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
  void llama_context::kv_self_update() {
      auto & kv = kv_self;
  
@@ -131,7 +131,7 @@ index afe6f552..d6e7b3af 100644
      if (kv->has_shift) {
          if (!kv->get_can_shift()) {
              GGML_ABORT("The current context does not support K-shift");
-@@ -759,8 +740,6 @@ void llama_context::kv_self_update() {
+@@ -763,8 +744,6 @@ void llama_context::kv_self_update() {
              res->set_inputs(nullptr);
  
              graph_compute(gf, false);
@@ -140,7 +140,7 @@ index afe6f552..d6e7b3af 100644
          }
  
          {
-@@ -775,49 +754,28 @@ void llama_context::kv_self_update() {
+@@ -779,49 +758,28 @@ void llama_context::kv_self_update() {
      // defragment the KV cache if needed
      if (kv->do_defrag) {
          LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
@@ -202,7 +202,7 @@ index afe6f552..d6e7b3af 100644
  }
  
  enum llama_pooling_type llama_context::pooling_type() const {
-@@ -1301,9 +1259,12 @@ int llama_context::decode(llama_batch & inp_batch) {
+@@ -1305,9 +1263,12 @@ int llama_context::decode(llama_batch & inp_batch) {
          // find KV slot
          {
              if (!kv_self->find_slot(ubatch)) {
@@ -241,7 +241,7 @@ index baa03276..a59ff8fd 100644
      // TODO: read/write lora adapters and cvec
      size_t state_write_data(llama_io_write_i & io);
 diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
-index 9310f262..5c941e7c 100644
+index 69f8d35a..35a750d3 100644
 --- a/src/llama-kv-cache.cpp
 +++ b/src/llama-kv-cache.cpp
 @@ -781,17 +781,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
diff --git a/llama/patches/0012-sort-devices-by-score.patch b/llama/patches/0011-sort-devices-by-score.patch
similarity index 100%
rename from llama/patches/0012-sort-devices-by-score.patch
rename to llama/patches/0011-sort-devices-by-score.patch
diff --git a/llama/patches/0013-add-phony-target-ggml-cpu-for-all-cpu-variants.patch b/llama/patches/0012-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
similarity index 100%
rename from llama/patches/0013-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
rename to llama/patches/0012-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
diff --git a/llama/patches/0014-remove-amx.patch b/llama/patches/0013-remove-amx.patch
similarity index 100%
rename from llama/patches/0014-remove-amx.patch
rename to llama/patches/0013-remove-amx.patch
diff --git a/llama/patches/0015-fix-string-arr-kv-loading.patch b/llama/patches/0014-fix-string-arr-kv-loading.patch
similarity index 99%
rename from llama/patches/0015-fix-string-arr-kv-loading.patch
rename to llama/patches/0014-fix-string-arr-kv-loading.patch
index ae7e4ca9f..01f1b71eb 100644
--- a/llama/patches/0015-fix-string-arr-kv-loading.patch
+++ b/llama/patches/0014-fix-string-arr-kv-loading.patch
@@ -53,7 +53,7 @@ index 381a9c7d..e45b453d 100644
  }
  
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index d74919d2..c90f636c 100644
+index 032019c9..ba37df35 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
 @@ -1459,7 +1459,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
diff --git a/llama/patches/0016-ollama-debug-tensor.patch b/llama/patches/0015-ollama-debug-tensor.patch
similarity index 100%
rename from llama/patches/0016-ollama-debug-tensor.patch
rename to llama/patches/0015-ollama-debug-tensor.patch
diff --git a/llama/patches/0017-add-model-quantizations.patch b/llama/patches/0016-add-model-quantizations.patch
similarity index 92%
rename from llama/patches/0017-add-model-quantizations.patch
rename to llama/patches/0016-add-model-quantizations.patch
index 1f40328c7..3d078b03f 100644
--- a/llama/patches/0017-add-model-quantizations.patch
+++ b/llama/patches/0016-add-model-quantizations.patch
@@ -13,7 +13,7 @@ models not supported in llama.cpp
  4 files changed, 24 insertions(+)
 
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index c1f78618..bdf3d898 100644
+index 0568565f..dd01df60 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
 @@ -73,6 +73,7 @@ static const std::map LLM_ARCH_NAMES = {
@@ -24,7 +24,7 @@ index c1f78618..bdf3d898 100644
      { LLM_ARCH_UNKNOWN,          "(unknown)"        },
  };
  
-@@ -1582,6 +1583,22 @@ static const std::map> LLM_TENSOR_N
+@@ -1586,6 +1587,22 @@ static const std::map> LLM_TENSOR_N
              { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
          },
      },
@@ -48,7 +48,7 @@ index c1f78618..bdf3d898 100644
          LLM_ARCH_UNKNOWN,
          {
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index f987844d..ee081fbf 100644
+index 6a989034..b6227eeb 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
 @@ -75,6 +75,7 @@ enum llm_arch {
@@ -60,10 +60,10 @@ index f987844d..ee081fbf 100644
      LLM_ARCH_BAILINGMOE,
      LLM_ARCH_UNKNOWN,
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index d5ad466e..cd1d239c 100644
+index d051696c..c8374159 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
-@@ -1423,6 +1423,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+@@ -1425,6 +1425,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                      default: type = LLM_TYPE_UNKNOWN;
                  }
              } break;
@@ -71,7 +71,7 @@ index d5ad466e..cd1d239c 100644
          default: throw std::runtime_error("unsupported model architecture");
      }
  
-@@ -13652,6 +13653,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
+@@ -13704,6 +13705,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
          case LLM_ARCH_CHAMELEON:
          case LLM_ARCH_SOLAR:
          case LLM_ARCH_BAILINGMOE:
diff --git a/llama/patches/0018-add-op_neg.patch b/llama/patches/0018-add-op_neg.patch
deleted file mode 100644
index b9cab0e34..000000000
--- a/llama/patches/0018-add-op_neg.patch
+++ /dev/null
@@ -1,76 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Tue, 8 Apr 2025 20:41:24 -0700
-Subject: [PATCH] add op_neg
-
-adds the neg operator to ggml
----
- ggml/src/ggml-metal/ggml-metal.m     | 15 +++++++++++++++
- ggml/src/ggml-metal/ggml-metal.metal |  7 +++++++
- 2 files changed, 22 insertions(+)
-
-diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index b121ab9e..fea50521 100644
---- a/ggml/src/ggml-metal/ggml-metal.m
-+++ b/ggml/src/ggml-metal/ggml-metal.m
-@@ -461,6 +461,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
-     GGML_METAL_KERNEL_TYPE_SQRT,
-     GGML_METAL_KERNEL_TYPE_SIN,
-     GGML_METAL_KERNEL_TYPE_COS,
-+    GGML_METAL_KERNEL_TYPE_NEG,
-     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
-     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
-     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
-@@ -1119,6 +1120,7 @@ @implementation GGMLMetalClass
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                            sqrt,                            true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                             sin,                             true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true);
-+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
-@@ -1280,6 +1282,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
-                 case GGML_UNARY_OP_GELU_QUICK:
-                 case GGML_UNARY_OP_SILU:
-                 case GGML_UNARY_OP_ELU:
-+                case GGML_UNARY_OP_NEG:
-                     return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
-                 default:
-                     return false;
-@@ -1966,6 +1969,18 @@ static void ggml_metal_encode_node(
- 
-                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                 } break;
-+                case GGML_UNARY_OP_NEG:
-+                {
-+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
-+
-+                    [encoder setComputePipelineState:pipeline];
-+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-+
-+                    const int64_t n = ggml_nelements(dst);
-+
-+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-+                } break;
-                 default:
-                 {
-                     GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
-diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index e3185e5b..ede9d1e6 100644
---- a/ggml/src/ggml-metal/ggml-metal.metal
-+++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -949,6 +949,13 @@ kernel void kernel_cos(
-     dst[tpig] = cos(src0[tpig]);
- }
- 
-+kernel void kernel_neg(
-+        device const float * src0,
-+        device       float * dst,
-+        uint tpig[[thread_position_in_grid]]) {
-+    dst[tpig] = -src0[tpig];
-+}
-+
- kernel void kernel_sum_rows(
-         device const float * src0,
-         device       float * dst,
diff --git a/llama/patches/0019-fix-compiler-error-in-clip.h.patch b/llama/patches/0019-fix-compiler-error-in-clip.h.patch
deleted file mode 100644
index a0b90df18..000000000
--- a/llama/patches/0019-fix-compiler-error-in-clip.h.patch
+++ /dev/null
@@ -1,39 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Tue, 8 Apr 2025 20:49:50 -0700
-Subject: [PATCH] fix compiler error in clip.h
-
-fixes an error that occurs in clip.h when compiling
-using CGo
----
- examples/llava/clip.h | 5 +++--
- 1 file changed, 3 insertions(+), 2 deletions(-)
-
-diff --git a/examples/llava/clip.h b/examples/llava/clip.h
-index cc133a58..5fc45d3e 100644
---- a/examples/llava/clip.h
-+++ b/examples/llava/clip.h
-@@ -30,12 +30,13 @@ struct clip_image_size {
-     int height;
- };
- 
-+struct clip_image_f32;
- struct clip_image_u8_batch;
- struct clip_image_f32_batch;
- 
- struct clip_context_params {
-     bool use_gpu;
--    ggml_log_level verbosity;
-+    enum ggml_log_level verbosity;
- };
- 
- // deprecated, use clip_init
-@@ -84,7 +85,7 @@ CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
- CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
- CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
- CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
--CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
-+CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
- 
- /**
-  * Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
diff --git a/llama/patches/0020-Revert-Simplify-and-improve-CUDA-graphs-through-use-.patch b/llama/patches/0020-Revert-Simplify-and-improve-CUDA-graphs-through-use-.patch
deleted file mode 100644
index a9d41d2a3..000000000
--- a/llama/patches/0020-Revert-Simplify-and-improve-CUDA-graphs-through-use-.patch
+++ /dev/null
@@ -1,600 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Sat, 12 Apr 2025 13:06:57 -0700
-Subject: [PATCH] Revert "Simplify and improve CUDA graphs through use of
- indirect copy pointers (#9017)"
-
-this commit in llama.cpp causes errors when running llama 3.2
-vision - temporarily revert it
-
-This reverts commit 3f9da22c2b21a2cef216de50006436ef1cab8764.
----
- ggml/src/ggml-cuda/common.cuh   |   8 +-
- ggml/src/ggml-cuda/cpy.cu       | 149 ++++++++++++--------------------
- ggml/src/ggml-cuda/cpy.cuh      |   2 -
- ggml/src/ggml-cuda/ggml-cuda.cu |  93 +++++++++++++++-----
- 4 files changed, 124 insertions(+), 128 deletions(-)
-
-diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
-index 8284a001..a718b6a1 100644
---- a/ggml/src/ggml-cuda/common.cuh
-+++ b/ggml/src/ggml-cuda/common.cuh
-@@ -729,13 +729,7 @@ struct ggml_cuda_graph {
-     bool disable_due_to_failed_graph_capture = false;
-     int number_consecutive_updates = 0;
-     std::vector ggml_graph_properties;
--    bool use_cpy_indirection = false;
--    std::vector cpy_dest_ptrs;
--    char ** dest_ptrs_d;
--    int dest_ptrs_size = 0;
--    // Index to allow each cpy kernel to be aware of it's position within the graph
--    // relative to other cpy nodes.
--    int graph_cpynode_index = -1;
-+    std::vector updated_kernel_arg;
- #endif
- };
- 
-diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
-index 4f4faa3e..8396df28 100644
---- a/ggml/src/ggml-cuda/cpy.cu
-+++ b/ggml/src/ggml-cuda/cpy.cu
-@@ -39,18 +39,16 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
- }
- 
- template 
--static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
-+static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
-                                    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
--                                   const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
-+                                   const int nb12, const int nb13) {
-     const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
- 
-     if (i >= ne) {
-         return;
-     }
- 
--    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
--
-     // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
-     // then combine those indices with the corresponding byte offsets to get the total offsets
-     const int64_t i03 = i/(ne00 * ne01 * ne02);
-@@ -297,18 +295,16 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
- }
- 
- template 
--static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
-+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
-                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
--                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
-+                                 const int nb12, const int nb13) {
-     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
- 
-     if (i >= ne) {
-         return;
-     }
- 
--    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
--
-     const int i03 = i/(ne00 * ne01 * ne02);
-     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-@@ -325,18 +321,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
- }
- 
- template 
--static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
-+static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
-                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
--                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
-+                                 const int nb12, const int nb13) {
-     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
- 
-     if (i >= ne) {
-         return;
-     }
- 
--    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
--
-     const int i03 = i/(ne00 * ne01 * ne02);
-     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-@@ -352,97 +346,76 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
-     cpy_blck(cx + x_offset, cdst + dst_offset);
- }
- 
--// Copy destination pointers to GPU to be available when pointer indirection is in use
--
--void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
--#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
--    if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
--        CUDA_CHECK(cudaStreamSynchronize(stream));
--        if (cuda_graph->dest_ptrs_d != nullptr) {
--            CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
--        }
--        CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
--        cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
--    }
--    // copy destination pointers to GPU
--    CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
--    cuda_graph->graph_cpynode_index = 0; // reset index
--#else
--    GGML_UNUSED(cuda_graph); GGML_UNUSED(host_dest_ptrs);
--    GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream);
--#endif
--}
--
- static void ggml_cpy_f16_f32_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-     cpy_f32_f16<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f32_f32_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-     cpy_f32_f16<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f32_bf16_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-     cpy_f32_f16<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f32_f16_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-     cpy_f32_f16<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f32_q8_0_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     GGML_ASSERT(ne % QK8_0 == 0);
-     const int num_blocks = ne / QK8_0;
-     cpy_f32_q<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_q8_0_f32_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     const int num_blocks = ne;
-     cpy_q_f32<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f32_q4_0_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     GGML_ASSERT(ne % QK4_0 == 0);
-     const int num_blocks = ne / QK4_0;
-     cpy_f32_q<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_q4_0_f32_cuda(
-@@ -451,22 +424,22 @@ static void ggml_cpy_q4_0_f32_cuda(
-     const int nb00, const int nb01, const int nb02,
-     const int nb03, const int ne10, const int ne11, const int ne12,
-     const int nb10, const int nb11, const int nb12, const int nb13,
--    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    cudaStream_t stream) {
-     const int num_blocks = ne;
-     cpy_q_f32, QK4_0><<>>(
-         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
--         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f32_q4_1_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     GGML_ASSERT(ne % QK4_1 == 0);
-     const int num_blocks = ne / QK4_1;
-     cpy_f32_q<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_q4_1_f32_cuda(
-@@ -475,22 +448,22 @@ static void ggml_cpy_q4_1_f32_cuda(
-     const int nb00, const int nb01, const int nb02,
-     const int nb03, const int ne10, const int ne11, const int ne12,
-     const int nb10, const int nb11, const int nb12, const int nb13,
--    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    cudaStream_t stream) {
-     const int num_blocks = ne;
-     cpy_q_f32, QK4_1><<>>(
-         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
--         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f32_q5_0_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     GGML_ASSERT(ne % QK5_0 == 0);
-     const int num_blocks = ne / QK5_0;
-     cpy_f32_q<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_q5_0_f32_cuda(
-@@ -499,22 +472,22 @@ static void ggml_cpy_q5_0_f32_cuda(
-     const int nb00, const int nb01, const int nb02,
-     const int nb03, const int ne10, const int ne11, const int ne12,
-     const int nb10, const int nb11, const int nb12, const int nb13,
--    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    cudaStream_t stream) {
-     const int num_blocks = ne;
-     cpy_q_f32, QK5_0><<>>(
-         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
--        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f32_q5_1_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     GGML_ASSERT(ne % QK5_1 == 0);
-     const int num_blocks = ne / QK5_1;
-     cpy_f32_q<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_q5_1_f32_cuda(
-@@ -523,32 +496,32 @@ static void ggml_cpy_q5_1_f32_cuda(
-     const int nb00, const int nb01, const int nb02,
-     const int nb03, const int ne10, const int ne11, const int ne12,
-     const int nb10, const int nb11, const int nb12, const int nb13,
--    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    cudaStream_t stream) {
-     const int num_blocks = ne;
-     cpy_q_f32, QK5_1><<>>(
-         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
--        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f32_iq4_nl_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     GGML_ASSERT(ne % QK4_NL == 0);
-     const int num_blocks = ne / QK4_NL;
-     cpy_f32_q<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- static void ggml_cpy_f16_f16_cuda(
-     const char * cx, char * cdst, const int ne,
-     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
--    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
- 
-     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-     cpy_f32_f16<<>>
--        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
- }
- 
- void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
-@@ -585,62 +558,48 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
-     char * src0_ddc = (char *) src0->data;
-     char * src1_ddc = (char *) src1->data;
- 
--    char ** dest_ptrs_d = nullptr;
--    int graph_cpynode_index = -1;
--#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
--    if(ctx.cuda_graph->use_cpy_indirection) {
--        dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
--        graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
--    }
--#endif
-     if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
-         GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
-         CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
-     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
--        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
--        ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
--        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
--        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
--        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
--        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
-         ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
--            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
--        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
-         ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
--            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
--        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
-         ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
--            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
--        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
--        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
--        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
--        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
--        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-+        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-     } else {
-         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
-                 ggml_type_name(src0->type), ggml_type_name(src1->type));
-     }
--#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
--    if(ctx.cuda_graph->use_cpy_indirection) {
--        ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
--    }
--#endif
--
- }
- 
- void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh
-index 6bed0564..28b06cdd 100644
---- a/ggml/src/ggml-cuda/cpy.cuh
-+++ b/ggml/src/ggml-cuda/cpy.cuh
-@@ -7,5 +7,3 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
- void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
- 
- void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
--
--void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
-diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index 67208cba..a44788db 100644
---- a/ggml/src/ggml-cuda/ggml-cuda.cu
-+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2477,11 +2477,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
- 
- #ifdef USE_CUDA_GRAPH
- static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
--    bool use_cuda_graph) {
-+    std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
- 
-     // Loop over nodes in GGML graph to obtain info needed for CUDA graph
--    cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
--
-+    cuda_ctx->cuda_graph->updated_kernel_arg.clear();
-     for (int i = 0; i < cgraph->n_nodes; i++) {
-         ggml_tensor * node = cgraph->nodes[i];
- 
-@@ -2513,11 +2512,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
-         }
- 
-         if (node->op == GGML_OP_CPY) {
--
--            // Store the pointers which are updated for each token, such that these can be sent
--            // to the device and accessed using indirection from CUDA graph
--            cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
--
-+            // store the copy op parameter which changes with each token.
-+            cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
-             // store a pointer to each copy op CUDA kernel to identify it later
-             void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
-             if (!ptr) {
-@@ -2525,6 +2521,10 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
- #ifndef NDEBUG
-                 GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
- #endif
-+            } else {
-+                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
-+                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
-+                }
-             }
-         }
- 
-@@ -2533,12 +2533,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
-         }
-     }
- 
--    if (use_cuda_graph) {
--        cuda_ctx->cuda_graph->use_cpy_indirection = true;
--        // copy pointers to GPU so they can be accessed via indirection within CUDA graph
--        ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
--    }
--
-     return use_cuda_graph;
- }
- 
-@@ -2593,6 +2587,51 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
-     return true;
- }
- 
-+static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
-+
-+    if (cuda_graph_update_required) {
-+        // Extract nodes from graph
-+        // First call with null argument gets number of nodes in graph
-+        CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-+        // Subsequent call with non-null argument gets nodes
-+        cuda_ctx->cuda_graph->nodes.clear();
-+        cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
-+        cuda_ctx->cuda_graph->params.clear();
-+        cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
-+        if (cuda_ctx->cuda_graph->num_nodes > 0) {
-+            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
-+
-+            // Loop over nodes, and extract kernel parameters from each node
-+            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-+                cudaGraphNodeType node_type;
-+                CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
-+                if (node_type == cudaGraphNodeTypeKernel) {
-+                    cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
-+                    if (stat == cudaErrorInvalidDeviceFunction) {
-+                        // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
-+                        // We don't need to update blas nodes, so clear error and move on.
-+                        (void)cudaGetLastError();
-+                    } else {
-+                        GGML_ASSERT(stat == cudaSuccess);
-+                    }
-+                }
-+            }
-+        }
-+    } else {
-+        // One of the arguments to the copy kernel is updated for each token, hence we need to
-+        // replace that argument with the updated value in the CUDA graph
-+        // on update steps, the live parameters will already be captured
-+        int k = 0;
-+        for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-+            if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
-+                char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
-+                *(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
-+                CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
-+            }
-+        }
-+    }
-+}
-+
- static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
- 
-     bool cuda_graph_update_required = false;
-@@ -2652,7 +2691,8 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
- #endif
- 
- static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
--    bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
-+   [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs,  bool & graph_evaluated_or_captured, bool & use_cuda_graph,
-+    bool & cuda_graph_update_required) {
- 
-     while (!graph_evaluated_or_captured) {
-         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
-@@ -2702,9 +2742,13 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
-         if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
-             CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
-         }
--        if (cuda_graph_update_required) { // Update graph executable
--            update_cuda_graph_executable(cuda_ctx);
--        }
-+
-+        // Perform update to graph (if required for this token), and change copy parameter (required for every token)
-+        maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
-+
-+        // Update graph executable
-+        update_cuda_graph_executable(cuda_ctx);
-+
-         // Launch graph
-         CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
- #else
-@@ -2718,6 +2762,10 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
- 
-     ggml_cuda_set_device(cuda_ctx->device);
- 
-+    // vector of pointers to CUDA cpy kernels, which are required to identify
-+    // kernel parameters which need updated in the graph for each token
-+    std::vector ggml_cuda_cpy_fn_ptrs;
-+
- #ifdef USE_CUDA_GRAPH
-     static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
- 
-@@ -2751,7 +2799,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
-     if (use_cuda_graph) {
-         cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
- 
--        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
-+        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
-+                             ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
- 
-         // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
-         if (use_cuda_graph && cuda_graph_update_required) {
-@@ -2772,10 +2821,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
-         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
-     }
- 
--    if (!use_cuda_graph) {
--        cuda_ctx->cuda_graph->use_cpy_indirection = false;
--    }
--
- #else
-     bool use_cuda_graph = false;
-     bool cuda_graph_update_required = false;
-@@ -2783,7 +2828,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
- 
-     bool graph_evaluated_or_captured = false;
- 
--    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
-+    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
- 
-     return GGML_STATUS_SUCCESS;
- }
diff --git a/ml/backend/ggml/ggml/include/ggml-rpc.h b/ml/backend/ggml/ggml/include/ggml-rpc.h
index 4e0d210f8..c8b6097f7 100644
--- a/ml/backend/ggml/ggml/include/ggml-rpc.h
+++ b/ml/backend/ggml/ggml/include/ggml-rpc.h
@@ -7,6 +7,9 @@
 extern "C" {
 #endif
 
+#define RPC_PROTO_MAJOR_VERSION    1
+#define RPC_PROTO_MINOR_VERSION    0
+#define RPC_PROTO_PATCH_VERSION    0
 #define GGML_RPC_MAX_SERVERS       16
 
 // backend API
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
index 74a31abb2..175cba329 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
@@ -183,67 +183,63 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang
 
 #if defined(__AVX2__) || defined(__AVX512F__)
 #if defined(__AVX512F__)
-// add int16_t pairwise and return as 512 bit int vector
-static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
+// add int16_t pairwise and return as 512 bit int vector, then add the accumulator
+static inline __m512i sum_i16_pairs_acc_int32x16(const __m512i acc, const __m512i x) {
     const __m512i ones = _mm512_set1_epi16(1);
-    return _mm512_madd_epi16(ones, x);
+    return _mm512_add_epi32(acc, _mm512_madd_epi16(ones, x));
 }
 
-static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
+static inline __m512i mul_sum_us8_pairs_acc_int32x16(const __m512i acc, const __m512i ax, const __m512i sy) {
 #if defined(__AVX512VNNI__)
-    const __m512i zero = _mm512_setzero_si512();
-    return _mm512_dpbusd_epi32(zero, ax, sy);
+    return _mm512_dpbusd_epi32(acc, ax, sy);
 #else
     // Perform multiplication and create 16-bit values
     const __m512i dot = _mm512_maddubs_epi16(ax, sy);
-    return sum_i16_pairs_int_32x16(dot);
+    return sum_i16_pairs_acc_int32x16(acc, dot);
 #endif
 }
 
-// multiply int8_t, add results pairwise twice and return as 512 bit int vector
-static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) {
+// multiply int8_t, add results pairwise twice and return as 512 bit int vector,then add the accumulator
+static inline __m512i mul_sum_i8_pairs_acc_int32x16(const __m512i acc, const __m512i x, const __m512i y) {
     const __m512i zero = _mm512_setzero_si512();
     // Get absolute values of x vectors
     const __m512i ax = _mm512_abs_epi8(x);
     // Sign the values of the y vectors
     __mmask64 blt0 = _mm512_movepi8_mask(x);
     const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);
-    return mul_sum_us8_pairs_int32x16(ax, sy);
+    return mul_sum_us8_pairs_acc_int32x16(acc, ax, sy);
 }
 #endif
 
-// add int16_t pairwise and return as 256 bit int vector
-static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
+// add int16_t pairwise and return as 256 bit int vector, then add the accumulator
+static inline __m256i sum_i16_pairs_acc_int32x8(const __m256i acc, const __m256i x) {
     const __m256i ones = _mm256_set1_epi16(1);
-    return _mm256_madd_epi16(ones, x);
+    return _mm256_add_epi32(acc, _mm256_madd_epi16(ones, x));
 }
 
-static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
+static inline __m256i mul_sum_us8_pairs_acc_int32x8(const __m256i acc, const __m256i ax, const __m256i sy) {
 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
-    const __m256i zero = _mm256_setzero_si256();
-    return _mm256_dpbusd_epi32(zero, ax, sy);
+    return _mm256_dpbusd_epi32(acc, ax, sy);
 #elif defined(__AVXVNNI__)
-    const __m256i zero = _mm256_setzero_si256();
-    return _mm256_dpbusd_avx_epi32(zero, ax, sy);
+    return _mm256_dpbusd_avx_epi32(acc, ax, sy);
 #else
     // Perform multiplication and create 16-bit values
     const __m256i dot = _mm256_maddubs_epi16(ax, sy);
-    return sum_i16_pairs_int32x8(dot);
+    return sum_i16_pairs_acc_int32x8(acc, dot);
 #endif
 }
 
 // Integer variant of the function defined in ggml-quants.c
-// multiply int8_t, add results pairwise twice and return as 256 bit int vector
-static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) {
-#if __AVXVNNIINT8__
-    const __m256i zero = _mm256_setzero_si256();
-    return _mm256_dpbssd_epi32(zero, x, y);
+// multiply int8_t, add results pairwise twice and return as 256 bit int vector, then add the accumulator
+static inline __m256i mul_sum_i8_pairs_acc_int32x8(const __m256i acc, const __m256i x, const __m256i y) {
+#if defined(__AVXVNNIINT8__)
+    return _mm256_dpbssd_epi32(acc, x, y);
 #else
     // Get absolute values of x vectors
     const __m256i ax = _mm256_sign_epi8(x, x);
     // Sign the values of the y vectors
     const __m256i sy = _mm256_sign_epi8(y, x);
-    return mul_sum_us8_pairs_int32x8(ax, sy);
+    return mul_sum_us8_pairs_acc_int32x8(acc, ax, sy);
 #endif
 }
 #endif
@@ -1175,17 +1171,17 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
                 // ...........................................................................
                 // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
 
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
+                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0));
+                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85));
 
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
+                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170));
+                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255));
 
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
+                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0));
+                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85));
 
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
+                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170));
+                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255));
 
                 // Accumulated values multipled with appropriate scales
                 acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
@@ -3239,22 +3235,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
 
                         // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
                         // Resembles MMLAs into 2x2 matrices in ARM Version
-                        __m512i iacc_mat_00_sp1 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
-                        __m512i iacc_mat_01_sp1 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
-                        __m512i iacc_mat_10_sp1 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
-                        __m512i iacc_mat_11_sp1 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
-                        __m512i iacc_mat_00_sp2 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
-                        __m512i iacc_mat_01_sp2 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
-                        __m512i iacc_mat_10_sp2 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
-                        __m512i iacc_mat_11_sp2 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+                        const __m512i zero = _mm512_setzero_epi32();
+                        __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
+                        __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
+                        __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
+                        __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
+                        __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
+                        __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
+                        __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
+                        __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
 
                         // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
                         __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
@@ -3430,22 +3419,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
 
                     // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
                     // Resembles MMLAs into 2x2 matrices in ARM Version
-                    __m512i iacc_mat_00_sp1 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
-                    __m512i iacc_mat_01_sp1 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
-                    __m512i iacc_mat_10_sp1 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
-                    __m512i iacc_mat_11_sp1 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
-                    __m512i iacc_mat_00_sp2 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
-                    __m512i iacc_mat_01_sp2 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
-                    __m512i iacc_mat_10_sp2 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
-                    __m512i iacc_mat_11_sp2 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+                    const __m512i zero = _mm512_setzero_epi32();
+                    __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
+                    __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
+                    __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
+                    __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
+                    __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
+                    __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
+                    __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
+                    __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
 
                     // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
                     __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
@@ -3605,22 +3587,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
 
                         // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
                         // Resembles MMLAs into 2x2 matrices in ARM Version
-                        __m256i iacc_mat_00_sp1 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
-                        __m256i iacc_mat_01_sp1 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
-                        __m256i iacc_mat_10_sp1 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
-                        __m256i iacc_mat_11_sp1 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
-                        __m256i iacc_mat_00_sp2 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
-                        __m256i iacc_mat_01_sp2 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
-                        __m256i iacc_mat_10_sp2 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
-                        __m256i iacc_mat_11_sp2 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+                        const __m256i zero = _mm256_setzero_si256();
+                        __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
+                        __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
+                        __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
+                        __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
+                        __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
+                        __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
+                        __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
+                        __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
 
                         // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
                         __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
@@ -3769,22 +3744,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
 
                     // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
                     // Resembles MMLAs into 2x2 matrices in ARM Version
-                    __m256i iacc_mat_00_sp1 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
-                    __m256i iacc_mat_01_sp1 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
-                    __m256i iacc_mat_10_sp1 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
-                    __m256i iacc_mat_11_sp1 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
-                    __m256i iacc_mat_00_sp2 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
-                    __m256i iacc_mat_01_sp2 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
-                    __m256i iacc_mat_10_sp2 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
-                    __m256i iacc_mat_11_sp2 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+                    const __m256i zero = _mm256_setzero_si256();
+                    __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
+                    __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
+                    __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
+                    __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
+                    __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
+                    __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
+                    __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
+                    __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
 
                     // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
                     __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
@@ -4076,7 +4044,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c
     UNUSED(ncols_interleaved);
     UNUSED(blocklen);
 
-#if defined(__AVX2__)
+#if defined(__AVX2__) || defined(__AVX512F__)
     const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx;
     const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;
     int64_t b_nb = n / QK_K;
@@ -4086,8 +4054,748 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c
     const __m256i m4b = _mm256_set1_epi8(0x0F);
     // Permute mask used for easier vector processing at later stages
     __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
-
+    int64_t xstart = 0;
     int anr = nr - nr % 16;; // Used to align nr with boundary of 16
+#ifdef __AVX512F__
+    int anc = nc - nc % 16; // Used to align nc with boundary of 16
+    // Mask to mask out nibbles from packed bytes expanded to 512 bit length
+    const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
+    //Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
+    for (; y < anr / 4; y += 4) {
+
+        const block_q8_Kx4 * a_ptrs[4];
+
+        a_ptrs[0] = a_ptr_start + (y * nb);
+        for (int i = 0; i < 3; ++i) {
+            a_ptrs[i + 1] = a_ptrs[i] + nb;
+        }
+
+        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < anc / 8; x += 2) {
+
+            const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+            const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+            // Master FP accumulators
+            __m512 acc_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_rows[i] = _mm512_setzero_ps();
+            }
+
+            __m512 acc_min_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_min_rows[i] = _mm512_setzero_ps();
+            }
+
+            // For super block
+            for (int64_t b = 0; b < nb; b++) {
+                // Scale values - Load the sixteen scale values from two block_q4_kx8 structures
+                const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+                // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures
+                const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);
+
+                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));
+
+                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));
+
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
+
+                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+                    const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);
+                    const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);
+
+                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+                    const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);
+                    const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);
+
+                    //4-bit -> 8-bit
+                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
+                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
+                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
+                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
+
+                    const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)
+                    const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)
+                    const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)
+                    const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)
+
+                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
+                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
+                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
+                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
+
+                    const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)
+                    const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)
+                    const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)
+                    const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)
+                    const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)
+                    const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
+                    const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)
+                    const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)
+                    const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)
+                    const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)
+                    const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)
+
+                    const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)
+                    const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)
+                    const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)
+                    const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)
+                    const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)
+                    const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)
+                    const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)
+                    const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)
+
+                    // Shuffle pattern two - right side input
+                    const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)
+                    const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)
+                    const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)
+                    const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)
+                    const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)
+                    const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)
+                    const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)
+                    const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)
+
+                    const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)
+                    const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)
+                    const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)
+                    const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)
+                    const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)
+                    const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)
+                    const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)
+                    const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)
+
+                    uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];
+
+                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);
+                    utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_00 = utmp_00[1] & kmask1;
+                    utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);
+                    utmp_00[2] = uaux_00;
+                    utmp_00[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);
+                    utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_01 = utmp_01[1] & kmask1;
+                    utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);
+                    utmp_01[2] = uaux_01;
+                    utmp_01[0] &= kmask1;
+
+                    memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);
+                    utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_10 = utmp_10[1] & kmask1;
+                    utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);
+                    utmp_10[2] = uaux_10;
+                    utmp_10[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);
+                    utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_11 = utmp_11[1] & kmask1;
+                    utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);
+                    utmp_11[2] = uaux_11;
+                    utmp_11[0] &= kmask1;
+
+                    // Scales of first sub block in the sb loop
+                    const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);
+                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+
+                    // Scales of second sub block in the sb loop
+                    const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);
+                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));
+
+                    const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);
+
+                    for (int rp = 0; rp < 4; rp++) {
+
+                        // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
+                        // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+                        __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);
+                        __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);
+                        __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);
+                        __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);
+                        __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);
+                        __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);
+                        __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);
+                        __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);
+                        __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);
+                        __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);
+                        __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);
+                        __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);
+                        __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);
+                        __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);
+                        __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);
+                        __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);
+
+                        __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);
+                        __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);
+                        __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);
+                        __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);
+                        __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);
+                        __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);
+                        __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);
+                        __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);
+
+                        __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);
+                        __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);
+                        __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);
+                        __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);
+                        __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);
+                        __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);
+                        __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);
+                        __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);
+
+                        // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
+                        __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
+                        __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
+                        lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);
+                        __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);
+
+                        // Shuffle pattern one - left side input
+                        const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
+                        const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)
+                        const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                        const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)
+                        const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
+                        const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)
+                        const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
+                        const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)
+
+                        const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
+                        const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)
+                        const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                        const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)
+                        const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
+                        const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)
+                        const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
+                        const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)
+
+                        const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
+                        const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)
+                        const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
+                        const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)
+                        const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
+                        const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)
+                        const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
+                        const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)
+
+                        const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
+                        const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)
+                        const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
+                        const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)
+                        const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
+                        const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)
+                        const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
+                        const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)
+
+                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                        __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));
+                        __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));
+                        __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));
+                        __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));
+                        __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));
+                        __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));
+                        __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));
+                        __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));
+
+                        __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));
+                        __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));
+                        __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));
+                        __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));
+                        __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));
+                        __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));
+                        __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));
+                        __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
+                        __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
+                        __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
+                        __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
+
+                        __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
+                        __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
+                        __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
+                        __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
+
+                        iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);
+                        iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);
+                        iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);
+                        iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);
+
+                        iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);
+                        iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);
+                        iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);
+                        iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);
+
+                        // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
+                        __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);
+                        __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);
+                        __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);
+                        __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);
+
+                        __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);
+                        __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);
+                        __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);
+                        __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);
+
+                        // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
+                        const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
+                        const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
+                        const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);
+
+                        // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                        acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+                        acc_rows[rp * 4  + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+                        acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                        acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+
+                        __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);
+                        __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);
+                        __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);
+                        __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);
+
+                        acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
+                        acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
+                        acc_min_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
+                        acc_min_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
+                    }
+                }
+            }
+            // Store the accumulated values
+            for (int i = 0; i < 16; i++) {
+                _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));
+            }
+        }
+    }
+
+    for (; y < nr / 4; y++) {
+
+        const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
+
+        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < anc / 8; x += 2) {
+
+            const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+            const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+            // Master FP accumulators
+            __m512 acc_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_rows[i] = _mm512_setzero_ps();
+            }
+
+            __m512 acc_min_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_min_rows[i] = _mm512_setzero_ps();
+            }
+
+            // For super block
+            for (int64_t b = 0; b < nb; b++) {
+                // Scale values - Load the sixteen scale values from two block_q4_kx8 structures
+                const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+                // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures
+                const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);
+
+                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));
+
+                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));
+
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
+
+                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+                    const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);
+                    const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);
+
+                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+                    const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);
+                    const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);
+
+                    //4-bit -> 8-bit
+                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
+                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
+                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
+                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
+
+                    const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)
+                    const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)
+                    const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)
+                    const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)
+
+                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
+                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
+                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
+                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
+
+                    const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)
+                    const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)
+                    const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)
+                    const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)
+                    const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)
+                    const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
+                    const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)
+                    const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)
+                    const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)
+                    const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)
+                    const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)
+
+                    const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)
+                    const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)
+                    const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)
+                    const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)
+                    const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)
+                    const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)
+                    const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)
+                    const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)
+
+                    // Shuffle pattern two - right side input
+                    const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)
+                    const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)
+                    const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)
+                    const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)
+                    const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)
+                    const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)
+                    const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)
+                    const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)
+
+                    const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)
+                    const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)
+                    const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)
+                    const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)
+                    const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)
+                    const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)
+                    const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)
+                    const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)
+
+                    uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];
+
+                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);
+                    utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_00 = utmp_00[1] & kmask1;
+                    utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);
+                    utmp_00[2] = uaux_00;
+                    utmp_00[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);
+                    utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_01 = utmp_01[1] & kmask1;
+                    utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);
+                    utmp_01[2] = uaux_01;
+                    utmp_01[0] &= kmask1;
+
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);
+                    utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_10 = utmp_10[1] & kmask1;
+                    utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);
+                    utmp_10[2] = uaux_10;
+                    utmp_10[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);
+                    utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_11 = utmp_11[1] & kmask1;
+                    utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);
+                    utmp_11[2] = uaux_11;
+                    utmp_11[0] &= kmask1;
+
+                    // Scales of first sub block in the sb loop
+                    const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);
+                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+
+                    // Scales of second sub block in the sb loop
+                    const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);
+                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));
+
+                    const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);
+
+                    // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                    __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);
+                    __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);
+                    __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);
+                    __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);
+                    __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);
+                    __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);
+                    __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);
+                    __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);
+                    __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);
+                    __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);
+                    __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);
+                    __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);
+                    __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);
+                    __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);
+                    __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);
+                    __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);
+
+                    //Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into a 512 bit vector
+                    __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);
+                    __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);
+                    __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);
+                    __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);
+                    __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);
+                    __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);
+                    __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);
+                    __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);
+
+                    __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);
+                    __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);
+                    __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);
+                    __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);
+                    __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);
+                    __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);
+                    __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);
+                    __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);
+
+                    // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
+                    __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
+                    __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
+                    lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);
+                    __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);
+
+                    // Shuffle pattern one - left side input
+                    const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
+                    const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)
+                    const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                    const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)
+                    const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
+                    const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)
+                    const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
+                    const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)
+
+                    const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
+                    const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)
+                    const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                    const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)
+                    const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
+                    const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)
+                    const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
+                    const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)
+
+                    const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
+                    const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)
+                    const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
+                    const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)
+                    const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
+                    const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)
+                    const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
+                    const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)
+
+                    const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
+                    const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)
+                    const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
+                    const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)
+                    const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
+                    const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)
+                    const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
+                    const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)
+
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));
+                    __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));
+                    __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));
+                    __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));
+                    __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));
+                    __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));
+                    __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));
+                    __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));
+
+                    __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));
+                    __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));
+                    __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));
+                    __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));
+                    __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));
+                    __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));
+                    __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));
+                    __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));
+
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
+                    __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
+                    __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
+                    __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
+
+                    __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
+                    __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
+                    __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
+                    __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
+
+                    iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);
+                    iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);
+                    iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);
+                    iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);
+
+                    iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);
+                    iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);
+                    iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);
+                    iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);
+
+                    // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
+                    __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);
+                    __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);
+                    __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);
+                    __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);
+
+                    __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);
+                    __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);
+                    __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);
+                    __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);
+
+                    // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
+                    const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
+                    const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
+                    const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);
+
+                    // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                    acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+                    acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+                    acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+                    acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+
+                    __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);
+                    __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);
+                    __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);
+                    __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);
+
+                    acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
+                    acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
+                    acc_min_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);
+                    acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
+                }
+            }
+            // Store accumlated values
+            for (int i = 0; i < 4; i++) {
+                _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));
+            }
+        }
+    }
+    if (anc != nc) {
+        xstart = anc/8;
+        y = 0;
+    }
+#endif //AVX512F
+
     // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
     for (; y < anr / 4; y += 4) {
 
@@ -4099,7 +4807,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c
         }
 
         // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
-        for (int64_t x = 0; x < nc / 8; x++) {
+        for (int64_t x = xstart; x < nc / 8; x++) {
 
             const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
 
@@ -4433,7 +5141,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c
 
         const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
 
-        for (int64_t x = 0; x < nc / 8; x++) {
+        for (int64_t x = xstart; x < nc / 8; x++) {
 
             const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
index 09f8382b9..4b688a67e 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
@@ -425,6 +425,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
         }
         case GGML_OP_IM2COL_BACK:
             return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
+        case GGML_OP_GET_ROWS_BACK:
+            return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16;
         case GGML_OP_OUT_PROD:
             return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
                 src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
index a718b6a12..8284a0017 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
@@ -729,7 +729,13 @@ struct ggml_cuda_graph {
     bool disable_due_to_failed_graph_capture = false;
     int number_consecutive_updates = 0;
     std::vector ggml_graph_properties;
-    std::vector updated_kernel_arg;
+    bool use_cpy_indirection = false;
+    std::vector cpy_dest_ptrs;
+    char ** dest_ptrs_d;
+    int dest_ptrs_size = 0;
+    // Index to allow each cpy kernel to be aware of it's position within the graph
+    // relative to other cpy nodes.
+    int graph_cpynode_index = -1;
 #endif
 };
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
index 8396df28c..ed25646e8 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
@@ -39,16 +39,18 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
 }
 
 template 
-static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
                                    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                   const int nb12, const int nb13) {
+                                   const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
     const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
         return;
     }
 
+    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
+
     // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
     // then combine those indices with the corresponding byte offsets to get the total offsets
     const int64_t i03 = i/(ne00 * ne01 * ne02);
@@ -295,16 +297,18 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
 }
 
 template 
-static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13) {
+                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
+    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
+
     const int i03 = i/(ne00 * ne01 * ne02);
     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
@@ -321,16 +325,18 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
 }
 
 template 
-static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
+static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13) {
+                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
+    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
+
     const int i03 = i/(ne00 * ne01 * ne02);
     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
@@ -346,76 +352,97 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
+// Copy destination pointers to GPU to be available when pointer indirection is in use
+
+void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
+#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+    if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
+        CUDA_CHECK(cudaStreamSynchronize(stream));
+        if (cuda_graph->dest_ptrs_d != nullptr) {
+            CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
+        }
+        CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
+        cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
+    }
+    // copy destination pointers to GPU
+    CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
+    cuda_graph->graph_cpynode_index = 0; // reset index
+#else
+    GGML_UNUSED(cuda_graph); GGML_UNUSED(host_dest_ptrs);
+    GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream);
+#endif
+}
+
 static void ggml_cpy_f16_f32_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_f32_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_bf16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q8_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK8_0 == 0);
     const int num_blocks = ne / QK8_0;
     cpy_f32_q<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q8_0_f32_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = ne;
     cpy_q_f32<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q4_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK4_0 == 0);
     const int num_blocks = ne / QK4_0;
     cpy_f32_q<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q4_0_f32_cuda(
@@ -424,22 +451,22 @@ static void ggml_cpy_q4_0_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream) {
+    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
     const int num_blocks = ne;
     cpy_q_f32, QK4_0><<>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q4_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK4_1 == 0);
     const int num_blocks = ne / QK4_1;
     cpy_f32_q<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q4_1_f32_cuda(
@@ -448,22 +475,22 @@ static void ggml_cpy_q4_1_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream) {
+    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
     const int num_blocks = ne;
     cpy_q_f32, QK4_1><<>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q5_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK5_0 == 0);
     const int num_blocks = ne / QK5_0;
     cpy_f32_q<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q5_0_f32_cuda(
@@ -472,22 +499,22 @@ static void ggml_cpy_q5_0_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream) {
+    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
     const int num_blocks = ne;
     cpy_q_f32, QK5_0><<>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q5_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK5_1 == 0);
     const int num_blocks = ne / QK5_1;
     cpy_f32_q<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q5_1_f32_cuda(
@@ -496,35 +523,35 @@ static void ggml_cpy_q5_1_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream) {
+    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
     const int num_blocks = ne;
     cpy_q_f32, QK5_1><<>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_iq4_nl_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK4_NL == 0);
     const int num_blocks = ne / QK4_NL;
     cpy_f32_q<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f16_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
-void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
 
@@ -558,53 +585,68 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     char * src0_ddc = (char *) src0->data;
     char * src1_ddc = (char *) src1->data;
 
+    char ** dest_ptrs_d = nullptr;
+    int graph_cpynode_index = -1;
+#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+    if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
+        dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
+        graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
+    }
+#endif
     if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
         GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
         CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
-        ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
-        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
-        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
-        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
-        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
-        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
-        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else {
         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
     }
+#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+    if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
+        ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
+    }
+#endif
+
 }
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    ggml_cuda_cpy(ctx, src0, dst);
+    bool disable_indirection = true;
+    ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
 }
 
 void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cuh
index 28b06cdda..0bd3c0c6f 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cuh
@@ -2,8 +2,10 @@
 
 #define CUDA_CPY_BLOCK_SIZE 64
 
-void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1,  bool disable_indirection = false);
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
+
+void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
index a44788db2..0fef9522d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -96,31 +96,32 @@ int ggml_cuda_get_device() {
 
 static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
     ggml_cuda_set_device(device);
-#if defined(GGML_USE_HIP) && defined(GGML_HIP_UMA)
-    auto res = hipMallocManaged(ptr, size);
-    if (res == hipSuccess) {
-        // if error we "need" to know why...
-        CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
-    }
-    return res;
-#else
-
-#if !defined(GGML_USE_HIP)
     cudaError_t err;
     if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
     {
         err = cudaMallocManaged(ptr, size);
+#if defined(GGML_USE_HIP)
+        if (err == hipSuccess) {
+            CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
+        }
+
+        // fall back to cudaMalloc if not supported (e.g. on Windows)
+        if (err == hipErrorNotSupported) {
+            static bool warned_unsupported = false;
+            if (!warned_unsupported) {
+                GGML_LOG_WARN("hipMallocManaged unsupported, falling back to hipMalloc.\n");
+                warned_unsupported = true;
+            }
+
+            err = cudaMalloc(ptr, size);
+        }
+#endif // defined(GGML_USE_HIP)
     }
     else
     {
         err = cudaMalloc(ptr, size);
     }
     return err;
-#else
-    return cudaMalloc(ptr, size);
-#endif // !defined(GGML_USE_HIP)
-
-#endif
 }
 
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
@@ -2341,11 +2342,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ARGSORT:
             ggml_cuda_op_argsort(ctx, dst);
             break;
-#if !defined(GGML_DISABLE_FLASH_ATTN)
         case GGML_OP_FLASH_ATTN_EXT:
             ggml_cuda_flash_attn_ext(ctx, dst);
             break;
-#endif
         case GGML_OP_CROSS_ENTROPY_LOSS:
             ggml_cuda_cross_entropy_loss(ctx, dst);
             break;
@@ -2477,10 +2476,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
 
 #ifdef USE_CUDA_GRAPH
 static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
-    std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
+    bool use_cuda_graph) {
 
     // Loop over nodes in GGML graph to obtain info needed for CUDA graph
-    cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+    cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
+
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -2498,7 +2498,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
         if (node->op == GGML_OP_MUL_MAT_ID) {
             use_cuda_graph = false; // This node type is not supported by CUDA graph capture
 #ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
 #endif
         }
 
@@ -2512,8 +2512,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
         }
 
         if (node->op == GGML_OP_CPY) {
-            // store the copy op parameter which changes with each token.
-            cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+
+            // Store the pointers which are updated for each token, such that these can be sent
+            // to the device and accessed using indirection from CUDA graph
+            cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
+
             // store a pointer to each copy op CUDA kernel to identify it later
             void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
             if (!ptr) {
@@ -2521,10 +2524,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #ifndef NDEBUG
                 GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
 #endif
-            } else {
-                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
-                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
-                }
             }
         }
 
@@ -2533,6 +2532,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
         }
     }
 
+    if (use_cuda_graph) {
+        cuda_ctx->cuda_graph->use_cpy_indirection = true;
+        // copy pointers to GPU so they can be accessed via indirection within CUDA graph
+        ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
+    }
+
     return use_cuda_graph;
 }
 
@@ -2587,51 +2592,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
     return true;
 }
 
-static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
-
-    if (cuda_graph_update_required) {
-        // Extract nodes from graph
-        // First call with null argument gets number of nodes in graph
-        CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-        // Subsequent call with non-null argument gets nodes
-        cuda_ctx->cuda_graph->nodes.clear();
-        cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
-        cuda_ctx->cuda_graph->params.clear();
-        cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
-        if (cuda_ctx->cuda_graph->num_nodes > 0) {
-            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
-
-            // Loop over nodes, and extract kernel parameters from each node
-            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                cudaGraphNodeType node_type;
-                CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
-                if (node_type == cudaGraphNodeTypeKernel) {
-                    cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
-                    if (stat == cudaErrorInvalidDeviceFunction) {
-                        // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
-                        // We don't need to update blas nodes, so clear error and move on.
-                        (void)cudaGetLastError();
-                    } else {
-                        GGML_ASSERT(stat == cudaSuccess);
-                    }
-                }
-            }
-        }
-    } else {
-        // One of the arguments to the copy kernel is updated for each token, hence we need to
-        // replace that argument with the updated value in the CUDA graph
-        // on update steps, the live parameters will already be captured
-        int k = 0;
-        for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-            if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
-                char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
-                *(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
-                CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
-            }
-        }
-    }
-}
-
 static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
 
     bool cuda_graph_update_required = false;
@@ -2691,8 +2651,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 #endif
 
 static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
-   [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs,  bool & graph_evaluated_or_captured, bool & use_cuda_graph,
-    bool & cuda_graph_update_required) {
+    bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
 
     while (!graph_evaluated_or_captured) {
         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -2742,13 +2701,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
         if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
             CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
         }
-
-        // Perform update to graph (if required for this token), and change copy parameter (required for every token)
-        maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
-
-        // Update graph executable
-        update_cuda_graph_executable(cuda_ctx);
-
+        if (cuda_graph_update_required) { // Update graph executable
+            update_cuda_graph_executable(cuda_ctx);
+        }
         // Launch graph
         CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
 #else
@@ -2762,10 +2717,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
 
     ggml_cuda_set_device(cuda_ctx->device);
 
-    // vector of pointers to CUDA cpy kernels, which are required to identify
-    // kernel parameters which need updated in the graph for each token
-    std::vector ggml_cuda_cpy_fn_ptrs;
-
 #ifdef USE_CUDA_GRAPH
     static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
 
@@ -2799,8 +2750,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
     if (use_cuda_graph) {
         cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
 
-        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
-                             ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
+        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
 
         // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
         if (use_cuda_graph && cuda_graph_update_required) {
@@ -2821,6 +2771,10 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
     }
 
+    if (!use_cuda_graph) {
+        cuda_ctx->cuda_graph->use_cpy_indirection = false;
+    }
+
 #else
     bool use_cuda_graph = false;
     bool cuda_graph_update_required = false;
@@ -2828,7 +2782,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
 
     bool graph_evaluated_or_captured = false;
 
-    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
+    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
 
     return GGML_STATUS_SUCCESS;
 }
@@ -3290,6 +3244,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
+            if (op->src[0]->ne[0] == 576) {
+                // DeepSeek MLA
+                return false;
+            }
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
index 420b41b8d..1a28831b7 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
@@ -71,6 +71,8 @@
 #define cudaLaunchHostFunc hipLaunchHostFunc
 #define cudaMalloc hipMalloc
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#define cudaMallocManaged hipMallocManaged
+#define cudaMemAdvise hipMemAdvise
 #define cudaMemcpy hipMemcpy
 #define cudaMemcpyAsync hipMemcpyAsync
 #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
diff --git a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
index e3762649f..1fe8fe3b8 100644
--- a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
@@ -89,10 +89,6 @@ endif()
 
 add_compile_definitions(GGML_USE_HIP)
 
-if (GGML_HIP_UMA)
-    add_compile_definitions(GGML_HIP_UMA)
-endif()
-
 if (GGML_CUDA_FORCE_MMQ)
     add_compile_definitions(GGML_CUDA_FORCE_MMQ)
 endif()
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
index eaf6a23d8..297933adb 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
@@ -6051,6 +6051,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]]         kernel flash_at
 template [[host_name("kernel_flash_attn_ext_f16_h192")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6061,6 +6062,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 #endif
 
 template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6071,6 +6073,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6080,6 +6083,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6089,6 +6093,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6098,6 +6103,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6107,6 +6113,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q8_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #undef FA_TYPES
 
@@ -6464,6 +6471,16 @@ kernel void kernel_flash_attn_ext_vec(
 
 typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
 
+template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
 template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
@@ -6504,6 +6521,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
 template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 
+template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
 #undef FA_TYPES
 
 template
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
index fea505214..b2e95a66c 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
@@ -355,6 +355,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@@ -363,6 +364,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -371,6 +373,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@@ -379,6 +382,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@@ -387,6 +391,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@@ -395,6 +400,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@@ -403,6 +409,14 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
@@ -431,6 +445,13 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_SET_I32,
     GGML_METAL_KERNEL_TYPE_SET_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -1014,6 +1035,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,         flash_attn_ext_f16_h192,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,  flash_attn_ext_f16_hk192_hv128,  has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,         flash_attn_ext_f16_h256,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,  flash_attn_ext_f16_hk576_hv512,  has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,         flash_attn_ext_bf16_h64,         has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,         flash_attn_ext_bf16_h80,         has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,         flash_attn_ext_bf16_h96,         has_simdgroup_mm && use_bfloat);
@@ -1022,6 +1044,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,        flash_attn_ext_bf16_h192,        has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,        flash_attn_ext_bf16_h256,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,         flash_attn_ext_q4_0_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,         flash_attn_ext_q4_0_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,         flash_attn_ext_q4_0_h96,         has_simdgroup_mm);
@@ -1030,6 +1053,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,        flash_attn_ext_q4_0_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,        flash_attn_ext_q4_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,         flash_attn_ext_q4_1_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,         flash_attn_ext_q4_1_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,         flash_attn_ext_q4_1_h96,         has_simdgroup_mm);
@@ -1038,6 +1062,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,        flash_attn_ext_q4_1_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,        flash_attn_ext_q4_1_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,         flash_attn_ext_q5_0_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,         flash_attn_ext_q5_0_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,         flash_attn_ext_q5_0_h96,         has_simdgroup_mm);
@@ -1046,6 +1071,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,        flash_attn_ext_q5_0_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,        flash_attn_ext_q5_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,         flash_attn_ext_q5_1_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,         flash_attn_ext_q5_1_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,         flash_attn_ext_q5_1_h96,         has_simdgroup_mm);
@@ -1054,6 +1080,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,        flash_attn_ext_q5_1_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,        flash_attn_ext_q5_1_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,         flash_attn_ext_q8_0_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,         flash_attn_ext_q8_0_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,         flash_attn_ext_q8_0_h96,         has_simdgroup_mm);
@@ -1062,6 +1089,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,        flash_attn_ext_q8_0_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,        flash_attn_ext_q8_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,      flash_attn_ext_vec_f16_h96,      has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,     flash_attn_ext_vec_bf16_h96,     has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,     flash_attn_ext_vec_q4_0_h96,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,     flash_attn_ext_vec_q4_1_h96,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,     flash_attn_ext_vec_q5_0_h96,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,     flash_attn_ext_vec_q5_1_h96,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,     flash_attn_ext_vec_q8_0_h96,     has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,     flash_attn_ext_vec_f16_h128,     has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,    flash_attn_ext_vec_bf16_h128,    has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,    flash_attn_ext_vec_q4_0_h128,    has_simdgroup_reduction);
@@ -1090,6 +1125,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,    flash_attn_ext_vec_q5_0_h256,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,    flash_attn_ext_vec_q5_1_h256,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,    flash_attn_ext_vec_q8_0_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,     flash_attn_ext_vec_f16_hk576_hv512,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,    flash_attn_ext_vec_bf16_hk576_hv512,    has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,    flash_attn_ext_vec_q4_0_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,    flash_attn_ext_vec_q4_1_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,    flash_attn_ext_vec_q5_0_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,    flash_attn_ext_vec_q5_1_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,    flash_attn_ext_vec_q8_0_hk576_hv512,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32,                         set_f32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32,                         set_i32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                     cpy_f32_f32,                     true);
@@ -1357,6 +1399,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 // TODO: not sure if it is worth adding kernels for this size
                 return false;
             }
+            if (op->src[0]->ne[0] == 576) {
+                // DeepSeek sizes
+                // TODO: disabled for now, until optmized
+                return false;
+            }
             if (op->src[1]->type != op->src[2]->type) {
                 return false;
             }
@@ -3891,12 +3938,14 @@ static void ggml_metal_encode_node(
                 // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
                 //       for now avoiding mainly to keep the number of templates/kernels a bit lower
                 //       these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
-                if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
+                if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
                     switch (src1->type) {
                         case GGML_TYPE_F16:
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
@@ -3919,6 +3968,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
@@ -3941,6 +3992,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
@@ -3963,6 +4016,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
@@ -3985,6 +4040,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
@@ -4007,6 +4064,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
@@ -4029,6 +4088,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
@@ -4058,6 +4119,24 @@ static void ggml_metal_encode_node(
                     use_vec_kernel = true;
 
                     switch (ne00) {
+                        case 96:
+                            {
+                                switch (src1->type) {
+                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
+                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
+                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
+                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
+                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
+                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
+                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
+                                    default:
+                                        {
+                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                            GGML_LOG_ERROR("add template specialization for this type\n");
+                                            GGML_ABORT("add template specialization for this type");
+                                        }
+                                }
+                            } break;
                         case 128:
                             {
                                 switch (src1->type) {
@@ -4130,12 +4209,36 @@ static void ggml_metal_encode_node(
                                         }
                                 }
                             } break;
+                        case 576:
+                            {
+                                if (ne20 == 512) {
+                                    switch (src1->type) {
+                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
+                                        default:
+                                            {
+                                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                                GGML_LOG_ERROR("add template specialization for this type\n");
+                                                GGML_ABORT("add template specialization for this type");
+                                            }
+                                    }
+                                } else {
+                                    GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
+                                    GGML_LOG_ERROR("add template specialization for this size\n");
+                                    GGML_ABORT("add template specialization for this size");
+                                }
+                            } break;
                         default:
-                                  {
-                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                      GGML_ABORT("add template specialization for this size");
-                                  }
+                            {
+                                GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                GGML_LOG_ERROR("add template specialization for this size\n");
+                                GGML_ABORT("add template specialization for this size");
+                            }
                     }
                 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
index ede9d1e6c..71f0f97ff 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
@@ -3598,6 +3598,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]]         kernel flash_at
 template [[host_name("kernel_flash_attn_ext_f16_h192")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3608,6 +3609,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 #endif
 
 template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3618,6 +3620,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3627,6 +3630,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3636,6 +3640,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3645,6 +3650,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3654,6 +3660,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q8_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #undef FA_TYPES
 
@@ -4011,6 +4018,16 @@ kernel void kernel_flash_attn_ext_vec(
 
 typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
 
+template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
 template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
@@ -4051,6 +4068,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
 template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 
+template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
 #undef FA_TYPES
 
 template