divisor.flux1.math
1# SPDX-License-Identifier:Apache-2.0 2# original BFL Flux code from https://github.com/black-forest-labs/flux 3 4from einops import rearrange 5import torch 6from torch import Tensor 7 8 9def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: 10 q, k = apply_rope(q, k, pe) 11 12 x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 13 x = rearrange(x, "B H L D -> B L (H D)") 14 15 return x 16 17 18def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 19 assert dim % 2 == 0 20 scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim 21 omega = 1.0 / (theta**scale) 22 out = torch.einsum("...n,d->...nd", pos, omega) 23 out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 24 out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 25 return out.float() 26 27 28def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 29 xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 30 xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 31 xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 32 xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 33 return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def
attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
def
rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
19def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 20 assert dim % 2 == 0 21 scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim 22 omega = 1.0 / (theta**scale) 23 out = torch.einsum("...n,d->...nd", pos, omega) 24 out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 25 out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 26 return out.float()
def
apply_rope( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
29def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 30 xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 31 xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 32 xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 33 xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 34 return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)