divisor.flux2.model

  1# SPDX-License-Identifier:Apache-2.0
  2# original BFL Flux code from https://github.com/black-forest-labs/flux2
  3
  4from dataclasses import dataclass, field
  5import math
  6
  7from einops import rearrange
  8import torch
  9from torch import Tensor, nn
 10
 11from divisor.layer_dropout import process_blocks_with_dropout
 12
 13
 14@dataclass
 15class Flux2Params:
 16    in_channels: int = 128
 17    context_in_dim: int = 15360
 18    hidden_size: int = 6144
 19    num_heads: int = 48
 20    depth: int = 8
 21    depth_single_blocks: int = 48
 22    axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
 23    theta: int = 2000
 24    mlp_ratio: float = 3.0
 25
 26
 27class Flux2(nn.Module):
 28    def __init__(self, params: Flux2Params):
 29        super().__init__()
 30
 31        self.in_channels = params.in_channels
 32        self.out_channels = params.in_channels
 33        if params.hidden_size % params.num_heads != 0:
 34            raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
 35        pe_dim = params.hidden_size // params.num_heads
 36        if sum(params.axes_dim) != pe_dim:
 37            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
 38        self.hidden_size = params.hidden_size
 39        self.num_heads = params.num_heads
 40        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
 41        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
 42        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
 43        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
 44        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
 45
 46        self.double_blocks = nn.ModuleList(
 47            [
 48                DoubleStreamBlock(
 49                    self.hidden_size,
 50                    self.num_heads,
 51                    mlp_ratio=params.mlp_ratio,
 52                )
 53                for _ in range(params.depth)
 54            ]
 55        )
 56
 57        self.single_blocks = nn.ModuleList(
 58            [
 59                SingleStreamBlock(
 60                    self.hidden_size,
 61                    self.num_heads,
 62                    mlp_ratio=params.mlp_ratio,
 63                )
 64                for _ in range(params.depth_single_blocks)
 65            ]
 66        )
 67
 68        self.double_stream_modulation_img = Modulation(
 69            self.hidden_size,
 70            double=True,
 71            disable_bias=True,
 72        )
 73        self.double_stream_modulation_txt = Modulation(
 74            self.hidden_size,
 75            double=True,
 76            disable_bias=True,
 77        )
 78        self.single_stream_modulation = Modulation(self.hidden_size, double=False, disable_bias=True)
 79
 80        self.final_layer = LastLayer(
 81            self.hidden_size,
 82            self.out_channels,
 83        )
 84
 85    def forward(
 86        self,
 87        x: Tensor,
 88        x_ids: Tensor,
 89        timesteps: Tensor,
 90        ctx: Tensor,
 91        ctx_ids: Tensor,
 92        guidance: Tensor,
 93        layer_dropouts: list[int] | None = None,
 94    ):
 95        num_txt_tokens = ctx.shape[1]
 96
 97        timestep_emb = timestep_embedding(timesteps, 256)
 98        vec = self.time_in(timestep_emb)
 99        guidance_emb = timestep_embedding(guidance, 256)
