divisor.flux1.layers

  1# SPDX-License-Identifier:Apache-2.0
  2# original BFL Flux code from https://github.com/black-forest-labs/flux
  3
  4from dataclasses import dataclass
  5import math
  6
  7from einops import rearrange
  8import torch
  9from torch import Tensor, nn
 10
 11from divisor.flux1.math import attention, rope
 12
 13
 14class EmbedND(nn.Module):
 15    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
 16        super().__init__()
 17        self.dim = dim
 18        self.theta = theta
 19        self.axes_dim = axes_dim
 20
 21    def forward(self, ids: Tensor) -> Tensor:
 22        n_axes = ids.shape[-1]
 23        emb = torch.cat(
 24            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
 25            dim=-3,
 26        )
 27
 28        return emb.unsqueeze(1)
 29
 30
 31def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
 32    """
 33    Create sinusoidal timestep embeddings.
 34    :param t: a 1-D Tensor of N indices, one per batch element.
 35                      These may be fractional.
 36    :param dim: the dimension of the output.
 37    :param max_period: controls the minimum frequency of the embeddings.
 38    :return: an (N, D) Tensor of positional embeddings.
 39    """
 40    t = time_factor * t
 41    half = dim // 2
 42    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
 43
 44    args = t[:, None].float() * freqs[None]
 45    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
 46    if dim % 2:
 47        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 48    if torch.is_floating_point(t):
 49        embedding = embedding.to(t)
 50    return embedding
 51
 52
 53class MLPEmbedder(nn.Module):
 54    def __init__(self, in_dim: int, hidden_dim: int):
 55        super().__init__()
 56        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
 57        self.silu = nn.SiLU()
 58        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
 59
 60    def forward(self, x: Tensor) -> Tensor:
 61        return self.out_layer(self.silu(self.in_layer(x)))
 62
 63
 64class RMSNorm(torch.nn.Module):
 65    def __init__(self, dim: int):
 66        super().__init__()
 67        self.scale = nn.Parameter(torch.ones(dim))
 68
 69    def forward(self, x: Tensor):
 70        x_dtype = x.dtype
 71        x = x.float()
 72        rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
 73        return (x * rrms).to(dtype=x_dtype) * self.scale
 74
 75
 76class QKNorm(torch.nn.Module):
 77    def __init__(self, dim: int):
 78        super().__init__()
 79        self.query_norm = RMSNorm(dim)
 80        self.key_norm = RMSNorm(dim)
 81
 82    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
 83        q = self.query_norm(q)
 84        k = self.key_norm(k)
 85        return q.to(v), k.to(v)
 86
 87
 88class SelfAttention(nn.Module):
 89    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
 90        super().__init__()
 91        self.num_heads = num_heads
 92        head_dim = dim // num_heads
 93
 94        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 95        self.norm = QKNorm(head_dim)
 96        self.proj = nn.Linear(dim, dim)
 97
 98    def forward(self, x: Tensor, pe: Tensor) -> Tensor:
 99        qkv = self.qkv(x)
