Skip to content

vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse

logger module-attribute

logger = init_logger(__name__)

ROCMAiterMLASparseBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
class ROCMAiterMLASparseBackend(AttentionBackend):
    accept_output_buffer: bool = True

    @staticmethod
    def get_name() -> str:
        return "ROCM_AITER_MLA_SPARSE"

    @staticmethod
    def get_metadata_cls() -> type[AttentionMetadata]:
        return ROCMAiterMLASparseMetadata

    @staticmethod
    def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
        return ROCMAiterMLASparseMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
        return ROCMAiterMLASparseImpl

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        return (num_blocks, block_size, head_size)

    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [576]

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[
    ROCMAiterMLASparseMetadataBuilder
]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
    return ROCMAiterMLASparseMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[ROCMAiterMLASparseImpl]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
    return ROCMAiterMLASparseImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,  # assumed to be 1 for MLA
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    return (num_blocks, block_size, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
    return ROCMAiterMLASparseMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_name() -> str:
    return "ROCM_AITER_MLA_SPARSE"

get_supported_dtypes classmethod

get_supported_dtypes() -> list[dtype]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16]

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    return [576]

ROCMAiterMLASparseImpl

Bases: MLACommonBaseImpl[ROCMAiterMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: str,
        kv_sharing_target_layer_name: str | None,
        # MLA Specific Arguments
        topk_indice_buffer: torch.Tensor | None = None,
        indexer: Optional["Indexer"] = None,
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )
        self.softmax_scale = scale
        assert indexer is not None
        self.topk_indices_buffer = indexer.topk_indices_buffer
        self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()

    def _forward_bf16_kv(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: ROCMAiterMLASparseMetadata,
    ) -> torch.Tensor:
        num_tokens = q.shape[0]
        kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
            -1, 1, kv_c_and_k_pe_cache.shape[-1]
        )

        topk_indices = topk_indices.view(num_tokens, 1, -1)
        output = reference_mla_sparse_prefill(
            q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512
        )[0]
        return output[:, : self.num_heads, :]

    def forward(
        self,
        layer: AttentionLayer,
        q: torch.Tensor,
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
        attn_metadata: ROCMAiterMLASparseMetadata,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
        # MQA 576/512 approach for both prefill and decode

        assert output is not None, "Output tensor must be provided."

        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported for ROCMAiterMLASparse"
            )

        if attn_metadata is None:
            # The zero fill is required when used with DP + EP
            # to ensure all ranks within a DP group compute the
            # same expert outputs.
            return output.fill_(0)

        num_actual_toks = attn_metadata.num_actual_tokens

        # Inputs and outputs may be padded for CUDA graphs

        q = q[:num_actual_toks, ...]
        k_c_normed = k_c_normed[:num_actual_toks, ...]
        k_pe = k_pe[:num_actual_toks, ...]

        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # Convert from (B, N, P) to (N, B, P)
        q_nope = q_nope.transpose(0, 1)
        if self.is_fp8bmm_enabled:
            # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
            ql_nope = rocm_aiter_ops.triton_fp8_bmm(
                q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
            )
        else:
            # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
            ql_nope = torch.bmm(q_nope, self.W_UK_T)
            # Convert from (N, B, L) to (B, N, L)
            ql_nope = ql_nope.transpose(0, 1)

        topk_indices = self.topk_indices_buffer[:num_actual_toks]

        topk_indices_global = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
        )

        q = torch.cat([ql_nope, q_pe], dim=-1)

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                k_c_normed,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=self.kv_cache_dtype,
                scale=layer._k_scale,
            )

        attn_out = self._forward_bf16_kv(
            q, kv_cache, topk_indices_global, attn_metadata
        )

        self._v_up_proj(attn_out, out=output[:num_actual_toks])
        return output

is_fp8bmm_enabled instance-attribute

is_fp8bmm_enabled = is_fp8bmm_enabled()

