divisor.flux1.math

 1# SPDX-License-Identifier:Apache-2.0
 2# original BFL Flux code from https://github.com/black-forest-labs/flux
 3
 4from einops import rearrange
 5import torch
 6from torch import Tensor
 7
 8
 9def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
10    q, k = apply_rope(q, k, pe)
11
12    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
13    x = rearrange(x, "B H L D -> B L (H D)")
14
15    return x
16
17
18def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
19    assert dim % 2 == 0
20    scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
21    omega = 1.0 / (theta**scale)
22    out = torch.einsum("...n,d->...nd", pos, omega)
23    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
24    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
25    return out.float()
26
27
28def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
29    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
30    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
31    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
32    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
33    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
10def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
11    q, k = apply_rope(q, k, pe)
12
13    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
14    x = rearrange(x, "B H L D -> B L (H D)")
15
16    return x
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
19def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
20    assert dim % 2 == 0
21    scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
22    omega = 1.0 / (theta**scale)
23    out = torch.einsum("...n,d->...nd", pos, omega)
24    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
25    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
26    return out.float()
def apply_rope( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
29def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
30    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
31    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
32    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
33    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
34    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)