100        vec = vec + self.guidance_in(guidance_emb)
101
102        double_block_mod_img = self.double_stream_modulation_img(vec)
103        double_block_mod_txt = self.double_stream_modulation_txt(vec)
104        single_block_mod, _ = self.single_stream_modulation(vec)
105
106        img = self.img_in(x)
107        txt = self.txt_in(ctx)
108
109        pe_x = self.pe_embedder(x_ids)
110        pe_ctx = self.pe_embedder(ctx_ids)
111
112        img, txt = process_blocks_with_dropout(
113            self.double_blocks,
114            layer_dropouts,  # Would need to be added to flux2 forward signature
115            0,
116            "double",
117            lambda block, state: block(state[0], state[1], pe_x, pe_ctx, double_block_mod_img, double_block_mod_txt),
118            (img, txt),
119        )
120
121        img = torch.cat((txt, img), dim=1)
122        pe = torch.cat((pe_ctx, pe_x), dim=2)
123
124        img = process_blocks_with_dropout(
125            self.single_blocks,
126            layer_dropouts,  # Would need to be added to flux2 forward signature
127            len(self.double_blocks),
128            "single",
129            lambda block, state: block(state, pe, single_block_mod),
130            img,
131        )
132        img = img[:, num_txt_tokens:, ...]
133
134        img = self.final_layer(img, vec)
135        return img
136
137
138class SelfAttention(nn.Module):
139    def __init__(
140        self,
141        dim: int,
142        num_heads: int = 8,
143    ):
144        super().__init__()
145        self.num_heads = num_heads
146        head_dim = dim // num_heads
147        self.qkv = nn.Linear(dim, dim * 3, bias=False)
148
149        self.norm = QKNorm(head_dim)
150        self.proj = nn.Linear(dim, dim, bias=False)
151
152
153class SiLUActivation(nn.Module):
154    def __init__(self):
155        super().__init__()
156        self.gate_fn = nn.SiLU()
157
158    def forward(self, x: Tensor) -> Tensor:
159        x1, x2 = x.chunk(2, dim=-1)
160        return self.gate_fn(x1) * x2
161
162
163class Modulation(nn.Module):
164    def __init__(self, dim: int, double: bool, disable_bias: bool = False):
165        super().__init__()
166        self.is_double = double
167        self.multiplier = 6 if double else 3
168        self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
169
170    def forward(self, vec: torch.Tensor):
171        out = self.lin(nn.functional.silu(vec))
172        if out.ndim == 2:
173            out = out[:, None, :]
174        out = out.chunk(self.multiplier, dim=-1)
175        return out[:3], out[3:] if self.is_double else None
176
177
178class LastLayer(nn.Module):
179    def __init__(
180        self,
181        hidden_size: int,
182        out_channels: int,
183    ):
184        super().__init__()
185        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
186        self.linear = nn.Linear(hidden_size, out_channels, bias=False)
187        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False))
188
189    def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
190        mod = self.adaLN_modulation(vec)
191        shift, scale = mod.chunk(2, dim=-1)
192        if shift.ndim == 2:
193            shift = shift[:, None, :]
194            scale = scale[:, None, :]
195        x = (1 + scale) * self.norm_final(x) + shift
196        x = self.linear(x)
197        return x
198
199
200class SingleStreamBlock(nn.Module):
201    def __init__(
202        self,
203        hidden_size: int,
204        num_heads: int,
205        mlp_ratio: float = 4.0,
206    ):
207        super().__init__()
208
209        self.hidden_dim = hidden_size
210        self.num_heads = num_heads
211        head_dim = hidden_size // num_heads
212        self.scale = head_dim**-0.5
213        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
214        self.mlp_mult_factor = 2
215
216        self.linear1 = nn.Linear(
217            hidden_size,
218            hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
219            bias=False,
220        )
221
222        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=False)
223
224        self.norm = QKNorm(head_dim)
225
226        self.hidden_size = hidden_size
227        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
228
229        self.mlp_act = SiLUActivation()
230
231    def forward(
232        self,
233        x: Tensor,
234        pe: Tensor,
235        mod: tuple[Tensor, Tensor],
236    ) -> Tensor:
237        mod_shift, mod_scale, mod_gate = mod  # type: ignore
238        x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
239
240        qkv, mlp = torch.split(
241            self.linear1(x_mod),
242            [3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
243            dim=-1,
244        )
245
246        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
247        q, k = self.norm(q, k, v)
248
249        attn = attention(q, k, v, pe)
250
251        # compute activation in mlp stream, cat again and run second linear layer
252        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
253        return x + mod_gate * output
254
255
256class DoubleStreamBlock(nn.Module):
257    def __init__(
258        self,
259        hidden_size: int,
260        num_heads: int,
261        mlp_ratio: float,
262    ):
263        super().__init__()
264        mlp_hidden_dim = int(hidden_size * mlp_ratio)
265        self.num_heads = num_heads
266        assert hidden_size % num_heads == 0, f"{hidden_size=} must be divisible by {num_heads=}"
267
268        self.hidden_size = hidden_size
269        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
270        self.mlp_mult_factor = 2
271
272        self.img_attn = SelfAttention(
273            dim=hidden_size,
274            num_heads=num_heads,
275        )
276
277        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
278        self.img_mlp = nn.Sequential(
279            nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
280            SiLUActivation(),
281            nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
282        )
283
284        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
285        self.txt_attn = SelfAttention(
286            dim=hidden_size,
287            num_heads=num_heads,
288        )
289
290        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
291        self.txt_mlp = nn.Sequential(
292            nn.Linear(
293                hidden_size,
294                mlp_hidden_dim * self.mlp_mult_factor,
295                bias=False,
296            ),
297            SiLUActivation(),
298            nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
299        )
300
301    def forward(
302        self,
303        img: Tensor,
304        txt: Tensor,
305        pe: Tensor,
306        pe_ctx: Tensor,
307        mod_img: tuple[Tensor, Tensor],
308        mod_txt: tuple[Tensor, Tensor],
309    ) -> tuple[Tensor, Tensor]:
310        img_mod1, img_mod2 = mod_img
311        txt_mod1, txt_mod2 = mod_txt
312
313        img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
314        img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
315        txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
316        txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
317
318        # prepare image for attention
319        img_modulated = self.img_norm1(img)
320        img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
321
322        img_qkv = self.img_attn.qkv(img_modulated)
323        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
324        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
325
326        # prepare txt for attention
327        txt_modulated = self.txt_norm1(txt)
328        txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
329
330        txt_qkv = self.txt_attn.qkv(txt_modulated)
331        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
332        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
333
334        q = torch.cat((txt_q, img_q), dim=2)
335        k = torch.cat((txt_k, img_k), dim=2)
336        v = torch.cat((txt_v, img_v), dim=2)
337
338        pe = torch.cat((pe_ctx, pe), dim=2)
339        attn = attention(q, k, v, pe)
340        txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :]
341
342        # calculate the img blocks
343        img = img + img_mod1_gate * self.img_attn.proj(img_attn)
344        img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift)
345
346        # calculate the txt blocks
347        txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
348        txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift)
349        return img, txt
350
351
352class MLPEmbedder(nn.Module):
353    def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
354        super().__init__()
355        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
356        self.silu = nn.SiLU()
357        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
358
359    def forward(self, x: Tensor) -> Tensor:
360        return self.out_layer(self.silu(self.in_layer(x)))
361
362
363class EmbedND(nn.Module):
364    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
365        super().__init__()
366        self.dim = dim
367        self.theta = theta
368        self.axes_dim = axes_dim
369
370    def forward(self, ids: Tensor) -> Tensor:
371        emb = torch.cat(
372            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(len(self.axes_dim))],
373            dim=-3,
374        )
375
376        return emb.unsqueeze(1)
377
378
379def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
380    """
381    Create sinusoidal timestep embeddings.
382    :param t: a 1-D Tensor of N indices, one per batch element.
383                      These may be fractional.
384    :param dim: the dimension of the output.
385    :param max_period: controls the minimum frequency of the embeddings.
386    :return: an (N, D) Tensor of positional embeddings.
387    """
388    t = time_factor * t
389    half = dim // 2
390    freqs = torch.exp(
391        -math.log(max_period) * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32) / half  # float32 originally
392    )
393
394    args = t[:, None].float() * freqs[None]
395    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
396    if dim % 2:
397        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
398    if torch.is_floating_point(t):
399        embedding = embedding.to(t)
400    return embedding
401
402
403class RMSNorm(torch.nn.Module):
404    def __init__(self, dim: int):
405        super().__init__()
406        self.scale = nn.Parameter(torch.ones(dim))
407
408    def forward(self, x: Tensor):
409        x_dtype = x.dtype
410        x = x.float()
411        rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
412        return (x * rrms).to(dtype=x_dtype) * self.scale
413
414
415class QKNorm(torch.nn.Module):
416    def __init__(self, dim: int):
417        super().__init__()
418        self.query_norm = RMSNorm(dim)
419        self.key_norm = RMSNorm(dim)
420
421    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
422        q = self.query_norm(q)
423        k = self.key_norm(k)
424        return q.to(v), k.to(v)
425
426
427def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
428    q, k = apply_rope(q, k, pe)
429
430    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
431    x = rearrange(x, "B H L D -> B L (H D)")
432
433    return x
434
435
436def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
437    assert dim % 2 == 0
438    scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
439    omega = 1.0 / (theta**scale)
440    out = torch.einsum("...n,d->...nd", pos, omega)
441    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
442    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
443    return out.float()
444
445
446def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
447    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
448    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
449    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
450    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
451    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
@dataclass
class Flux2Params:
15@dataclass
16class Flux2Params:
17    in_channels: int = 128
18    context_in_dim: int = 15360
19    hidden_size: int = 6144
20    num_heads: int = 48
21    depth: int = 8
22    depth_single_blocks: int = 48
23    axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
24    theta: int = 2000
25    mlp_ratio: float = 3.0
Flux2Params( in_channels: int = 128, context_in_dim: int = 15360, hidden_size: int = 6144, num_heads: int = 48, depth: int = 8, depth_single_blocks: int = 48, axes_dim: list[int] = <factory>, theta: int = 2000, mlp_ratio: float = 3.0)
in_channels: int = 128
context_in_dim: int = 15360
hidden_size: int = 6144
num_heads: int = 48
depth: int = 8
depth_single_blocks: int = 48
axes_dim: list[int]
theta: int = 2000
mlp_ratio: float = 3.0
class Flux2(torch.nn.modules.module.Module):
 28class Flux2(nn.Module):
 29    def __init__(self, params: Flux2Params):
 30        super().__init__()
 31
 32        self.in_channels = params.in_channels
 33        self.out_channels = params.in_channels
 34        if params.hidden_size % params.num_heads != 0:
 35            raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
 36        pe_dim = params.hidden_size // params.num_heads
 37        if sum(params.axes_dim) != pe_dim:
 38            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
 39        self.hidden_size = params.hidden_size
 40        self.num_heads = params.num_heads
 41        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
 42        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
 43        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
 44        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
 45        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
 46
 47        self.double_blocks = nn.ModuleList(
 48            [
 49                DoubleStreamBlock(
 50                    self.hidden_size,
 51                    self.num_heads,
 52                    mlp_ratio=params.mlp_ratio,
 53                )
 54                for _ in range(params.depth)
 55            ]
 56        )
 57
 58        self.single_blocks = nn.ModuleList(
 59            [
 60                SingleStreamBlock(
 61                    self.hidden_size,
 62                    self.num_heads,
 63                    mlp_ratio=params.mlp_ratio,
 64                )
 65                for _ in range(params.depth_single_blocks)
 66            ]
 67        )
 68
 69        self.double_stream_modulation_img = Modulation(
 70            self.hidden_size,
 71            double=True,
 72            disable_bias=True,
 73        )
 74        self.double_stream_modulation_txt = Modulation(
 75            self.hidden_size,
 76            double=True,
 77            disable_bias=True,
 78        )
 79        self.single_stream_modulation = Modulation(self.hidden_size, double=False, disable_bias=True)
 80
 81        self.final_layer = LastLayer(
 82            self.hidden_size,
 83            self.out_channels,
 84        )
 85
 86    def forward(
 87        self,
 88        x: Tensor,
 89        x_ids: Tensor,
 90        timesteps: Tensor,
 91        ctx: Tensor,
 92        ctx_ids: Tensor,
 93        guidance: Tensor,
 94        layer_dropouts: list[int] | None = None,
 95    ):
 96        num_txt_tokens = ctx.shape[1]
 97
 98        timestep_emb = timestep_embedding(timesteps, 256)
 99        vec = self.time_in(timestep_emb)
