divisor.flux2.model
1# SPDX-License-Identifier:Apache-2.0 2# original BFL Flux code from https://github.com/black-forest-labs/flux2 3 4from dataclasses import dataclass, field 5import math 6 7from einops import rearrange 8import torch 9from torch import Tensor, nn 10 11from divisor.layer_dropout import process_blocks_with_dropout 12 13 14@dataclass 15class Flux2Params: 16 in_channels: int = 128 17 context_in_dim: int = 15360 18 hidden_size: int = 6144 19 num_heads: int = 48 20 depth: int = 8 21 depth_single_blocks: int = 48 22 axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) 23 theta: int = 2000 24 mlp_ratio: float = 3.0 25 26 27class Flux2(nn.Module): 28 def __init__(self, params: Flux2Params): 29 super().__init__() 30 31 self.in_channels = params.in_channels 32 self.out_channels = params.in_channels 33 if params.hidden_size % params.num_heads != 0: 34 raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") 35 pe_dim = params.hidden_size // params.num_heads 36 if sum(params.axes_dim) != pe_dim: 37 raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 38 self.hidden_size = params.hidden_size 39 self.num_heads = params.num_heads 40 self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 41 self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False) 42 self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True) 43 self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True) 44 self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False) 45 46 self.double_blocks = nn.ModuleList( 47 [ 48 DoubleStreamBlock( 49 self.hidden_size, 50 self.num_heads, 51 mlp_ratio=params.mlp_ratio, 52 ) 53 for _ in range(params.depth) 54 ] 55 ) 56 57 self.single_blocks = nn.ModuleList( 58 [ 59 SingleStreamBlock( 60 self.hidden_size, 61 self.num_heads, 62 mlp_ratio=params.mlp_ratio, 63 ) 64 for _ in range(params.depth_single_blocks) 65 ] 66 ) 67 68 self.double_stream_modulation_img = Modulation( 69 self.hidden_size, 70 double=True, 71 disable_bias=True, 72 ) 73 self.double_stream_modulation_txt = Modulation( 74 self.hidden_size, 75 double=True, 76 disable_bias=True, 77 ) 78 self.single_stream_modulation = Modulation(self.hidden_size, double=False, disable_bias=True) 79 80 self.final_layer = LastLayer( 81 self.hidden_size, 82 self.out_channels, 83 ) 84 85 def forward( 86 self, 87 x: Tensor, 88 x_ids: Tensor, 89 timesteps: Tensor, 90 ctx: Tensor, 91 ctx_ids: Tensor, 92 guidance: Tensor, 93 layer_dropouts: list[int] | None = None, 94 ): 95 num_txt_tokens = ctx.shape[1] 96 97 timestep_emb = timestep_embedding(timesteps, 256) 98 vec = self.time_in(timestep_emb) 99 guidance_emb = timestep_embedding(guidance, 256) 100 vec = vec + self.guidance_in(guidance_emb) 101 102 double_block_mod_img = self.double_stream_modulation_img(vec) 103 double_block_mod_txt = self.double_stream_modulation_txt(vec) 104 single_block_mod, _ = self.single_stream_modulation(vec) 105 106 img = self.img_in(x) 107 txt = self.txt_in(ctx) 108 109 pe_x = self.pe_embedder(x_ids) 110 pe_ctx = self.pe_embedder(ctx_ids) 111 112 img, txt = process_blocks_with_dropout( 113 self.double_blocks, 114 layer_dropouts, # Would need to be added to flux2 forward signature 115 0, 116 "double", 117 lambda block, state: block(state[0], state[1], pe_x, pe_ctx, double_block_mod_img, double_block_mod_txt), 118 (img, txt), 119 ) 120 121 img = torch.cat((txt, img), dim=1) 122 pe = torch.cat((pe_ctx, pe_x), dim=2) 123 124 img = process_blocks_with_dropout( 125 self.single_blocks, 126 layer_dropouts, # Would need to be added to flux2 forward signature 127 len(self.double_blocks), 128 "single", 129 lambda block, state: block(state, pe, single_block_mod), 130 img, 131 ) 132 img = img[:, num_txt_tokens:, ...] 133 134 img = self.final_layer(img, vec) 135 return img 136 137 138class SelfAttention(nn.Module): 139 def __init__( 140 self, 141 dim: int, 142 num_heads: int = 8, 143 ): 144 super().__init__() 145 self.num_heads = num_heads 146 head_dim = dim // num_heads 147 self.qkv = nn.Linear(dim, dim * 3, bias=False) 148 149 self.norm = QKNorm(head_dim) 150 self.proj = nn.Linear(dim, dim, bias=False) 151 152 153class SiLUActivation(nn.Module): 154 def __init__(self): 155 super().__init__() 156 self.gate_fn = nn.SiLU() 157 158 def forward(self, x: Tensor) -> Tensor: 159 x1, x2 = x.chunk(2, dim=-1) 160 return self.gate_fn(x1) * x2 161 162 163class Modulation(nn.Module): 164 def __init__(self, dim: int, double: bool, disable_bias: bool = False): 165 super().__init__() 166 self.is_double = double 167 self.multiplier = 6 if double else 3 168 self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias) 169 170 def forward(self, vec: torch.Tensor): 171 out = self.lin(nn.functional.silu(vec)) 172 if out.ndim == 2: 173 out = out[:, None, :] 174 out = out.chunk(self.multiplier, dim=-1) 175 return out[:3], out[3:] if self.is_double else None 176 177 178class LastLayer(nn.Module): 179 def __init__( 180 self, 181 hidden_size: int, 182 out_channels: int, 183 ): 184 super().__init__() 185 self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 186 self.linear = nn.Linear(hidden_size, out_channels, bias=False) 187 self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False)) 188 189 def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: 190 mod = self.adaLN_modulation(vec) 191 shift, scale = mod.chunk(2, dim=-1) 192 if shift.ndim == 2: 193 shift = shift[:, None, :] 194 scale = scale[:, None, :] 195 x = (1 + scale) * self.norm_final(x) + shift 196 x = self.linear(x) 197 return x 198 199 200class SingleStreamBlock(nn.Module): 201 def __init__( 202 self, 203 hidden_size: int, 204 num_heads: int, 205 mlp_ratio: float = 4.0, 206 ): 207 super().__init__() 208 209 self.hidden_dim = hidden_size 210 self.num_heads = num_heads 211 head_dim = hidden_size // num_heads 212 self.scale = head_dim**-0.5 213 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 214 self.mlp_mult_factor = 2 215 216 self.linear1 = nn.Linear( 217 hidden_size, 218 hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, 219 bias=False, 220 ) 221 222 self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=False) 223 224 self.norm = QKNorm(head_dim) 225 226 self.hidden_size = hidden_size 227 self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 228 229 self.mlp_act = SiLUActivation() 230 231 def forward( 232 self, 233 x: Tensor, 234 pe: Tensor, 235 mod: tuple[Tensor, Tensor], 236 ) -> Tensor: 237 mod_shift, mod_scale, mod_gate = mod # type: ignore 238 x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift 239 240 qkv, mlp = torch.split( 241 self.linear1(x_mod), 242 [3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor], 243 dim=-1, 244 ) 245 246 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 247 q, k = self.norm(q, k, v) 248 249 attn = attention(q, k, v, pe) 250 251 # compute activation in mlp stream, cat again and run second linear layer 252 output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 253 return x + mod_gate * output 254 255 256class DoubleStreamBlock(nn.Module): 257 def __init__( 258 self, 259 hidden_size: int, 260 num_heads: int, 261 mlp_ratio: float, 262 ): 263 super().__init__() 264 mlp_hidden_dim = int(hidden_size * mlp_ratio) 265 self.num_heads = num_heads 266 assert hidden_size % num_heads == 0, f"{hidden_size=} must be divisible by {num_heads=}" 267 268 self.hidden_size = hidden_size 269 self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 270 self.mlp_mult_factor = 2 271 272 self.img_attn = SelfAttention( 273 dim=hidden_size, 274 num_heads=num_heads, 275 ) 276 277 self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 278 self.img_mlp = nn.Sequential( 279 nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False), 280 SiLUActivation(), 281 nn.Linear(mlp_hidden_dim, hidden_size, bias=False), 282 ) 283 284 self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 285 self.txt_attn = SelfAttention( 286 dim=hidden_size, 287 num_heads=num_heads, 288 ) 289 290 self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 291 self.txt_mlp = nn.Sequential( 292 nn.Linear( 293 hidden_size, 294 mlp_hidden_dim * self.mlp_mult_factor, 295 bias=False, 296 ), 297 SiLUActivation(), 298 nn.Linear(mlp_hidden_dim, hidden_size, bias=False), 299 ) 300 301 def forward( 302 self, 303 img: Tensor, 304 txt: Tensor, 305 pe: Tensor, 306 pe_ctx: Tensor, 307 mod_img: tuple[Tensor, Tensor], 308 mod_txt: tuple[Tensor, Tensor], 309 ) -> tuple[Tensor, Tensor]: 310 img_mod1, img_mod2 = mod_img 311 txt_mod1, txt_mod2 = mod_txt 312 313 img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1 314 img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2 315 txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1 316 txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2 317 318 # prepare image for attention 319 img_modulated = self.img_norm1(img) 320 img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift 321 322 img_qkv = self.img_attn.qkv(img_modulated) 323 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 324 img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 325 326 # prepare txt for attention 327 txt_modulated = self.txt_norm1(txt) 328 txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift 329 330 txt_qkv = self.txt_attn.qkv(txt_modulated) 331 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 332 txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 333 334 q = torch.cat((txt_q, img_q), dim=2) 335 k = torch.cat((txt_k, img_k), dim=2) 336 v = torch.cat((txt_v, img_v), dim=2) 337 338 pe = torch.cat((pe_ctx, pe), dim=2) 339 attn = attention(q, k, v, pe) 340 txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :] 341 342 # calculate the img blocks 343 img = img + img_mod1_gate * self.img_attn.proj(img_attn) 344 img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift) 345 346 # calculate the txt blocks 347 txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn) 348 txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift) 349 return img, txt 350 351 352class MLPEmbedder(nn.Module): 353 def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False): 354 super().__init__() 355 self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias) 356 self.silu = nn.SiLU() 357 self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias) 358 359 def forward(self, x: Tensor) -> Tensor: 360 return self.out_layer(self.silu(self.in_layer(x))) 361 362 363class EmbedND(nn.Module): 364 def __init__(self, dim: int, theta: int, axes_dim: list[int]): 365 super().__init__() 366 self.dim = dim 367 self.theta = theta 368 self.axes_dim = axes_dim 369 370 def forward(self, ids: Tensor) -> Tensor: 371 emb = torch.cat( 372 [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(len(self.axes_dim))], 373 dim=-3, 374 ) 375 376 return emb.unsqueeze(1) 377 378 379def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 380 """ 381 Create sinusoidal timestep embeddings. 382 :param t: a 1-D Tensor of N indices, one per batch element. 383 These may be fractional. 384 :param dim: the dimension of the output. 385 :param max_period: controls the minimum frequency of the embeddings. 386 :return: an (N, D) Tensor of positional embeddings. 387 """ 388 t = time_factor * t 389 half = dim // 2 390 freqs = torch.exp( 391 -math.log(max_period) * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32) / half # float32 originally 392 ) 393 394 args = t[:, None].float() * freqs[None] 395 embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 396 if dim % 2: 397 embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 398 if torch.is_floating_point(t): 399 embedding = embedding.to(t) 400 return embedding 401 402 403class RMSNorm(torch.nn.Module): 404 def __init__(self, dim: int): 405 super().__init__() 406 self.scale = nn.Parameter(torch.ones(dim)) 407 408 def forward(self, x: Tensor): 409 x_dtype = x.dtype 410 x = x.float() 411 rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 412 return (x * rrms).to(dtype=x_dtype) * self.scale 413 414 415class QKNorm(torch.nn.Module): 416 def __init__(self, dim: int): 417 super().__init__() 418 self.query_norm = RMSNorm(dim) 419 self.key_norm = RMSNorm(dim) 420 421 def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 422 q = self.query_norm(q) 423 k = self.key_norm(k) 424 return q.to(v), k.to(v) 425 426 427def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: 428 q, k = apply_rope(q, k, pe) 429 430 x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 431 x = rearrange(x, "B H L D -> B L (H D)") 432 433 return x 434 435 436def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 437 assert dim % 2 == 0 438 scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim 439 omega = 1.0 / (theta**scale) 440 out = torch.einsum("...n,d->...nd", pos, omega) 441 out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 442 out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 443 return out.float() 444 445 446def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 447 xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 448 xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 449 xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 450 xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 451 return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
15@dataclass 16class Flux2Params: 17 in_channels: int = 128 18 context_in_dim: int = 15360 19 hidden_size: int = 6144 20 num_heads: int = 48 21 depth: int = 8 22 depth_single_blocks: int = 48 23 axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) 24 theta: int = 2000 25 mlp_ratio: float = 3.0
28class Flux2(nn.Module): 29 def __init__(self, params: Flux2Params): 30 super().__init__() 31 32 self.in_channels = params.in_channels 33 self.out_channels = params.in_channels 34 if params.hidden_size % params.num_heads != 0: 35 raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") 36 pe_dim = params.hidden_size // params.num_heads 37 if sum(params.axes_dim) != pe_dim: 38 raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 39 self.hidden_size = params.hidden_size 40 self.num_heads = params.num_heads 41 self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 42 self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False) 43 self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True) 44 self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True) 45 self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False) 46 47 self.double_blocks = nn.ModuleList( 48 [ 49 DoubleStreamBlock( 50 self.hidden_size, 51 self.num_heads, 52 mlp_ratio=params.mlp_ratio, 53 ) 54 for _ in range(params.depth) 55 ] 56 ) 57 58 self.single_blocks = nn.ModuleList( 59 [ 60 SingleStreamBlock( 61 self.hidden_size, 62 self.num_heads, 63 mlp_ratio=params.mlp_ratio, 64 ) 65 for _ in range(params.depth_single_blocks) 66 ] 67 ) 68 69 self.double_stream_modulation_img = Modulation( 70 self.hidden_size, 71 double=True, 72 disable_bias=True, 73 ) 74 self.double_stream_modulation_txt = Modulation( 75 self.hidden_size, 76 double=True, 77 disable_bias=True, 78 ) 79 self.single_stream_modulation = Modulation(self.hidden_size, double=False, disable_bias=True) 80 81 self.final_layer = LastLayer( 82 self.hidden_size, 83 self.out_channels, 84 ) 85 86 def forward( 87 self, 88 x: Tensor, 89 x_ids: Tensor, 90 timesteps: Tensor, 91 ctx: Tensor, 92 ctx_ids: Tensor, 93 guidance: Tensor, 94 layer_dropouts: list[int] | None = None, 95 ): 96 num_txt_tokens = ctx.shape[1] 97 98 timestep_emb = timestep_embedding(timesteps, 256) 99 vec = self.time_in(timestep_emb) 100 guidance_emb = timestep_embedding(guidance, 256) 101 vec = vec + self.guidance_in(guidance_emb) 102 103 double_block_mod_img = self.double_stream_modulation_img(vec) 104 double_block_mod_txt = self.double_stream_modulation_txt(vec) 105 single_block_mod, _ = self.single_stream_modulation(vec) 106 107 img = self.img_in(x) 108 txt = self.txt_in(ctx) 109 110 pe_x = self.pe_embedder(x_ids) 111 pe_ctx = self.pe_embedder(ctx_ids) 112 113 img, txt = process_blocks_with_dropout( 114 self.double_blocks, 115 layer_dropouts, # Would need to be added to flux2 forward signature 116 0, 117 "double", 118 lambda block, state: block(state[0], state[1], pe_x, pe_ctx, double_block_mod_img, double_block_mod_txt), 119 (img, txt), 120 ) 121 122 img = torch.cat((txt, img), dim=1) 123 pe = torch.cat((pe_ctx, pe_x), dim=2) 124 125 img = process_blocks_with_dropout( 126 self.single_blocks, 127 layer_dropouts, # Would need to be added to flux2 forward signature 128 len(self.double_blocks), 129 "single", 130 lambda block, state: block(state, pe, single_block_mod), 131 img, 132 ) 133 img = img[:, num_txt_tokens:, ...] 134 135 img = self.final_layer(img, vec) 136 return img
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
29 def __init__(self, params: Flux2Params): 30 super().__init__() 31 32 self.in_channels = params.in_channels 33 self.out_channels = params.in_channels 34 if params.hidden_size % params.num_heads != 0: 35 raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") 36 pe_dim = params.hidden_size // params.num_heads 37 if sum(params.axes_dim) != pe_dim: 38 raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 39 self.hidden_size = params.hidden_size 40 self.num_heads = params.num_heads 41 self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 42 self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False) 43 self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True) 44 self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True) 45 self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False) 46 47 self.double_blocks = nn.ModuleList( 48 [ 49 DoubleStreamBlock( 50 self.hidden_size, 51 self.num_heads, 52 mlp_ratio=params.mlp_ratio, 53 ) 54 for _ in range(params.depth) 55 ] 56 ) 57 58 self.single_blocks = nn.ModuleList( 59 [ 60 SingleStreamBlock( 61 self.hidden_size, 62 self.num_heads, 63 mlp_ratio=params.mlp_ratio, 64 ) 65 for _ in range(params.depth_single_blocks) 66 ] 67 ) 68 69 self.double_stream_modulation_img = Modulation( 70 self.hidden_size, 71 double=True, 72 disable_bias=True, 73 ) 74 self.double_stream_modulation_txt = Modulation( 75 self.hidden_size, 76 double=True, 77 disable_bias=True, 78 ) 79 self.single_stream_modulation = Modulation(self.hidden_size, double=False, disable_bias=True) 80 81 self.final_layer = LastLayer( 82 self.hidden_size, 83 self.out_channels, 84 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
86 def forward( 87 self, 88 x: Tensor, 89 x_ids: Tensor, 90 timesteps: Tensor, 91 ctx: Tensor, 92 ctx_ids: Tensor, 93 guidance: Tensor, 94 layer_dropouts: list[int] | None = None, 95 ): 96 num_txt_tokens = ctx.shape[1] 97 98 timestep_emb = timestep_embedding(timesteps, 256) 99 vec = self.time_in(timestep_emb) 100 guidance_emb = timestep_embedding(guidance, 256) 101 vec = vec + self.guidance_in(guidance_emb) 102 103 double_block_mod_img = self.double_stream_modulation_img(vec) 104 double_block_mod_txt = self.double_stream_modulation_txt(vec) 105 single_block_mod, _ = self.single_stream_modulation(vec) 106 107 img = self.img_in(x) 108 txt = self.txt_in(ctx) 109 110 pe_x = self.pe_embedder(x_ids) 111 pe_ctx = self.pe_embedder(ctx_ids) 112 113 img, txt = process_blocks_with_dropout( 114 self.double_blocks, 115 layer_dropouts, # Would need to be added to flux2 forward signature 116 0, 117 "double", 118 lambda block, state: block(state[0], state[1], pe_x, pe_ctx, double_block_mod_img, double_block_mod_txt), 119 (img, txt), 120 ) 121 122 img = torch.cat((txt, img), dim=1) 123 pe = torch.cat((pe_ctx, pe_x), dim=2) 124 125 img = process_blocks_with_dropout( 126 self.single_blocks, 127 layer_dropouts, # Would need to be added to flux2 forward signature 128 len(self.double_blocks), 129 "single", 130 lambda block, state: block(state, pe, single_block_mod), 131 img, 132 ) 133 img = img[:, num_txt_tokens:, ...] 134 135 img = self.final_layer(img, vec) 136 return img
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
139class SelfAttention(nn.Module): 140 def __init__( 141 self, 142 dim: int, 143 num_heads: int = 8, 144 ): 145 super().__init__() 146 self.num_heads = num_heads 147 head_dim = dim // num_heads 148 self.qkv = nn.Linear(dim, dim * 3, bias=False) 149 150 self.norm = QKNorm(head_dim) 151 self.proj = nn.Linear(dim, dim, bias=False)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
140 def __init__( 141 self, 142 dim: int, 143 num_heads: int = 8, 144 ): 145 super().__init__() 146 self.num_heads = num_heads 147 head_dim = dim // num_heads 148 self.qkv = nn.Linear(dim, dim * 3, bias=False) 149 150 self.norm = QKNorm(head_dim) 151 self.proj = nn.Linear(dim, dim, bias=False)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
154class SiLUActivation(nn.Module): 155 def __init__(self): 156 super().__init__() 157 self.gate_fn = nn.SiLU() 158 159 def forward(self, x: Tensor) -> Tensor: 160 x1, x2 = x.chunk(2, dim=-1) 161 return self.gate_fn(x1) * x2
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
159 def forward(self, x: Tensor) -> Tensor: 160 x1, x2 = x.chunk(2, dim=-1) 161 return self.gate_fn(x1) * x2
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
164class Modulation(nn.Module): 165 def __init__(self, dim: int, double: bool, disable_bias: bool = False): 166 super().__init__() 167 self.is_double = double 168 self.multiplier = 6 if double else 3 169 self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias) 170 171 def forward(self, vec: torch.Tensor): 172 out = self.lin(nn.functional.silu(vec)) 173 if out.ndim == 2: 174 out = out[:, None, :] 175 out = out.chunk(self.multiplier, dim=-1) 176 return out[:3], out[3:] if self.is_double else None
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
165 def __init__(self, dim: int, double: bool, disable_bias: bool = False): 166 super().__init__() 167 self.is_double = double 168 self.multiplier = 6 if double else 3 169 self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
171 def forward(self, vec: torch.Tensor): 172 out = self.lin(nn.functional.silu(vec)) 173 if out.ndim == 2: 174 out = out[:, None, :] 175 out = out.chunk(self.multiplier, dim=-1) 176 return out[:3], out[3:] if self.is_double else None
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
179class LastLayer(nn.Module): 180 def __init__( 181 self, 182 hidden_size: int, 183 out_channels: int, 184 ): 185 super().__init__() 186 self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 187 self.linear = nn.Linear(hidden_size, out_channels, bias=False) 188 self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False)) 189 190 def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: 191 mod = self.adaLN_modulation(vec) 192 shift, scale = mod.chunk(2, dim=-1) 193 if shift.ndim == 2: 194 shift = shift[:, None, :] 195 scale = scale[:, None, :] 196 x = (1 + scale) * self.norm_final(x) + shift 197 x = self.linear(x) 198 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
180 def __init__( 181 self, 182 hidden_size: int, 183 out_channels: int, 184 ): 185 super().__init__() 186 self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 187 self.linear = nn.Linear(hidden_size, out_channels, bias=False) 188 self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
190 def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: 191 mod = self.adaLN_modulation(vec) 192 shift, scale = mod.chunk(2, dim=-1) 193 if shift.ndim == 2: 194 shift = shift[:, None, :] 195 scale = scale[:, None, :] 196 x = (1 + scale) * self.norm_final(x) + shift 197 x = self.linear(x) 198 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
201class SingleStreamBlock(nn.Module): 202 def __init__( 203 self, 204 hidden_size: int, 205 num_heads: int, 206 mlp_ratio: float = 4.0, 207 ): 208 super().__init__() 209 210 self.hidden_dim = hidden_size 211 self.num_heads = num_heads 212 head_dim = hidden_size // num_heads 213 self.scale = head_dim**-0.5 214 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 215 self.mlp_mult_factor = 2 216 217 self.linear1 = nn.Linear( 218 hidden_size, 219 hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, 220 bias=False, 221 ) 222 223 self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=False) 224 225 self.norm = QKNorm(head_dim) 226 227 self.hidden_size = hidden_size 228 self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 229 230 self.mlp_act = SiLUActivation() 231 232 def forward( 233 self, 234 x: Tensor, 235 pe: Tensor, 236 mod: tuple[Tensor, Tensor], 237 ) -> Tensor: 238 mod_shift, mod_scale, mod_gate = mod # type: ignore 239 x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift 240 241 qkv, mlp = torch.split( 242 self.linear1(x_mod), 243 [3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor], 244 dim=-1, 245 ) 246 247 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 248 q, k = self.norm(q, k, v) 249 250 attn = attention(q, k, v, pe) 251 252 # compute activation in mlp stream, cat again and run second linear layer 253 output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 254 return x + mod_gate * output
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
202 def __init__( 203 self, 204 hidden_size: int, 205 num_heads: int, 206 mlp_ratio: float = 4.0, 207 ): 208 super().__init__() 209 210 self.hidden_dim = hidden_size 211 self.num_heads = num_heads 212 head_dim = hidden_size // num_heads 213 self.scale = head_dim**-0.5 214 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 215 self.mlp_mult_factor = 2 216 217 self.linear1 = nn.Linear( 218 hidden_size, 219 hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, 220 bias=False, 221 ) 222 223 self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=False) 224 225 self.norm = QKNorm(head_dim) 226 227 self.hidden_size = hidden_size 228 self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 229 230 self.mlp_act = SiLUActivation()
Initialize internal Module state, shared by both nn.Module and ScriptModule.
232 def forward( 233 self, 234 x: Tensor, 235 pe: Tensor, 236 mod: tuple[Tensor, Tensor], 237 ) -> Tensor: 238 mod_shift, mod_scale, mod_gate = mod # type: ignore 239 x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift 240 241 qkv, mlp = torch.split( 242 self.linear1(x_mod), 243 [3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor], 244 dim=-1, 245 ) 246 247 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 248 q, k = self.norm(q, k, v) 249 250 attn = attention(q, k, v, pe) 251 252 # compute activation in mlp stream, cat again and run second linear layer 253 output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 254 return x + mod_gate * output
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
257class DoubleStreamBlock(nn.Module): 258 def __init__( 259 self, 260 hidden_size: int, 261 num_heads: int, 262 mlp_ratio: float, 263 ): 264 super().__init__() 265 mlp_hidden_dim = int(hidden_size * mlp_ratio) 266 self.num_heads = num_heads 267 assert hidden_size % num_heads == 0, f"{hidden_size=} must be divisible by {num_heads=}" 268 269 self.hidden_size = hidden_size 270 self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 271 self.mlp_mult_factor = 2 272 273 self.img_attn = SelfAttention( 274 dim=hidden_size, 275 num_heads=num_heads, 276 ) 277 278 self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 279 self.img_mlp = nn.Sequential( 280 nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False), 281 SiLUActivation(), 282 nn.Linear(mlp_hidden_dim, hidden_size, bias=False), 283 ) 284 285 self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 286 self.txt_attn = SelfAttention( 287 dim=hidden_size, 288 num_heads=num_heads, 289 ) 290 291 self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 292 self.txt_mlp = nn.Sequential( 293 nn.Linear( 294 hidden_size, 295 mlp_hidden_dim * self.mlp_mult_factor, 296 bias=False, 297 ), 298 SiLUActivation(), 299 nn.Linear(mlp_hidden_dim, hidden_size, bias=False), 300 ) 301 302 def forward( 303 self, 304 img: Tensor, 305 txt: Tensor, 306 pe: Tensor, 307 pe_ctx: Tensor, 308 mod_img: tuple[Tensor, Tensor], 309 mod_txt: tuple[Tensor, Tensor], 310 ) -> tuple[Tensor, Tensor]: 311 img_mod1, img_mod2 = mod_img 312 txt_mod1, txt_mod2 = mod_txt 313 314 img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1 315 img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2 316 txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1 317 txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2 318 319 # prepare image for attention 320 img_modulated = self.img_norm1(img) 321 img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift 322 323 img_qkv = self.img_attn.qkv(img_modulated) 324 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 325 img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 326 327 # prepare txt for attention 328 txt_modulated = self.txt_norm1(txt) 329 txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift 330 331 txt_qkv = self.txt_attn.qkv(txt_modulated) 332 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 333 txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 334 335 q = torch.cat((txt_q, img_q), dim=2) 336 k = torch.cat((txt_k, img_k), dim=2) 337 v = torch.cat((txt_v, img_v), dim=2) 338 339 pe = torch.cat((pe_ctx, pe), dim=2) 340 attn = attention(q, k, v, pe) 341 txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :] 342 343 # calculate the img blocks 344 img = img + img_mod1_gate * self.img_attn.proj(img_attn) 345 img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift) 346 347 # calculate the txt blocks 348 txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn) 349 txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift) 350 return img, txt
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
258 def __init__( 259 self, 260 hidden_size: int, 261 num_heads: int, 262 mlp_ratio: float, 263 ): 264 super().__init__() 265 mlp_hidden_dim = int(hidden_size * mlp_ratio) 266 self.num_heads = num_heads 267 assert hidden_size % num_heads == 0, f"{hidden_size=} must be divisible by {num_heads=}" 268 269 self.hidden_size = hidden_size 270 self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 271 self.mlp_mult_factor = 2 272 273 self.img_attn = SelfAttention( 274 dim=hidden_size, 275 num_heads=num_heads, 276 ) 277 278 self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 279 self.img_mlp = nn.Sequential( 280 nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False), 281 SiLUActivation(), 282 nn.Linear(mlp_hidden_dim, hidden_size, bias=False), 283 ) 284 285 self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 286 self.txt_attn = SelfAttention( 287 dim=hidden_size, 288 num_heads=num_heads, 289 ) 290 291 self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 292 self.txt_mlp = nn.Sequential( 293 nn.Linear( 294 hidden_size, 295 mlp_hidden_dim * self.mlp_mult_factor, 296 bias=False, 297 ), 298 SiLUActivation(), 299 nn.Linear(mlp_hidden_dim, hidden_size, bias=False), 300 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
302 def forward( 303 self, 304 img: Tensor, 305 txt: Tensor, 306 pe: Tensor, 307 pe_ctx: Tensor, 308 mod_img: tuple[Tensor, Tensor], 309 mod_txt: tuple[Tensor, Tensor], 310 ) -> tuple[Tensor, Tensor]: 311 img_mod1, img_mod2 = mod_img 312 txt_mod1, txt_mod2 = mod_txt 313 314 img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1 315 img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2 316 txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1 317 txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2 318 319 # prepare image for attention 320 img_modulated = self.img_norm1(img) 321 img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift 322 323 img_qkv = self.img_attn.qkv(img_modulated) 324 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 325 img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 326 327 # prepare txt for attention 328 txt_modulated = self.txt_norm1(txt) 329 txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift 330 331 txt_qkv = self.txt_attn.qkv(txt_modulated) 332 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 333 txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 334 335 q = torch.cat((txt_q, img_q), dim=2) 336 k = torch.cat((txt_k, img_k), dim=2) 337 v = torch.cat((txt_v, img_v), dim=2) 338 339 pe = torch.cat((pe_ctx, pe), dim=2) 340 attn = attention(q, k, v, pe) 341 txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :] 342 343 # calculate the img blocks 344 img = img + img_mod1_gate * self.img_attn.proj(img_attn) 345 img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift) 346 347 # calculate the txt blocks 348 txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn) 349 txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift) 350 return img, txt
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
353class MLPEmbedder(nn.Module): 354 def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False): 355 super().__init__() 356 self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias) 357 self.silu = nn.SiLU() 358 self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias) 359 360 def forward(self, x: Tensor) -> Tensor: 361 return self.out_layer(self.silu(self.in_layer(x)))
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
354 def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False): 355 super().__init__() 356 self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias) 357 self.silu = nn.SiLU() 358 self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
364class EmbedND(nn.Module): 365 def __init__(self, dim: int, theta: int, axes_dim: list[int]): 366 super().__init__() 367 self.dim = dim 368 self.theta = theta 369 self.axes_dim = axes_dim 370 371 def forward(self, ids: Tensor) -> Tensor: 372 emb = torch.cat( 373 [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(len(self.axes_dim))], 374 dim=-3, 375 ) 376 377 return emb.unsqueeze(1)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
365 def __init__(self, dim: int, theta: int, axes_dim: list[int]): 366 super().__init__() 367 self.dim = dim 368 self.theta = theta 369 self.axes_dim = axes_dim
Initialize internal Module state, shared by both nn.Module and ScriptModule.
371 def forward(self, ids: Tensor) -> Tensor: 372 emb = torch.cat( 373 [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(len(self.axes_dim))], 374 dim=-3, 375 ) 376 377 return emb.unsqueeze(1)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
380def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 381 """ 382 Create sinusoidal timestep embeddings. 383 :param t: a 1-D Tensor of N indices, one per batch element. 384 These may be fractional. 385 :param dim: the dimension of the output. 386 :param max_period: controls the minimum frequency of the embeddings. 387 :return: an (N, D) Tensor of positional embeddings. 388 """ 389 t = time_factor * t 390 half = dim // 2 391 freqs = torch.exp( 392 -math.log(max_period) * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32) / half # float32 originally 393 ) 394 395 args = t[:, None].float() * freqs[None] 396 embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 397 if dim % 2: 398 embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 399 if torch.is_floating_point(t): 400 embedding = embedding.to(t) 401 return embedding
Create sinusoidal timestep embeddings.
Parameters
- t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
- dim: the dimension of the output.
- max_period: controls the minimum frequency of the embeddings.
Returns
an (N, D) Tensor of positional embeddings.
404class RMSNorm(torch.nn.Module): 405 def __init__(self, dim: int): 406 super().__init__() 407 self.scale = nn.Parameter(torch.ones(dim)) 408 409 def forward(self, x: Tensor): 410 x_dtype = x.dtype 411 x = x.float() 412 rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 413 return (x * rrms).to(dtype=x_dtype) * self.scale
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
405 def __init__(self, dim: int): 406 super().__init__() 407 self.scale = nn.Parameter(torch.ones(dim))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
409 def forward(self, x: Tensor): 410 x_dtype = x.dtype 411 x = x.float() 412 rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 413 return (x * rrms).to(dtype=x_dtype) * self.scale
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
416class QKNorm(torch.nn.Module): 417 def __init__(self, dim: int): 418 super().__init__() 419 self.query_norm = RMSNorm(dim) 420 self.key_norm = RMSNorm(dim) 421 422 def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 423 q = self.query_norm(q) 424 k = self.key_norm(k) 425 return q.to(v), k.to(v)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
417 def __init__(self, dim: int): 418 super().__init__() 419 self.query_norm = RMSNorm(dim) 420 self.key_norm = RMSNorm(dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
422 def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 423 q = self.query_norm(q) 424 k = self.key_norm(k) 425 return q.to(v), k.to(v)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
437def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 438 assert dim % 2 == 0 439 scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim 440 omega = 1.0 / (theta**scale) 441 out = torch.einsum("...n,d->...nd", pos, omega) 442 out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 443 out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 444 return out.float()
447def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 448 xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 449 xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 450 xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 451 xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 452 return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)