100        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
101        q, k = self.norm(q, k, v)
102        x = attention(q, k, v, pe=pe)
103        x = self.proj(x)
104        return x
105
106
107@dataclass
108class ModulationOut:
109    shift: Tensor
110    scale: Tensor
111    gate: Tensor
112
113
114class Modulation(nn.Module):
115    def __init__(self, dim: int, double: bool):
116        super().__init__()
117        self.is_double = double
118        self.multiplier = 6 if double else 3
119        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
120
121    def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
122        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
123
124        return (
125            ModulationOut(*out[:3]),
126            ModulationOut(*out[3:]) if self.is_double else None,
127        )
128
129
130class DoubleStreamBlock(nn.Module):
131    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
132        super().__init__()
133
134        mlp_hidden_dim = int(hidden_size * mlp_ratio)
135        self.num_heads = num_heads
136        self.hidden_size = hidden_size
137        self.img_mod = Modulation(hidden_size, double=True)
138        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
139        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
140
141        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142        self.img_mlp = nn.Sequential(
143            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
144            nn.GELU(approximate="tanh"),
145            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
146        )
147
148        self.txt_mod = Modulation(hidden_size, double=True)
149        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
150        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
151
152        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153        self.txt_mlp = nn.Sequential(
154            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
155            nn.GELU(approximate="tanh"),
156            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
157        )
158
159    def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
160        img_mod1, img_mod2 = self.img_mod(vec)
161        txt_mod1, txt_mod2 = self.txt_mod(vec)
162
163        # prepare image for attention
164        img_modulated = self.img_norm1(img)
165        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
166        img_qkv = self.img_attn.qkv(img_modulated)
167        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
168        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
169
170        # prepare txt for attention
171        txt_modulated = self.txt_norm1(txt)
172        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
173        txt_qkv = self.txt_attn.qkv(txt_modulated)
174        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
175        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
176
177        # run actual attention
178        q = torch.cat((txt_q, img_q), dim=2)
179        k = torch.cat((txt_k, img_k), dim=2)
180        v = torch.cat((txt_v, img_v), dim=2)
181
182        attn = attention(q, k, v, pe=pe)
183        txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
184
185        # calculate the img blocks
186        img = img + img_mod1.gate * self.img_attn.proj(img_attn)
187        img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
188
189        # calculate the txt blocks
190        txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
191        txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
192        return img, txt
193
194
195class SingleStreamBlock(nn.Module):
196    """A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface."""
197
198    def __init__(
199        self,
200        hidden_size: int,
201        num_heads: int,
202        mlp_ratio: float = 4.0,
203        qk_scale: float | None = None,
204    ):
205        super().__init__()
206        self.hidden_dim = hidden_size
207        self.num_heads = num_heads
208        head_dim = hidden_size // num_heads
209        self.scale = qk_scale or head_dim**-0.5
210
211        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
212        # qkv and mlp_in
213        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
214        # proj and mlp_out
215        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
216
217        self.norm = QKNorm(head_dim)
218
219        self.hidden_size = hidden_size
220        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
221
222        self.mlp_act = nn.GELU(approximate="tanh")
223        self.modulation = Modulation(hidden_size, double=False)
224
225    def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
226        mod, _ = self.modulation(vec)
227        x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
228        qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
229
230        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
231        q, k = self.norm(q, k, v)
232
233        # compute attention
234        attn = attention(q, k, v, pe=pe)
235        # compute activation in mlp stream, cat again and run second linear layer
236        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
237        return x + mod.gate * output
238
239
240class LastLayer(nn.Module):
241    def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
242        super().__init__()
243        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
244        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
245        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
246
247    def forward(self, x: Tensor, vec: Tensor) -> Tensor:
248        shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
249        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
250        x = self.linear(x)
251        return x
class EmbedND(torch.nn.modules.module.Module):
15class EmbedND(nn.Module):
16    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
17        super().__init__()
18        self.dim = dim
19        self.theta = theta
20        self.axes_dim = axes_dim
21
22    def forward(self, ids: Tensor) -> Tensor:
23        n_axes = ids.shape[-1]
24        emb = torch.cat(
25            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
26            dim=-3,
27        )
28
29        return emb.unsqueeze(1)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

EmbedND(dim: int, theta: int, axes_dim: list[int])
16    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
17        super().__init__()
18        self.dim = dim
19        self.theta = theta
20        self.axes_dim = axes_dim

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dim
theta
axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
22    def forward(self, ids: Tensor) -> Tensor:
23        n_axes = ids.shape[-1]
24        emb = torch.cat(
25            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
26            dim=-3,
27        )
28
29        return emb.unsqueeze(1)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def timestep_embedding(t: torch.Tensor, dim, max_period=10000, time_factor: float = 1000.0):
32def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
33    """
34    Create sinusoidal timestep embeddings.
35    :param t: a 1-D Tensor of N indices, one per batch element.
36                      These may be fractional.
37    :param dim: the dimension of the output.
38    :param max_period: controls the minimum frequency of the embeddings.
39    :return: an (N, D) Tensor of positional embeddings.
40    """
41    t = time_factor * t
42    half = dim // 2
43    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
44
45    args = t[:, None].float() * freqs[None]
46    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
47    if dim % 2:
48        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
49    if torch.is_floating_point(t):
50        embedding = embedding.to(t)
51    return embedding