100        guidance_emb = timestep_embedding(guidance, 256)
101        vec = vec + self.guidance_in(guidance_emb)
102
103        double_block_mod_img = self.double_stream_modulation_img(vec)
104        double_block_mod_txt = self.double_stream_modulation_txt(vec)
105        single_block_mod, _ = self.single_stream_modulation(vec)
106
107        img = self.img_in(x)
108        txt = self.txt_in(ctx)
109
110        pe_x = self.pe_embedder(x_ids)
111        pe_ctx = self.pe_embedder(ctx_ids)
112
113        img, txt = process_blocks_with_dropout(
114            self.double_blocks,
115            layer_dropouts,  # Would need to be added to flux2 forward signature
116            0,
117            "double",
118            lambda block, state: block(state[0], state[1], pe_x, pe_ctx, double_block_mod_img, double_block_mod_txt),
119            (img, txt),
120        )
121
122        img = torch.cat((txt, img), dim=1)
123        pe = torch.cat((pe_ctx, pe_x), dim=2)
124
125        img = process_blocks_with_dropout(
126            self.single_blocks,
127            layer_dropouts,  # Would need to be added to flux2 forward signature
128            len(self.double_blocks),
129            "single",
130            lambda block, state: block(state, pe, single_block_mod),
131            img,
132        )
133        img = img[:, num_txt_tokens:, ...]
134
135        img = self.final_layer(img, vec)
136        return img

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Flux2(params: Flux2Params)
29    def __init__(self, params: Flux2Params):
30        super().__init__()
31
32        self.in_channels = params.in_channels
33        self.out_channels = params.in_channels
34        if params.hidden_size % params.num_heads != 0:
35            raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
36        pe_dim = params.hidden_size // params.num_heads
37        if sum(params.axes_dim) != pe_dim:
38            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
39        self.hidden_size = params.hidden_size
40        self.num_heads = params.num_heads
41        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
42        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
43        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
44        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
45        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
46
47        self.double_blocks = nn.ModuleList(
48            [
49                DoubleStreamBlock(
50                    self.hidden_size,
51                    self.num_heads,
52                    mlp_ratio=params.mlp_ratio,
53                )
54                for _ in range(params.depth)
55            ]
56        )
57
58        self.single_blocks = nn.ModuleList(
59            [
60                SingleStreamBlock(
61                    self.hidden_size,
62                    self.num_heads,
63                    mlp_ratio=params.mlp_ratio,
64                )
65                for _ in range(params.depth_single_blocks)
66            ]
67        )
68
69        self.double_stream_modulation_img = Modulation(
70            self.hidden_size,
71            double=True,
72            disable_bias=True,
73        )
74        self.double_stream_modulation_txt = Modulation(
75            self.hidden_size,
76            double=True,
77            disable_bias=True,
78        )
79        self.single_stream_modulation = Modulation(self.hidden_size, double=False, disable_bias=True)
80
81        self.final_layer = LastLayer(
82            self.hidden_size,
83            self.out_channels,
84        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_channels
out_channels
hidden_size
num_heads
pe_embedder
img_in
time_in
guidance_in
txt_in
double_blocks
single_blocks
double_stream_modulation_img
double_stream_modulation_txt
single_stream_modulation
final_layer
def forward( self, x: torch.Tensor, x_ids: torch.Tensor, timesteps: torch.Tensor, ctx: torch.Tensor, ctx_ids: torch.Tensor, guidance: torch.Tensor, layer_dropouts: list[int] | None = None):
 86    def forward(
 87        self,
 88        x: Tensor,
 89        x_ids: Tensor,
 90        timesteps: Tensor,
 91        ctx: Tensor,
 92        ctx_ids: Tensor,
 93        guidance: Tensor,
 94        layer_dropouts: list[int] | None = None,
 95    ):
 96        num_txt_tokens = ctx.shape[1]
 97
 98        timestep_emb = timestep_embedding(timesteps, 256)
 99        vec = self.time_in(timestep_emb)
100        guidance_emb = timestep_embedding(guidance, 256)
101        vec = vec + self.guidance_in(guidance_emb)
102
103        double_block_mod_img = self.double_stream_modulation_img(vec)
104        double_block_mod_txt = self.double_stream_modulation_txt(vec)
105        single_block_mod, _ = self.single_stream_modulation(vec)
106
107        img = self.img_in(x)
108        txt = self.txt_in(ctx)
109
110        pe_x = self.pe_embedder(x_ids)
111        pe_ctx = self.pe_embedder(ctx_ids)
112
113        img, txt = process_blocks_with_dropout(
114            self.double_blocks,
115            layer_dropouts,  # Would need to be added to flux2 forward signature
116            0,
117            "double",
118            lambda block, state: block(state[0], state[1], pe_x, pe_ctx, double_block_mod_img, double_block_mod_txt),
119            (img, txt),
120        )
121
122        img = torch.cat((txt, img), dim=1)
123        pe = torch.cat((pe_ctx, pe_x), dim=2)
124
125        img = process_blocks_with_dropout(
126            self.single_blocks,
127            layer_dropouts,  # Would need to be added to flux2 forward signature
128            len(self.double_blocks),
129            "single",
130            lambda block, state: block(state, pe, single_block_mod),
131            img,
132        )
133        img = img[:, num_txt_tokens:, ...]
134
135        img = self.final_layer(img, vec)
136        return img

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class SelfAttention(torch.nn.modules.module.Module):
139class SelfAttention(nn.Module):
140    def __init__(
141        self,
142        dim: int,
143        num_heads: int = 8,
144    ):
145        super().__init__()
146        self.num_heads = num_heads
147        head_dim = dim // num_heads
148        self.qkv = nn.Linear(dim, dim * 3, bias=False)
149
150        self.norm = QKNorm(head_dim)
151        self.proj = nn.Linear(dim, dim, bias=False)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SelfAttention(dim: int, num_heads: int = 8)
140    def __init__(
141        self,
142        dim: int,
143        num_heads: int = 8,
144    ):
145        super().__init__()
146        self.num_heads = num_heads
147        head_dim = dim // num_heads
148        self.qkv = nn.Linear(dim, dim * 3, bias=False)
149
150        self.norm = QKNorm(head_dim)
151        self.proj = nn.Linear(dim, dim, bias=False)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_heads
qkv
norm
proj
class SiLUActivation(torch.nn.modules.module.Module):
154class SiLUActivation(nn.Module):
155    def __init__(self):
156        super().__init__()
157        self.gate_fn = nn.SiLU()
158
159    def forward(self, x: Tensor) -> Tensor:
160        x1, x2 = x.chunk(2, dim=-1)
161        return self.gate_fn(x1) * x2

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SiLUActivation()
155    def __init__(self):
156        super().__init__()
157        self.gate_fn = nn.SiLU()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

gate_fn
def forward(self, x: torch.Tensor) -> torch.Tensor:
159    def forward(self, x: Tensor) -> Tensor:
160        x1, x2 = x.chunk(2, dim=-1)
161        return self.gate_fn(x1) * x2

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Modulation(torch.nn.modules.module.Module):
164class Modulation(nn.Module):
165    def __init__(self, dim: int, double: bool, disable_bias: bool = False):
166        super().__init__()
167        self.is_double = double
168        self.multiplier = 6 if double else 3
169        self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
170
171    def forward(self, vec: torch.Tensor):
172        out = self.lin(nn.functional.silu(vec))
173        if out.ndim == 2:
174            out = out[:, None, :]
175        out = out.chunk(self.multiplier, dim=-1)
176        return out[:3], out[3:] if self.is_double else None

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Modulation(dim: int, double: bool, disable_bias: bool = False)
165    def __init__(self, dim: int, double: bool, disable_bias: bool = False):
166        super().__init__()
167        self.is_double = double
168        self.multiplier = 6 if double else 3
169        self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

is_double
multiplier
lin
def forward(self, vec: torch.Tensor):
171    def forward(self, vec: torch.Tensor):
172        out = self.lin(nn.functional.silu(vec))
173        if out.ndim == 2:
174            out = out[:, None, :]
175        out = out.chunk(self.multiplier, dim=-1)
176        return out[:3], out[3:] if self.is_double else None

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class LastLayer(torch.nn.modules.module.Module):
179class LastLayer(nn.Module):
180    def __init__(
181        self,
182        hidden_size: int,
183        out_channels: int,
184    ):
185        super().__init__()
186        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
187        self.linear = nn.Linear(hidden_size, out_channels, bias=False)
188        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False))
189
190    def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
191        mod = self.adaLN_modulation(vec)
192        shift, scale = mod.chunk(2, dim=-1)
193        if shift.ndim == 2:
194            shift = shift[:, None, :]
195            scale = scale[:, None, :]
196        x = (1 + scale) * self.norm_final(x) + shift
197        x = self.linear(x)
198        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