softmax_scale instance-attribute

softmax_scale = scale

topk_indices_buffer instance-attribute

topk_indices_buffer = topk_indices_buffer

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    topk_indice_buffer: Tensor | None = None,
    indexer: Optional[Indexer] = None,
    **mla_args,
) -> None
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    # MLA Specific Arguments
    topk_indice_buffer: torch.Tensor | None = None,
    indexer: Optional["Indexer"] = None,
    **mla_args,
) -> None:
    super().__init__(
        num_heads,
        head_size,
        scale,
        num_kv_heads,
        alibi_slopes,
        sliding_window,
        kv_cache_dtype,
        logits_soft_cap,
        attn_type,
        kv_sharing_target_layer_name,
        **mla_args,
    )
    self.softmax_scale = scale
    assert indexer is not None
    self.topk_indices_buffer = indexer.topk_indices_buffer
    self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()

_forward_bf16_kv

_forward_bf16_kv(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
    attn_metadata: ROCMAiterMLASparseMetadata,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def _forward_bf16_kv(
    self,
    q: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    topk_indices: torch.Tensor,
    attn_metadata: ROCMAiterMLASparseMetadata,
) -> torch.Tensor:
    num_tokens = q.shape[0]
    kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
        -1, 1, kv_c_and_k_pe_cache.shape[-1]
    )

    topk_indices = topk_indices.view(num_tokens, 1, -1)
    output = reference_mla_sparse_prefill(
        q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512
    )[0]
    return output[:, : self.num_heads, :]

forward