Create sinusoidal timestep embeddings.

Parameters
  • t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
  • dim: the dimension of the output.
  • max_period: controls the minimum frequency of the embeddings.
Returns

an (N, D) Tensor of positional embeddings.

class MLPEmbedder(torch.nn.modules.module.Module):
54class MLPEmbedder(nn.Module):
55    def __init__(self, in_dim: int, hidden_dim: int):
56        super().__init__()
57        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
58        self.silu = nn.SiLU()
59        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
60
61    def forward(self, x: Tensor) -> Tensor:
62        return self.out_layer(self.silu(self.in_layer(x)))

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

MLPEmbedder(in_dim: int, hidden_dim: int)
55    def __init__(self, in_dim: int, hidden_dim: int):
56        super().__init__()
57        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
58        self.silu = nn.SiLU()
59        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_layer
silu
out_layer
def forward(self, x: torch.Tensor) -> torch.Tensor:
61    def forward(self, x: Tensor) -> Tensor:
62        return self.out_layer(self.silu(self.in_layer(x)))

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class RMSNorm(torch.nn.modules.module.Module):
65class RMSNorm(torch.nn.Module):
66    def __init__(self, dim: int):
67        super().__init__()
68        self.scale = nn.Parameter(torch.ones(dim))
69
70    def forward(self, x: Tensor):
71        x_dtype = x.dtype
72        x = x.float()
73        rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
74        return (x * rrms).to(dtype=x_dtype) * self.scale

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

RMSNorm(dim: int)
66    def __init__(self, dim: int):
67        super().__init__()
68        self.scale = nn.Parameter(torch.ones(dim))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

scale
def forward(self, x: torch.Tensor):
70    def forward(self, x: Tensor):
71        x_dtype = x.dtype
72        x = x.float()
73        rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
74        return (x * rrms).to(dtype=x_dtype) * self.scale

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class QKNorm(torch.nn.modules.module.Module):
77class QKNorm(torch.nn.Module):
78    def __init__(self, dim: int):
79        super().__init__()
80        self.query_norm = RMSNorm(dim)
81        self.key_norm = RMSNorm(dim)
82
83    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
84        q = self.query_norm(q)
85        k = self.key_norm(k)
86        return q.to(v), k.to(v)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