LastLayer(hidden_size: int, out_channels: int)
180    def __init__(
181        self,
182        hidden_size: int,
183        out_channels: int,
184    ):
185        super().__init__()
186        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
187        self.linear = nn.Linear(hidden_size, out_channels, bias=False)
188        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

norm_final
linear
adaLN_modulation
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
190    def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
191        mod = self.adaLN_modulation(vec)
192        shift, scale = mod.chunk(2, dim=-1)
193        if shift.ndim == 2:
194            shift = shift[:, None, :]
195            scale = scale[:, None, :]
196        x = (1 + scale) * self.norm_final(x) + shift
197        x = self.linear(x)
198        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class SingleStreamBlock(torch.nn.modules.module.Module):
201class SingleStreamBlock(nn.Module):
202    def __init__(
203        self,
204        hidden_size: int,
205        num_heads: int,
206        mlp_ratio: float = 4.0,
207    ):
208        super().__init__()
209
210        self.hidden_dim = hidden_size
211        self.num_heads = num_heads
212        head_dim = hidden_size // num_heads
213        self.scale = head_dim**-0.5
214        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
215        self.mlp_mult_factor = 2
216
217        self.linear1 = nn.Linear(
218            hidden_size,
219            hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
220            bias=False,
221        )
222
223        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=False)
224
225        self.norm = QKNorm(head_dim)
226
227        self.hidden_size = hidden_size
228        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
229
230        self.mlp_act = SiLUActivation()
231
232    def forward(
233        self,
234        x: Tensor,
235        pe: Tensor,
236        mod: tuple[Tensor, Tensor],
237    ) -> Tensor:
238        mod_shift, mod_scale, mod_gate = mod  # type: ignore
239        x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
240
241        qkv, mlp = torch.split(
242            self.linear1(x_mod),
243            [3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
244            dim=-1,
245        )
246
247        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
248        q, k = self.norm(q, k, v)
249
250        attn = attention(q, k, v, pe)
251
252        # compute activation in mlp stream, cat again and run second linear layer
253        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
254        return x + mod_gate * output

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SingleStreamBlock(hidden_size: int, num_heads: int, mlp_ratio: float = 4.0)
202    def __init__(
203        self,
204        hidden_size: int,
205        num_heads: int,
206        mlp_ratio: float = 4.0,
207    ):
208        super().__init__()
209
210        self.hidden_dim = hidden_size
211        self.num_heads = num_heads
212        head_dim = hidden_size // num_heads
213        self.scale = head_dim**-0.5
214        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
215        self.mlp_mult_factor = 2
216
217        self.linear1 = nn.Linear(
218            hidden_size,
219            hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
220            bias=False,
221        )
222
223        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=False)
224
225        self.norm = QKNorm(head_dim)
226
227        self.hidden_size = hidden_size
228        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
229
230        self.mlp_act = SiLUActivation()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