forward(
    layer: AttentionLayer,
    q: Tensor,
    k_c_normed: Tensor,
    k_pe: Tensor,
    kv_cache: Tensor,
    attn_metadata: ROCMAiterMLASparseMetadata,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def forward(
    self,
    layer: AttentionLayer,
    q: torch.Tensor,
    k_c_normed: torch.Tensor,  # key in unified attn
    k_pe: torch.Tensor,  # value in unified attn
    kv_cache: torch.Tensor,
    attn_metadata: ROCMAiterMLASparseMetadata,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
    # MQA 576/512 approach for both prefill and decode

    assert output is not None, "Output tensor must be provided."

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported for ROCMAiterMLASparse"
        )

    if attn_metadata is None:
        # The zero fill is required when used with DP + EP
        # to ensure all ranks within a DP group compute the
        # same expert outputs.
        return output.fill_(0)

    num_actual_toks = attn_metadata.num_actual_tokens

    # Inputs and outputs may be padded for CUDA graphs

    q = q[:num_actual_toks, ...]
    k_c_normed = k_c_normed[:num_actual_toks, ...]
    k_pe = k_pe[:num_actual_toks, ...]

    q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
    # Convert from (B, N, P) to (N, B, P)
    q_nope = q_nope.transpose(0, 1)
    if self.is_fp8bmm_enabled:
        # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
        ql_nope = rocm_aiter_ops.triton_fp8_bmm(
            q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
        )
    else:
        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
        ql_nope = torch.bmm(q_nope, self.W_UK_T)
        # Convert from (N, B, L) to (B, N, L)
        ql_nope = ql_nope.transpose(0, 1)

    topk_indices = self.topk_indices_buffer[:num_actual_toks]

    topk_indices_global = triton_convert_req_index_to_global_index(
        attn_metadata.req_id_per_token,
        attn_metadata.block_table,
        topk_indices,
        BLOCK_SIZE=attn_metadata.block_size,
        NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
    )

    q = torch.cat([ql_nope, q_pe], dim=-1)

    # write the latent and rope to kv cache
    if kv_cache.numel() > 0:
        ops.concat_and_cache_mla(
            k_c_normed,
            k_pe.squeeze(1),
            kv_cache,
            attn_metadata.slot_mapping.flatten(),
            kv_cache_dtype=self.kv_cache_dtype,
            scale=layer._k_scale,
        )

    attn_out = self._forward_bf16_kv(
        q, kv_cache, topk_indices_global, attn_metadata
    )

    self._v_up_proj(attn_out, out=output[:num_actual_toks])
    return output

ROCMAiterMLASparseMetadata dataclass

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@dataclass
class ROCMAiterMLASparseMetadata:
    num_reqs: int
    max_query_len: int
    max_seq_len: int

    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

    block_table: torch.Tensor
    req_id_per_token: torch.Tensor
    block_size: int = 1
    topk_tokens: int = 2048

block_size class-attribute instance-attribute

block_size: int = 1

block_table instance-attribute

block_table: Tensor

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_reqs instance-attribute

num_reqs: int

query_start_loc instance-attribute

query_start_loc: Tensor

req_id_per_token instance-attribute

req_id_per_token: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

topk_tokens class-attribute instance-attribute

topk_tokens: int = 2048

__init__

__init__(
    num_reqs: int,
    max_query_len: int,
    max_seq_len: int,
    num_actual_tokens: int,
    query_start_loc: Tensor,
    slot_mapping: Tensor,
    block_table: Tensor,
    req_id_per_token: Tensor,
    block_size: int = 1,
    topk_tokens: int = 2048,
) -> None

ROCMAiterMLASparseMetadataBuilder dataclass

Bases: AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@dataclass
class ROCMAiterMLASparseMetadataBuilder(
    AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
):
    cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.kv_cache_spec = kv_cache_spec
        self.model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
        self.device = device

        self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
        self.mla_dims = get_mla_dims(self.model_config)
        self.topk_tokens = vllm_config.model_config.hf_config.index_topk
        self.topk_tokens_tensor = torch.tensor(
            [self.topk_tokens], device=device, dtype=torch.int32
        )
        self.max_model_len_tensor = torch.tensor(
            [self.model_config.max_model_len], device=device, dtype=torch.int32
        )
        # this is ignored by `flash_mla_with_kvcache` if indices not None
        self.dummy_block_table = torch.empty(
            (1, 1), dtype=torch.int32, device=self.device
        )

        self.req_id_per_token_buffer = torch.empty(
            (vllm_config.scheduler_config.max_num_batched_tokens,),
            dtype=torch.int32,
            device=device,
        )

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> ROCMAiterMLASparseMetadata:
        num_tokens = common_attn_metadata.num_actual_tokens
        starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
        seg_lengths = np.diff(starts)
        req_id_per_token = np.repeat(
            np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
        )
        # Zero-fill for cudagraphs
        self.req_id_per_token_buffer.fill_(0)
        self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
            torch.from_numpy(req_id_per_token), non_blocking=True
        )
        req_id_per_token = self.req_id_per_token_buffer[:num_tokens]

        metadata = ROCMAiterMLASparseMetadata(
            num_reqs=common_attn_metadata.num_reqs,
            max_query_len=common_attn_metadata.max_query_len,
            max_seq_len=common_attn_metadata.max_seq_len,
            num_actual_tokens=common_attn_metadata.num_actual_tokens,
            query_start_loc=common_attn_metadata.query_start_loc,
            slot_mapping=common_attn_metadata.slot_mapping,
            block_table=common_attn_metadata.block_table_tensor,
            req_id_per_token=req_id_per_token,
            block_size=self.kv_cache_spec.block_size,
            topk_tokens=self.topk_tokens,
        )
        return metadata

cudagraph_support class-attribute

cudagraph_support: AttentionCGSupport = NEVER

device instance-attribute

device = device

dummy_block_table instance-attribute

dummy_block_table = empty(
    (1, 1), dtype=int32, device=device
)

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

max_model_len_tensor instance-attribute

max_model_len_tensor = tensor(
    [max_model_len], device=device, dtype=int32
)

mla_dims instance-attribute

mla_dims = get_mla_dims(model_config)