QKNorm(dim: int)
78    def __init__(self, dim: int):
79        super().__init__()
80        self.query_norm = RMSNorm(dim)
81        self.key_norm = RMSNorm(dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

query_norm
key_norm
def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
83    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
84        q = self.query_norm(q)
85        k = self.key_norm(k)
86        return q.to(v), k.to(v)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class SelfAttention(torch.nn.modules.module.Module):
 89class SelfAttention(nn.Module):
 90    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
 91        super().__init__()
 92        self.num_heads = num_heads
 93        head_dim = dim // num_heads
 94
 95        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 96        self.norm = QKNorm(head_dim)
 97        self.proj = nn.Linear(dim, dim)
 98
 99    def forward(self, x: Tensor, pe: Tensor) -> Tensor:
100        qkv = self.qkv(x)
101        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
102        q, k = self.norm(q, k, v)
103        x = attention(q, k, v, pe=pe)
104        x = self.proj(x)
105        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SelfAttention(dim: int, num_heads: int = 8, qkv_bias: bool = False)
90    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
91        super().__init__()
92        self.num_heads = num_heads
93        head_dim = dim // num_heads
94
95        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
96        self.norm = QKNorm(head_dim)
97        self.proj = nn.Linear(dim, dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_heads
qkv
norm
proj
def forward(self, x: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
 99    def forward(self, x: Tensor, pe: Tensor) -> Tensor:
100        qkv = self.qkv(x)
101        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
102        q, k = self.norm(q, k, v)
103        x = attention(q, k, v, pe=pe)
104        x = self.proj(x)
105        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

@dataclass
class ModulationOut:
108@dataclass
109class ModulationOut:
110    shift: Tensor
111    scale: Tensor
112    gate: Tensor
ModulationOut(shift: torch.Tensor, scale: torch.Tensor, gate: torch.Tensor)
shift: torch.Tensor
scale: torch.Tensor
gate: torch.Tensor
class Modulation(torch.nn.modules.module.Module):
115class Modulation(nn.Module):
116    def __init__(self, dim: int, double: bool):
117        super().__init__()
118        self.is_double = double
119        self.multiplier = 6 if double else 3
120        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
121
122    def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
123        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
124
125        return (
126            ModulationOut(*out[:3]),
127            ModulationOut(*out[3:]) if self.is_double else None,
128        )

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Modulation(dim: int, double: bool)
116    def __init__(self, dim: int, double: bool):
117        super().__init__()
118        self.is_double = double
119        self.multiplier = 6 if double else 3
120        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

is_double
multiplier
lin
def forward( self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]:
122    def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
123        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
124
125        return (
126            ModulationOut(*out[:3]),
127            ModulationOut(*out[3:]) if self.is_double else None,
128        )

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class DoubleStreamBlock(torch.nn.modules.module.Module):
131class DoubleStreamBlock(nn.Module):
132    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
133        super().__init__()
134
135        mlp_hidden_dim = int(hidden_size * mlp_ratio)
136        self.num_heads = num_heads
137        self.hidden_size = hidden_size
138        self.img_mod = Modulation(hidden_size, double=True)
139        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
140        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
141
142        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
143        self.img_mlp = nn.Sequential(
144            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
145            nn.GELU(approximate="tanh"),
146            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
147        )
148
149        self.txt_mod = Modulation(hidden_size, double=True)
150        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
151        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
152
153        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
154        self.txt_mlp = nn.Sequential(
155            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
156            nn.GELU(approximate="tanh"),
157            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
158        )
159
160    def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
161        img_mod1, img_mod2 = self.img_mod(vec)
162        txt_mod1, txt_mod2 = self.txt_mod(vec)
163
164        # prepare image for attention
165        img_modulated = self.img_norm1(img)
166        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
167        img_qkv = self.img_attn.qkv(img_modulated)
168        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
169        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
170
171        # prepare txt for attention
172        txt_modulated = self.txt_norm1(txt)
173        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
174        txt_qkv = self.txt_attn.qkv(txt_modulated)
175        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
176        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
177
178        # run actual attention
179        q = torch.cat((txt_q, img_q), dim=2)
180        k = torch.cat((txt_k, img_k), dim=2)
181        v = torch.cat((txt_v, img_v), dim=2)
182
183        attn = attention(q, k, v, pe=pe)
184        txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
185
186        # calculate the img blocks
187        img = img + img_mod1.gate * self.img_attn.proj(img_attn)
188        img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
189
190        # calculate the txt blocks
191        txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
192        txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
193        return img, txt

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

DoubleStreamBlock( hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False)
132    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
133        super().__init__()
134
135        mlp_hidden_dim = int(hidden_size * mlp_ratio)
136        self.num_heads = num_heads
137        self.hidden_size = hidden_size
138        self.img_mod = Modulation(hidden_size, double=True)
139        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
140        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
141
142        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
143        self.img_mlp = nn.Sequential(
144            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
145            nn.GELU(approximate="tanh"),
146            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
147        )
148
149        self.txt_mod = Modulation(hidden_size, double=True)
150        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
151        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
152
153        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
154        self.txt_mlp = nn.Sequential(
155            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
156            nn.GELU(approximate="tanh"),
157            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
158        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_heads
hidden_size
img_mod
img_norm1
img_attn
img_norm2
img_mlp
txt_mod
txt_norm1
txt_attn
txt_norm2
txt_mlp
def forward( self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
160    def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
161        img_mod1, img_mod2 = self.img_mod(vec)
162        txt_mod1, txt_mod2 = self.txt_mod(vec)
163
164        # prepare image for attention
165        img_modulated = self.img_norm1(img)
166        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
167        img_qkv = self.img_attn.qkv(img_modulated)
168        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
169        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
170
171        # prepare txt for attention
172        txt_modulated = self.txt_norm1(txt)
173        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
174        txt_qkv = self.txt_attn.qkv(txt_modulated)
175        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
176        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
177
178        # run actual attention
179        q = torch.cat((txt_q, img_q), dim=2)
180        k = torch.cat((txt_k, img_k), dim=2)
181        v = torch.cat((txt_v, img_v), dim=2)
182
183        attn = attention(q, k, v, pe=pe)
184        txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
185
186        # calculate the img blocks
187        img = img + img_mod1.gate * self.img_attn.proj(img_attn)
188        img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
189
190        # calculate the txt blocks
191        txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
192        txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
193        return img, txt

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class SingleStreamBlock(torch.nn.modules.module.Module):
196class SingleStreamBlock(nn.Module):
197    """A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface."""
198
199    def __init__(
200        self,
201        hidden_size: int,
202        num_heads: int,
203        mlp_ratio: float = 4.0,
204        qk_scale: float | None = None,
205    ):
206        super().__init__()
207        self.hidden_dim = hidden_size
208        self.num_heads = num_heads
209        head_dim = hidden_size // num_heads
210        self.scale = qk_scale or head_dim**-0.5
211
212        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
213        # qkv and mlp_in
214        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
215        # proj and mlp_out
216        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
217
218        self.norm = QKNorm(head_dim)
219
220        self.hidden_size = hidden_size
221        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
222
223        self.mlp_act = nn.GELU(approximate="tanh")
224        self.modulation = Modulation(hidden_size, double=False)
225
226    def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
227        mod, _ = self.modulation(vec)
228        x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
229        qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
230
231        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
232        q, k = self.norm(q, k, v)
233
234        # compute attention
235        attn = attention(q, k, v, pe=pe)
236        # compute activation in mlp stream, cat again and run second linear layer
237        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
238        return x + mod.gate * output

A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface.

SingleStreamBlock( hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, qk_scale: float | None = None)
199    def __init__(
200        self,
201        hidden_size: int,
202        num_heads: int,
203        mlp_ratio: float = 4.0,
204        qk_scale: float | None = None,
205    ):
206        super().__init__()
207        self.hidden_dim = hidden_size
208        self.num_heads = num_heads
209        head_dim = hidden_size // num_heads
210        self.scale = qk_scale or head_dim**-0.5
211
212        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
213        # qkv and mlp_in
214        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
215        # proj and mlp_out
216        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
217
218        self.norm = QKNorm(head_dim)
219
220        self.hidden_size = hidden_size
221        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
222
223        self.mlp_act = nn.GELU(approximate="tanh")
224        self.modulation = Modulation(hidden_size, double=False)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

hidden_dim
num_heads
scale
mlp_hidden_dim
linear1
linear2
norm
hidden_size
pre_norm
mlp_act
modulation
def forward( self, x: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
226    def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
227        mod, _ = self.modulation(vec)
228        x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
229        qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
230
231        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
232        q, k = self.norm(q, k, v)
233
234        # compute attention
235        attn = attention(q, k, v, pe=pe)
236        # compute activation in mlp stream, cat again and run second linear layer
237        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
238        return x + mod.gate * output

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class LastLayer(torch.nn.modules.module.Module):
241class LastLayer(nn.Module):
242    def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
243        super().__init__()
244        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
245        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
246        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
247
248    def forward(self, x: Tensor, vec: Tensor) -> Tensor:
249        shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
250        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
251        x = self.linear(x)
252        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

LastLayer(hidden_size: int, patch_size: int, out_channels: int)
242    def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
243        super().__init__()
244        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
245        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
246        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

norm_final
linear
adaLN_modulation
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
248    def forward(self, x: Tensor, vec: Tensor) -> Tensor:
249        shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
250        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
251        x = self.linear(x)
252        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.