hidden_dim
num_heads
scale
mlp_hidden_dim
mlp_mult_factor
linear1
linear2
norm
hidden_size
pre_norm
mlp_act
def forward( self, x: torch.Tensor, pe: torch.Tensor, mod: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
232    def forward(
233        self,
234        x: Tensor,
235        pe: Tensor,
236        mod: tuple[Tensor, Tensor],
237    ) -> Tensor:
238        mod_shift, mod_scale, mod_gate = mod  # type: ignore
239        x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
240
241        qkv, mlp = torch.split(
242            self.linear1(x_mod),
243            [3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
244            dim=-1,
245        )
246
247        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
248        q, k = self.norm(q, k, v)
249
250        attn = attention(q, k, v, pe)
251
252        # compute activation in mlp stream, cat again and run second linear layer
253        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
254        return x + mod_gate * output

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class DoubleStreamBlock(torch.nn.modules.module.Module):
257class DoubleStreamBlock(nn.Module):
258    def __init__(
259        self,
260        hidden_size: int,
261        num_heads: int,
262        mlp_ratio: float,
263    ):
264        super().__init__()
265        mlp_hidden_dim = int(hidden_size * mlp_ratio)
266        self.num_heads = num_heads
267        assert hidden_size % num_heads == 0, f"{hidden_size=} must be divisible by {num_heads=}"
268
269        self.hidden_size = hidden_size
270        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
271        self.mlp_mult_factor = 2
272
273        self.img_attn = SelfAttention(
274            dim=hidden_size,
275            num_heads=num_heads,
276        )
277
278        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
279        self.img_mlp = nn.Sequential(
280            nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
281            SiLUActivation(),
282            nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
283        )
284
285        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
286        self.txt_attn = SelfAttention(
287            dim=hidden_size,
288            num_heads=num_heads,
289        )
290
291        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
292        self.txt_mlp = nn.Sequential(
293            nn.Linear(
294                hidden_size,
295                mlp_hidden_dim * self.mlp_mult_factor,
296                bias=False,
297            ),
298            SiLUActivation(),
299            nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
300        )
301
302    def forward(
303        self,
304        img: Tensor,
305        txt: Tensor,
306        pe: Tensor,
307        pe_ctx: Tensor,
308        mod_img: tuple[Tensor, Tensor],
309        mod_txt: tuple[Tensor, Tensor],
310    ) -> tuple[Tensor, Tensor]:
311        img_mod1, img_mod2 = mod_img
312        txt_mod1, txt_mod2 = mod_txt
313
314        img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
315        img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
316        txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
317        txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
318
319        # prepare image for attention
320        img_modulated = self.img_norm1(img)
321        img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
322
323        img_qkv = self.img_attn.qkv(img_modulated)
324        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
325        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
326
327        # prepare txt for attention
328        txt_modulated = self.txt_norm1(txt)
329        txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
330
331        txt_qkv = self.txt_attn.qkv(txt_modulated)
332        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
333        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
334
335        q = torch.cat((txt_q, img_q), dim=2)
336        k = torch.cat((txt_k, img_k), dim=2)
337        v = torch.cat((txt_v, img_v), dim=2)
338
339        pe = torch.cat((pe_ctx, pe), dim=2)
340        attn = attention(q, k, v, pe)
341        txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :]
342
343        # calculate the img blocks
344        img = img + img_mod1_gate * self.img_attn.proj(img_attn)
345        img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift)
346
347        # calculate the txt blocks
348        txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
349        txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift)
350        return img, txt

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

