divisor.flux1.layers
1# SPDX-License-Identifier:Apache-2.0 2# original BFL Flux code from https://github.com/black-forest-labs/flux 3 4from dataclasses import dataclass 5import math 6 7from einops import rearrange 8import torch 9from torch import Tensor, nn 10 11from divisor.flux1.math import attention, rope 12 13 14class EmbedND(nn.Module): 15 def __init__(self, dim: int, theta: int, axes_dim: list[int]): 16 super().__init__() 17 self.dim = dim 18 self.theta = theta 19 self.axes_dim = axes_dim 20 21 def forward(self, ids: Tensor) -> Tensor: 22 n_axes = ids.shape[-1] 23 emb = torch.cat( 24 [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 25 dim=-3, 26 ) 27 28 return emb.unsqueeze(1) 29 30 31def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 32 """ 33 Create sinusoidal timestep embeddings. 34 :param t: a 1-D Tensor of N indices, one per batch element. 35 These may be fractional. 36 :param dim: the dimension of the output. 37 :param max_period: controls the minimum frequency of the embeddings. 38 :return: an (N, D) Tensor of positional embeddings. 39 """ 40 t = time_factor * t 41 half = dim // 2 42 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) 43 44 args = t[:, None].float() * freqs[None] 45 embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 46 if dim % 2: 47 embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 48 if torch.is_floating_point(t): 49 embedding = embedding.to(t) 50 return embedding 51 52 53class MLPEmbedder(nn.Module): 54 def __init__(self, in_dim: int, hidden_dim: int): 55 super().__init__() 56 self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 57 self.silu = nn.SiLU() 58 self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 59 60 def forward(self, x: Tensor) -> Tensor: 61 return self.out_layer(self.silu(self.in_layer(x))) 62 63 64class RMSNorm(torch.nn.Module): 65 def __init__(self, dim: int): 66 super().__init__() 67 self.scale = nn.Parameter(torch.ones(dim)) 68 69 def forward(self, x: Tensor): 70 x_dtype = x.dtype 71 x = x.float() 72 rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 73 return (x * rrms).to(dtype=x_dtype) * self.scale 74 75 76class QKNorm(torch.nn.Module): 77 def __init__(self, dim: int): 78 super().__init__() 79 self.query_norm = RMSNorm(dim) 80 self.key_norm = RMSNorm(dim) 81 82 def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 83 q = self.query_norm(q) 84 k = self.key_norm(k) 85 return q.to(v), k.to(v) 86 87 88class SelfAttention(nn.Module): 89 def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 90 super().__init__() 91 self.num_heads = num_heads 92 head_dim = dim // num_heads 93 94 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 95 self.norm = QKNorm(head_dim) 96 self.proj = nn.Linear(dim, dim) 97 98 def forward(self, x: Tensor, pe: Tensor) -> Tensor: 99 qkv = self.qkv(x) 100 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 101 q, k = self.norm(q, k, v) 102 x = attention(q, k, v, pe=pe) 103 x = self.proj(x) 104 return x 105 106 107@dataclass 108class ModulationOut: 109 shift: Tensor 110 scale: Tensor 111 gate: Tensor 112 113 114class Modulation(nn.Module): 115 def __init__(self, dim: int, double: bool): 116 super().__init__() 117 self.is_double = double 118 self.multiplier = 6 if double else 3 119 self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 120 121 def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 122 out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) 123 124 return ( 125 ModulationOut(*out[:3]), 126 ModulationOut(*out[3:]) if self.is_double else None, 127 ) 128 129 130class DoubleStreamBlock(nn.Module): 131 def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 132 super().__init__() 133 134 mlp_hidden_dim = int(hidden_size * mlp_ratio) 135 self.num_heads = num_heads 136 self.hidden_size = hidden_size 137 self.img_mod = Modulation(hidden_size, double=True) 138 self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 139 self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 140 141 self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 142 self.img_mlp = nn.Sequential( 143 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 144 nn.GELU(approximate="tanh"), 145 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 146 ) 147 148 self.txt_mod = Modulation(hidden_size, double=True) 149 self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 150 self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 151 152 self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 153 self.txt_mlp = nn.Sequential( 154 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 155 nn.GELU(approximate="tanh"), 156 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 157 ) 158 159 def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: 160 img_mod1, img_mod2 = self.img_mod(vec) 161 txt_mod1, txt_mod2 = self.txt_mod(vec) 162 163 # prepare image for attention 164 img_modulated = self.img_norm1(img) 165 img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 166 img_qkv = self.img_attn.qkv(img_modulated) 167 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 168 img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 169 170 # prepare txt for attention 171 txt_modulated = self.txt_norm1(txt) 172 txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 173 txt_qkv = self.txt_attn.qkv(txt_modulated) 174 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 175 txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 176 177 # run actual attention 178 q = torch.cat((txt_q, img_q), dim=2) 179 k = torch.cat((txt_k, img_k), dim=2) 180 v = torch.cat((txt_v, img_v), dim=2) 181 182 attn = attention(q, k, v, pe=pe) 183 txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] 184 185 # calculate the img blocks 186 img = img + img_mod1.gate * self.img_attn.proj(img_attn) 187 img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) 188 189 # calculate the txt blocks 190 txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) 191 txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) 192 return img, txt 193 194 195class SingleStreamBlock(nn.Module): 196 """A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface.""" 197 198 def __init__( 199 self, 200 hidden_size: int, 201 num_heads: int, 202 mlp_ratio: float = 4.0, 203 qk_scale: float | None = None, 204 ): 205 super().__init__() 206 self.hidden_dim = hidden_size 207 self.num_heads = num_heads 208 head_dim = hidden_size // num_heads 209 self.scale = qk_scale or head_dim**-0.5 210 211 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 212 # qkv and mlp_in 213 self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 214 # proj and mlp_out 215 self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 216 217 self.norm = QKNorm(head_dim) 218 219 self.hidden_size = hidden_size 220 self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 221 222 self.mlp_act = nn.GELU(approximate="tanh") 223 self.modulation = Modulation(hidden_size, double=False) 224 225 def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 226 mod, _ = self.modulation(vec) 227 x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift 228 qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) 229 230 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 231 q, k = self.norm(q, k, v) 232 233 # compute attention 234 attn = attention(q, k, v, pe=pe) 235 # compute activation in mlp stream, cat again and run second linear layer 236 output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 237 return x + mod.gate * output 238 239 240class LastLayer(nn.Module): 241 def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 242 super().__init__() 243 self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 244 self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 245 self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) 246 247 def forward(self, x: Tensor, vec: Tensor) -> Tensor: 248 shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 249 x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 250 x = self.linear(x) 251 return x
15class EmbedND(nn.Module): 16 def __init__(self, dim: int, theta: int, axes_dim: list[int]): 17 super().__init__() 18 self.dim = dim 19 self.theta = theta 20 self.axes_dim = axes_dim 21 22 def forward(self, ids: Tensor) -> Tensor: 23 n_axes = ids.shape[-1] 24 emb = torch.cat( 25 [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 26 dim=-3, 27 ) 28 29 return emb.unsqueeze(1)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
16 def __init__(self, dim: int, theta: int, axes_dim: list[int]): 17 super().__init__() 18 self.dim = dim 19 self.theta = theta 20 self.axes_dim = axes_dim
Initialize internal Module state, shared by both nn.Module and ScriptModule.
22 def forward(self, ids: Tensor) -> Tensor: 23 n_axes = ids.shape[-1] 24 emb = torch.cat( 25 [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 26 dim=-3, 27 ) 28 29 return emb.unsqueeze(1)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
32def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 33 """ 34 Create sinusoidal timestep embeddings. 35 :param t: a 1-D Tensor of N indices, one per batch element. 36 These may be fractional. 37 :param dim: the dimension of the output. 38 :param max_period: controls the minimum frequency of the embeddings. 39 :return: an (N, D) Tensor of positional embeddings. 40 """ 41 t = time_factor * t 42 half = dim // 2 43 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) 44 45 args = t[:, None].float() * freqs[None] 46 embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 47 if dim % 2: 48 embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 49 if torch.is_floating_point(t): 50 embedding = embedding.to(t) 51 return embedding
Create sinusoidal timestep embeddings.
Parameters
- t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
- dim: the dimension of the output.
- max_period: controls the minimum frequency of the embeddings.
Returns
an (N, D) Tensor of positional embeddings.
54class MLPEmbedder(nn.Module): 55 def __init__(self, in_dim: int, hidden_dim: int): 56 super().__init__() 57 self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 58 self.silu = nn.SiLU() 59 self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 60 61 def forward(self, x: Tensor) -> Tensor: 62 return self.out_layer(self.silu(self.in_layer(x)))
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
55 def __init__(self, in_dim: int, hidden_dim: int): 56 super().__init__() 57 self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 58 self.silu = nn.SiLU() 59 self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
65class RMSNorm(torch.nn.Module): 66 def __init__(self, dim: int): 67 super().__init__() 68 self.scale = nn.Parameter(torch.ones(dim)) 69 70 def forward(self, x: Tensor): 71 x_dtype = x.dtype 72 x = x.float() 73 rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 74 return (x * rrms).to(dtype=x_dtype) * self.scale
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
66 def __init__(self, dim: int): 67 super().__init__() 68 self.scale = nn.Parameter(torch.ones(dim))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
70 def forward(self, x: Tensor): 71 x_dtype = x.dtype 72 x = x.float() 73 rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 74 return (x * rrms).to(dtype=x_dtype) * self.scale
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
77class QKNorm(torch.nn.Module): 78 def __init__(self, dim: int): 79 super().__init__() 80 self.query_norm = RMSNorm(dim) 81 self.key_norm = RMSNorm(dim) 82 83 def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 84 q = self.query_norm(q) 85 k = self.key_norm(k) 86 return q.to(v), k.to(v)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
78 def __init__(self, dim: int): 79 super().__init__() 80 self.query_norm = RMSNorm(dim) 81 self.key_norm = RMSNorm(dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
83 def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 84 q = self.query_norm(q) 85 k = self.key_norm(k) 86 return q.to(v), k.to(v)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
89class SelfAttention(nn.Module): 90 def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 91 super().__init__() 92 self.num_heads = num_heads 93 head_dim = dim // num_heads 94 95 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 96 self.norm = QKNorm(head_dim) 97 self.proj = nn.Linear(dim, dim) 98 99 def forward(self, x: Tensor, pe: Tensor) -> Tensor: 100 qkv = self.qkv(x) 101 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 102 q, k = self.norm(q, k, v) 103 x = attention(q, k, v, pe=pe) 104 x = self.proj(x) 105 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
90 def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 91 super().__init__() 92 self.num_heads = num_heads 93 head_dim = dim // num_heads 94 95 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 96 self.norm = QKNorm(head_dim) 97 self.proj = nn.Linear(dim, dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
99 def forward(self, x: Tensor, pe: Tensor) -> Tensor: 100 qkv = self.qkv(x) 101 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 102 q, k = self.norm(q, k, v) 103 x = attention(q, k, v, pe=pe) 104 x = self.proj(x) 105 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
115class Modulation(nn.Module): 116 def __init__(self, dim: int, double: bool): 117 super().__init__() 118 self.is_double = double 119 self.multiplier = 6 if double else 3 120 self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 121 122 def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 123 out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) 124 125 return ( 126 ModulationOut(*out[:3]), 127 ModulationOut(*out[3:]) if self.is_double else None, 128 )
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
116 def __init__(self, dim: int, double: bool): 117 super().__init__() 118 self.is_double = double 119 self.multiplier = 6 if double else 3 120 self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
122 def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 123 out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) 124 125 return ( 126 ModulationOut(*out[:3]), 127 ModulationOut(*out[3:]) if self.is_double else None, 128 )
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
131class DoubleStreamBlock(nn.Module): 132 def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 133 super().__init__() 134 135 mlp_hidden_dim = int(hidden_size * mlp_ratio) 136 self.num_heads = num_heads 137 self.hidden_size = hidden_size 138 self.img_mod = Modulation(hidden_size, double=True) 139 self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 140 self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 141 142 self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 143 self.img_mlp = nn.Sequential( 144 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 145 nn.GELU(approximate="tanh"), 146 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 147 ) 148 149 self.txt_mod = Modulation(hidden_size, double=True) 150 self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 151 self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 152 153 self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 154 self.txt_mlp = nn.Sequential( 155 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 156 nn.GELU(approximate="tanh"), 157 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 158 ) 159 160 def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: 161 img_mod1, img_mod2 = self.img_mod(vec) 162 txt_mod1, txt_mod2 = self.txt_mod(vec) 163 164 # prepare image for attention 165 img_modulated = self.img_norm1(img) 166 img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 167 img_qkv = self.img_attn.qkv(img_modulated) 168 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 169 img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 170 171 # prepare txt for attention 172 txt_modulated = self.txt_norm1(txt) 173 txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 174 txt_qkv = self.txt_attn.qkv(txt_modulated) 175 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 176 txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 177 178 # run actual attention 179 q = torch.cat((txt_q, img_q), dim=2) 180 k = torch.cat((txt_k, img_k), dim=2) 181 v = torch.cat((txt_v, img_v), dim=2) 182 183 attn = attention(q, k, v, pe=pe) 184 txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] 185 186 # calculate the img blocks 187 img = img + img_mod1.gate * self.img_attn.proj(img_attn) 188 img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) 189 190 # calculate the txt blocks 191 txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) 192 txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) 193 return img, txt
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
132 def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 133 super().__init__() 134 135 mlp_hidden_dim = int(hidden_size * mlp_ratio) 136 self.num_heads = num_heads 137 self.hidden_size = hidden_size 138 self.img_mod = Modulation(hidden_size, double=True) 139 self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 140 self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 141 142 self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 143 self.img_mlp = nn.Sequential( 144 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 145 nn.GELU(approximate="tanh"), 146 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 147 ) 148 149 self.txt_mod = Modulation(hidden_size, double=True) 150 self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 151 self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 152 153 self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 154 self.txt_mlp = nn.Sequential( 155 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 156 nn.GELU(approximate="tanh"), 157 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 158 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
160 def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: 161 img_mod1, img_mod2 = self.img_mod(vec) 162 txt_mod1, txt_mod2 = self.txt_mod(vec) 163 164 # prepare image for attention 165 img_modulated = self.img_norm1(img) 166 img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 167 img_qkv = self.img_attn.qkv(img_modulated) 168 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 169 img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 170 171 # prepare txt for attention 172 txt_modulated = self.txt_norm1(txt) 173 txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 174 txt_qkv = self.txt_attn.qkv(txt_modulated) 175 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 176 txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 177 178 # run actual attention 179 q = torch.cat((txt_q, img_q), dim=2) 180 k = torch.cat((txt_k, img_k), dim=2) 181 v = torch.cat((txt_v, img_v), dim=2) 182 183 attn = attention(q, k, v, pe=pe) 184 txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] 185 186 # calculate the img blocks 187 img = img + img_mod1.gate * self.img_attn.proj(img_attn) 188 img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) 189 190 # calculate the txt blocks 191 txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) 192 txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) 193 return img, txt
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
196class SingleStreamBlock(nn.Module): 197 """A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface.""" 198 199 def __init__( 200 self, 201 hidden_size: int, 202 num_heads: int, 203 mlp_ratio: float = 4.0, 204 qk_scale: float | None = None, 205 ): 206 super().__init__() 207 self.hidden_dim = hidden_size 208 self.num_heads = num_heads 209 head_dim = hidden_size // num_heads 210 self.scale = qk_scale or head_dim**-0.5 211 212 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 213 # qkv and mlp_in 214 self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 215 # proj and mlp_out 216 self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 217 218 self.norm = QKNorm(head_dim) 219 220 self.hidden_size = hidden_size 221 self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 222 223 self.mlp_act = nn.GELU(approximate="tanh") 224 self.modulation = Modulation(hidden_size, double=False) 225 226 def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 227 mod, _ = self.modulation(vec) 228 x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift 229 qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) 230 231 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 232 q, k = self.norm(q, k, v) 233 234 # compute attention 235 attn = attention(q, k, v, pe=pe) 236 # compute activation in mlp stream, cat again and run second linear layer 237 output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 238 return x + mod.gate * output
A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface.
199 def __init__( 200 self, 201 hidden_size: int, 202 num_heads: int, 203 mlp_ratio: float = 4.0, 204 qk_scale: float | None = None, 205 ): 206 super().__init__() 207 self.hidden_dim = hidden_size 208 self.num_heads = num_heads 209 head_dim = hidden_size // num_heads 210 self.scale = qk_scale or head_dim**-0.5 211 212 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 213 # qkv and mlp_in 214 self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 215 # proj and mlp_out 216 self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 217 218 self.norm = QKNorm(head_dim) 219 220 self.hidden_size = hidden_size 221 self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 222 223 self.mlp_act = nn.GELU(approximate="tanh") 224 self.modulation = Modulation(hidden_size, double=False)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
226 def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 227 mod, _ = self.modulation(vec) 228 x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift 229 qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) 230 231 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 232 q, k = self.norm(q, k, v) 233 234 # compute attention 235 attn = attention(q, k, v, pe=pe) 236 # compute activation in mlp stream, cat again and run second linear layer 237 output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 238 return x + mod.gate * output
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
241class LastLayer(nn.Module): 242 def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 243 super().__init__() 244 self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 245 self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 246 self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) 247 248 def forward(self, x: Tensor, vec: Tensor) -> Tensor: 249 shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 250 x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 251 x = self.linear(x) 252 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
242 def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 243 super().__init__() 244 self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 245 self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 246 self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
248 def forward(self, x: Tensor, vec: Tensor) -> Tensor: 249 shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 250 x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 251 x = self.linear(x) 252 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.