model_config instance-attribute

model_config = model_config

num_heads instance-attribute

num_heads = get_num_attention_heads(parallel_config)

req_id_per_token_buffer instance-attribute

req_id_per_token_buffer = empty(
    (max_num_batched_tokens,), dtype=int32, device=device
)

topk_tokens instance-attribute

topk_tokens = index_topk

topk_tokens_tensor instance-attribute

topk_tokens_tensor = tensor(
    [topk_tokens], device=device, dtype=int32
)

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    self.kv_cache_spec = kv_cache_spec
    self.model_config = vllm_config.model_config
    parallel_config = vllm_config.parallel_config
    self.device = device

    self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
    self.mla_dims = get_mla_dims(self.model_config)
    self.topk_tokens = vllm_config.model_config.hf_config.index_topk
    self.topk_tokens_tensor = torch.tensor(
        [self.topk_tokens], device=device, dtype=torch.int32
    )
    self.max_model_len_tensor = torch.tensor(
        [self.model_config.max_model_len], device=device, dtype=torch.int32
    )
    # this is ignored by `flash_mla_with_kvcache` if indices not None
    self.dummy_block_table = torch.empty(
        (1, 1), dtype=torch.int32, device=self.device
    )

    self.req_id_per_token_buffer = torch.empty(
        (vllm_config.scheduler_config.max_num_batched_tokens,),
        dtype=torch.int32,
        device=device,
    )

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> ROCMAiterMLASparseMetadata
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> ROCMAiterMLASparseMetadata:
    num_tokens = common_attn_metadata.num_actual_tokens
    starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
    seg_lengths = np.diff(starts)
    req_id_per_token = np.repeat(
        np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
    )
    # Zero-fill for cudagraphs
    self.req_id_per_token_buffer.fill_(0)
    self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
        torch.from_numpy(req_id_per_token), non_blocking=True
    )
    req_id_per_token = self.req_id_per_token_buffer[:num_tokens]

    metadata = ROCMAiterMLASparseMetadata(
        num_reqs=common_attn_metadata.num_reqs,
        max_query_len=common_attn_metadata.max_query_len,
        max_seq_len=common_attn_metadata.max_seq_len,
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        query_start_loc=common_attn_metadata.query_start_loc,
        slot_mapping=common_attn_metadata.slot_mapping,
        block_table=common_attn_metadata.block_table_tensor,
        req_id_per_token=req_id_per_token,
        block_size=self.kv_cache_spec.block_size,
        topk_tokens=self.topk_tokens,
    )
    return metadata

reference_mla_sparse_prefill

reference_mla_sparse_prefill(
    q: Tensor,
    kv: Tensor,
    indices: Tensor,
    sm_scale: float,
    d_v: int,
) -> tuple[Tensor, Tensor]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def reference_mla_sparse_prefill(
    q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int
) -> tuple[torch.Tensor, torch.Tensor]:
    import math

    def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
        return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)

    skv = kv.shape[0]
    sq = q.shape[0]
    topk = indices.shape[-1]
    dqk = q.shape[-1]
    indices = indices[:, 0, :]  # [s_q, topk]
    invalid_indices_mask = (indices < 0) | (indices >= skv)
    indices[invalid_indices_mask] = 0
    qs = q  # [s_q, h_q, d_qk]
    kvs = kv[:, 0, :][indices].view(sq, topk, dqk)  # [s_q, topk, d_qk]

    attn_score = (qs @ kvs.transpose(1, 2)).float()  # [s_q, h_q, topk]
    attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
    attn_score *= sm_scale * math.log2(math.e)
    lse = log2sumexp2(attn_score, dim=-1)  # [s_q, h_q]
    attn_score = torch.exp2(attn_score - lse.unsqueeze(-1))  # [s_q, h_q, topk]
    result = attn_score.to(q.dtype) @ kvs[:, :, :d_v]
    return (result, lse)