DoubleStreamBlock(hidden_size: int, num_heads: int, mlp_ratio: float)
258    def __init__(
259        self,
260        hidden_size: int,
261        num_heads: int,
262        mlp_ratio: float,
263    ):
264        super().__init__()
265        mlp_hidden_dim = int(hidden_size * mlp_ratio)
266        self.num_heads = num_heads
267        assert hidden_size % num_heads == 0, f"{hidden_size=} must be divisible by {num_heads=}"
268
269        self.hidden_size = hidden_size
270        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
271        self.mlp_mult_factor = 2
272
273        self.img_attn = SelfAttention(
274            dim=hidden_size,
275            num_heads=num_heads,
276        )
277
278        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
279        self.img_mlp = nn.Sequential(
280            nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
281            SiLUActivation(),
282            nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
283        )
284
285        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
286        self.txt_attn = SelfAttention(
287            dim=hidden_size,
288            num_heads=num_heads,
289        )
290
291        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
292        self.txt_mlp = nn.Sequential(
293            nn.Linear(
294                hidden_size,
295                mlp_hidden_dim * self.mlp_mult_factor,
296                bias=False,
297            ),
298            SiLUActivation(),
299            nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
300        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_heads
hidden_size
img_norm1
mlp_mult_factor
img_attn
img_norm2
img_mlp
txt_norm1
txt_attn
txt_norm2
txt_mlp
def forward( self, img: torch.Tensor, txt: torch.Tensor, pe: torch.Tensor, pe_ctx: torch.Tensor, mod_img: tuple[torch.Tensor, torch.Tensor], mod_txt: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
302    def forward(
303        self,
304        img: Tensor,
305        txt: Tensor,
306        pe: Tensor,
307        pe_ctx: Tensor,
308        mod_img: tuple[Tensor, Tensor],
309        mod_txt: tuple[Tensor, Tensor],
310    ) -> tuple[Tensor, Tensor]:
311        img_mod1, img_mod2 = mod_img
312        txt_mod1, txt_mod2 = mod_txt
313
314        img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
315        img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
316        txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
317        txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
318
319        # prepare image for attention
320        img_modulated = self.img_norm1(img)
321        img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
322
323        img_qkv = self.img_attn.qkv(img_modulated)
324        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
325        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
326
327        # prepare txt for attention
328        txt_modulated = self.txt_norm1(txt)
329        txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
330
331        txt_qkv = self.txt_attn.qkv(txt_modulated)
332        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
333        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
334
335        q = torch.cat((txt_q, img_q), dim=2)
336        k = torch.cat((txt_k, img_k), dim=2)
337        v = torch.cat((txt_v, img_v), dim=2)
338
339        pe = torch.cat((pe_ctx, pe), dim=2)
340        attn = attention(q, k, v, pe)
341        txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :]
342
343        # calculate the img blocks
344        img = img + img_mod1_gate * self.img_attn.proj(img_attn)
345        img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift)
346
347        # calculate the txt blocks
348        txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
349        txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift)
350        return img, txt

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class MLPEmbedder(torch.nn.modules.module.Module):
353class MLPEmbedder(nn.Module):
354    def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
355        super().__init__()
356        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
357        self.silu = nn.SiLU()
358        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
359
360    def forward(self, x: Tensor) -> Tensor:
361        return self.out_layer(self.silu(self.in_layer(x)))

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

