Skip to content

vllm.model_executor.layers.rotary_embedding.xdrope

XDRotaryEmbedding

Bases: DynamicNTKAlphaRotaryEmbedding

DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.

Based on the original DynamicNTKAlphaRotaryEmbedding implementation.

Source code in vllm/model_executor/layers/rotary_embedding/xdrope.py
class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
    """DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.

    Based on the original DynamicNTKAlphaRotaryEmbedding implementation.
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        scaling_alpha: float,
        dtype: torch.dtype,
        xdrope_section: list[int],
    ) -> None:
        self.xdrope_section = xdrope_section
        super().__init__(
            head_size,
            rotary_dim,
            max_position_embeddings,
            base,
            is_neox_style,
            scaling_alpha,
            dtype,
        )

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None = None,
        offsets: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """PyTorch-native implementation equivalent to forward().

        Args:
            positions:
                [4, num_tokens] (P/W/H/T positions with multimodal inputs)
            query: [num_tokens, num_heads * head_size]
            key: [num_tokens, num_kv_heads * head_size]
        """
        assert positions.ndim == 2
        assert key is not None

        num_tokens = positions.shape[-1]
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        cos = torch.cat(
            [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
        )
        sin = torch.cat(
            [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
        )

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., : self.rotary_dim]
        query_pass = query[..., self.rotary_dim :]
        query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., : self.rotary_dim]
        key_pass = key[..., self.rotary_dim :]
        key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    @staticmethod
    def get_next_input_positions(
        context_len: int,
        seq_len: int,
        xd_sections: int = 4,
    ) -> list[list[int]]:
        return [list(range(context_len, seq_len)) for _ in range(xd_sections)]

    @staticmethod
    def get_next_input_positions_tensor(
        out: np.ndarray,
        out_offset: int,
        context_len: int,
        num_new_tokens: int,
    ):
        values = np.arange(
            context_len,
            context_len + num_new_tokens,
            dtype=out.dtype,
        )
        out[:, out_offset : out_offset + num_new_tokens] = values

xdrope_section instance-attribute

xdrope_section = xdrope_section

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_alpha: float,
    dtype: dtype,
    xdrope_section: list[int],
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/xdrope.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_alpha: float,
    dtype: torch.dtype,
    xdrope_section: list[int],
) -> None:
    self.xdrope_section = xdrope_section
    super().__init__(
        head_size,
        rotary_dim,
        max_position_embeddings,
        base,
        is_neox_style,
        scaling_alpha,
        dtype,
    )

forward

forward(
    positions: Tensor,
    query: Tensor,
    key: Tensor | None = None,
    offsets: Tensor | None = None,
) -> tuple[Tensor, Tensor | None]

PyTorch-native implementation equivalent to forward().

Parameters:

Name Type Description Default
positions Tensor

[4, num_tokens] (P/W/H/T positions with multimodal inputs)

required
query Tensor

[num_tokens, num_heads * head_size]

required
key Tensor | None

[num_tokens, num_kv_heads * head_size]

None
Source code in vllm/model_executor/layers/rotary_embedding/xdrope.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor | None = None,
    offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """PyTorch-native implementation equivalent to forward().

    Args:
        positions:
            [4, num_tokens] (P/W/H/T positions with multimodal inputs)
        query: [num_tokens, num_heads * head_size]
        key: [num_tokens, num_kv_heads * head_size]
    """
    assert positions.ndim == 2
    assert key is not None

    num_tokens = positions.shape[-1]
    cos_sin = self.cos_sin_cache[positions]
    cos, sin = cos_sin.chunk(2, dim=-1)
    cos = torch.cat(
        [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
    )
    sin = torch.cat(
        [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
    )

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., : self.rotary_dim]
    query_pass = query[..., self.rotary_dim :]
    query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    key_shape = key.shape
    key = key.view(num_tokens, -1, self.head_size)
    key_rot = key[..., : self.rotary_dim]
    key_pass = key[..., self.rotary_dim :]
    key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
    key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

get_next_input_positions staticmethod

get_next_input_positions(
    context_len: int, seq_len: int, xd_sections: int = 4
) -> list[list[int]]
Source code in vllm/model_executor/layers/rotary_embedding/xdrope.py
@staticmethod
def get_next_input_positions(
    context_len: int,
    seq_len: int,
    xd_sections: int = 4,
) -> list[list[int]]:
    return [list(range(context_len, seq_len)) for _ in range(xd_sections)]

get_next_input_positions_tensor staticmethod

get_next_input_positions_tensor(
    out: ndarray,
    out_offset: int,
    context_len: int,
    num_new_tokens: int,
)
Source code in vllm/model_executor/layers/rotary_embedding/xdrope.py
@staticmethod
def get_next_input_positions_tensor(
    out: np.ndarray,
    out_offset: int,
    context_len: int,
    num_new_tokens: int,
):
    values = np.arange(
        context_len,
        context_len + num_new_tokens,
        dtype=out.dtype,
    )
    out[:, out_offset : out_offset + num_new_tokens] = values