MLPEmbedder(in_dim: int, hidden_dim: int, disable_bias: bool = False)
354    def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
355        super().__init__()
356        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
357        self.silu = nn.SiLU()
358        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_layer
silu
out_layer
def forward(self, x: torch.Tensor) -> torch.Tensor:
360    def forward(self, x: Tensor) -> Tensor:
361        return self.out_layer(self.silu(self.in_layer(x)))

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class EmbedND(torch.nn.modules.module.Module):
364class EmbedND(nn.Module):
365    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
366        super().__init__()
367        self.dim = dim
368        self.theta = theta
369        self.axes_dim = axes_dim
370
371    def forward(self, ids: Tensor) -> Tensor:
372        emb = torch.cat(
373            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(len(self.axes_dim))],
374            dim=-3,
375        )
376
377        return emb.unsqueeze(1)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

EmbedND(dim: int, theta: int, axes_dim: list[int])
365    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
366        super().__init__()
367        self.dim = dim
368        self.theta = theta
369        self.axes_dim = axes_dim

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dim
theta
axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
371    def forward(self, ids: Tensor) -> Tensor:
372        emb = torch.cat(
373            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(len(self.axes_dim))],
374            dim=-3,
375        )
376
377        return emb.unsqueeze(1)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def timestep_embedding(t: torch.Tensor, dim, max_period=10000, time_factor: float = 1000.0):
380def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
381    """
382    Create sinusoidal timestep embeddings.
383    :param t: a 1-D Tensor of N indices, one per batch element.
384                      These may be fractional.
385    :param dim: the dimension of the output.
386    :param max_period: controls the minimum frequency of the embeddings.
387    :return: an (N, D) Tensor of positional embeddings.
388    """
389    t = time_factor * t
390    half = dim // 2
391    freqs = torch.exp(
392        -math.log(max_period) * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32) / half  # float32 originally
393    )
394
395    args = t[:, None].float() * freqs[None]
396    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
397    if dim % 2:
398        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
399    if torch.is_floating_point(t):
400        embedding = embedding.to(t)
401    return embedding

Create sinusoidal timestep embeddings.

Parameters
  • t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
  • dim: the dimension of the output.
  • max_period: controls the minimum frequency of the embeddings.
Returns

an (N, D) Tensor of positional embeddings.

class RMSNorm(torch.nn.modules.module.Module):
404class RMSNorm(torch.nn.Module):
405    def __init__(self, dim: int):
406        super().__init__()
407        self.scale = nn.Parameter(torch.ones(dim))
408
409    def forward(self, x: Tensor):
410        x_dtype = x.dtype
411        x = x.float()
412        rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
413        return (x * rrms).to(dtype=x_dtype) * self.scale

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

RMSNorm(dim: int)
405    def __init__(self, dim: int):
406        super().__init__()
407        self.scale = nn.Parameter(torch.ones(dim))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

scale
def forward(self, x: torch.Tensor):
409    def forward(self, x: Tensor):
410        x_dtype = x.dtype
411        x = x.float()
412        rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
413        return (x * rrms).to(dtype=x_dtype) * self.scale

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class QKNorm(torch.nn.modules.module.Module):
416class QKNorm(torch.nn.Module):
417    def __init__(self, dim: int):
418        super().__init__()
419        self.query_norm = RMSNorm(dim)
420        self.key_norm = RMSNorm(dim)
421
422    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
423        q = self.query_norm(q)
424        k = self.key_norm(k)
425        return q.to(v), k.to(v)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

QKNorm(dim: int)
417    def __init__(self, dim: int):
418        super().__init__()
419        self.query_norm = RMSNorm(dim)
420        self.key_norm = RMSNorm(dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

query_norm
key_norm
def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
422    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
423        q = self.query_norm(q)
424        k = self.key_norm(k)
425        return q.to(v), k.to(v)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
428def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
429    q, k = apply_rope(q, k, pe)
430
431    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
432    x = rearrange(x, "B H L D -> B L (H D)")
433
434    return x
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
437def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
438    assert dim % 2 == 0
439    scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
440    omega = 1.0 / (theta**scale)
441    out = torch.einsum("...n,d->...nd", pos, omega)
442    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
443    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
444    return out.float()
def apply_rope( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
447def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
448    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
449    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
450    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
451    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
452    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)