divisor.xflux1.layers
1# SPDX-License-Identifier:Apache-2.0 2# original XFlux code from https://github.com/TencentARC/FluxKits 3 4# type: ignore 5from einops import rearrange 6import torch 7from torch import Tensor, nn 8import torch.nn.functional as F 9 10from divisor.flux1.layers import Modulation, QKNorm 11from divisor.flux1.math import attention 12 13 14class LoRALinearLayer(nn.Module): 15 def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): 16 super().__init__() 17 18 self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 19 self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 20 # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 21 # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 22 self.network_alpha = network_alpha 23 self.rank = rank 24 25 nn.init.normal_(self.down.weight, std=1 / rank) 26 nn.init.zeros_(self.up.weight) 27 28 def forward(self, hidden_states): 29 orig_dtype = hidden_states.dtype 30 dtype = self.down.weight.dtype 31 32 down_hidden_states = self.down(hidden_states.to(dtype)) 33 up_hidden_states = self.up(down_hidden_states) 34 35 if self.network_alpha is not None: 36 up_hidden_states *= self.network_alpha / self.rank 37 38 return up_hidden_states.to(orig_dtype) 39 40 41class FLuxSelfAttnProcessor: 42 def __call__(self, attn, x, pe, **attention_kwargs): 43 qkv = attn.qkv(x) 44 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 45 q, k = attn.norm(q, k, v) 46 x = attention(q, k, v, pe=pe) 47 x = attn.proj(x) 48 return x 49 50 51class LoraFluxAttnProcessor(nn.Module): 52 def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 53 super().__init__() 54 self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 55 self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 56 self.lora_weight = lora_weight 57 58 def __call__(self, attn, x, pe, **attention_kwargs): 59 qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight 60 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 61 q, k = attn.norm(q, k, v) 62 x = attention(q, k, v, pe=pe) 63 x = attn.proj(x) + self.proj_lora(x) * self.lora_weight 64 return x 65 66 67class SelfAttention(nn.Module): 68 def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 69 super().__init__() 70 self.num_heads = num_heads 71 head_dim = dim // num_heads 72 73 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 74 self.norm = QKNorm(head_dim) 75 self.proj = nn.Linear(dim, dim) 76 77 def forward(): 78 pass 79 80 81class DoubleStreamBlockLoraProcessor(nn.Module): 82 def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 83 super().__init__() 84 self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 85 self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) 86 self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 87 self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) 88 self.lora_weight = lora_weight 89 90 def __call__(self, attn, img, txt, vec, pe): 91 img_mod1, img_mod2 = attn.img_mod(vec) 92 txt_mod1, txt_mod2 = attn.txt_mod(vec) 93 94 # prepare image for attention 95 img_modulated = attn.img_norm1(img) 96 img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 97 img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight 98 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 99 img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 100 101 # prepare txt for attention 102 txt_modulated = attn.txt_norm1(txt) 103 txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 104 txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight 105 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 106 txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 107 108 # run actual attention 109 q = torch.cat((txt_q, img_q), dim=2) 110 k = torch.cat((txt_k, img_k), dim=2) 111 v = torch.cat((txt_v, img_v), dim=2) 112 113 attn1 = attention(q, k, v, pe=pe) 114 txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 115 116 # calculate the img blocks 117 img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight 118 img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 119 120 # calculate the txt blocks 121 txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight 122 txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 123 124 return img, txt 125 126 127class IPDoubleStreamBlockProcessor(nn.Module): 128 """Attention processor for handling IP-adapter with double stream block.""" 129 130 def __init__(self, context_dim, hidden_dim): 131 super().__init__() 132 if not hasattr(F, "scaled_dot_product_attention"): 133 raise ImportError("IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.") 134 135 # Ensure context_dim matches the dimension of image_proj 136 self.context_dim = context_dim 137 self.hidden_dim = hidden_dim 138 139 # Initialize projections for IP-adapter 140 self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True) 141 self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True) 142 143 nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight) 144 nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias) 145 146 nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight) 147 nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias) 148 149 def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs): 150 # Prepare image for attention 151 img_mod1, img_mod2 = attn.img_mod(vec) 152 txt_mod1, txt_mod2 = attn.txt_mod(vec) 153 154 img_modulated = attn.img_norm1(img) 155 img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 156 img_qkv = attn.img_attn.qkv(img_modulated) 157 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 158 img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 159 160 txt_modulated = attn.txt_norm1(txt) 161 txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 162 txt_qkv = attn.txt_attn.qkv(txt_modulated) 163 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 164 txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 165 166 q = torch.cat((txt_q, img_q), dim=2) 167 k = torch.cat((txt_k, img_k), dim=2) 168 v = torch.cat((txt_v, img_v), dim=2) 169 170 attn1 = attention(q, k, v, pe=pe) 171 txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 172 173 img = img + img_mod1.gate * attn.img_attn.proj(img_attn) 174 img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 175 176 txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) 177 txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 178 179 # IP-adapter processing 180 ip_query = img_q # latent sample query 181 ip_key = self.ip_adapter_double_stream_k_proj(image_proj) 182 ip_value = self.ip_adapter_double_stream_v_proj(image_proj) 183 184 # Reshape projections for multi-head attention 185 ip_key = rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim) 186 ip_value = rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim) 187 188 # Compute attention between IP projections and the latent query 189 ip_attention = F.scaled_dot_product_attention(ip_query, ip_key, ip_value, dropout_p=0.0, is_causal=False) 190 ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim) 191 192 img = img + ip_scale * ip_attention 193 194 return img, txt 195 196 197class DoubleStreamBlockProcessor: 198 def __call__(self, attn, img, txt, vec, pe): 199 img_mod1, img_mod2 = attn.img_mod(vec) 200 txt_mod1, txt_mod2 = attn.txt_mod(vec) 201 202 # prepare image for attention 203 img_modulated = attn.img_norm1(img) 204 img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 205 img_qkv = attn.img_attn.qkv(img_modulated) 206 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 207 img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 208 209 # prepare txt for attention 210 txt_modulated = attn.txt_norm1(txt) 211 txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 212 txt_qkv = attn.txt_attn.qkv(txt_modulated) 213 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 214 txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 215 216 # run actual attention 217 q = torch.cat((txt_q, img_q), dim=2) 218 k = torch.cat((txt_k, img_k), dim=2) 219 v = torch.cat((txt_v, img_v), dim=2) 220 221 attn1 = attention(q, k, v, pe=pe) 222 txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 223 224 # calculate the img bloks 225 img = img + img_mod1.gate * attn.img_attn.proj(img_attn) 226 img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 227 228 # calculate the txt bloks 229 txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) 230 txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 231 232 return img, txt 233 234 235class DoubleStreamBlock(nn.Module): 236 def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 237 super().__init__() 238 mlp_hidden_dim = int(hidden_size * mlp_ratio) 239 self.num_heads = num_heads 240 self.hidden_size = hidden_size 241 self.head_dim = hidden_size // num_heads 242 243 self.img_mod = Modulation(hidden_size, double=True) 244 self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 245 self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 246 247 self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 248 self.img_mlp = nn.Sequential( 249 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 250 nn.GELU(approximate="tanh"), 251 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 252 ) 253 254 self.txt_mod = Modulation(hidden_size, double=True) 255 self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 256 self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 257 258 self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 259 self.txt_mlp = nn.Sequential( 260 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 261 nn.GELU(approximate="tanh"), 262 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 263 ) 264 265 processor = DoubleStreamBlockProcessor() 266 self.set_processor(processor) 267 268 def set_processor(self, processor): 269 self.processor = processor 270 271 def get_processor(self): 272 return self.processor 273 274 def forward( 275 self, 276 img: Tensor, 277 txt: Tensor, 278 vec: Tensor, 279 pe: Tensor, 280 image_proj: Tensor = None, 281 ip_scale: float = 1.0, 282 ): 283 if image_proj is None: 284 return self.processor(self, img, txt, vec, pe) 285 else: 286 return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) 287 288 289class IPSingleStreamBlockProcessor(nn.Module): 290 """Attention processor for handling IP-adapter with single stream block.""" 291 292 def __init__(self, context_dim, hidden_dim): 293 super().__init__() 294 if not hasattr(F, "scaled_dot_product_attention"): 295 raise ImportError("IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.") 296 297 # Ensure context_dim matches the dimension of image_proj 298 self.context_dim = context_dim 299 self.hidden_dim = hidden_dim 300 301 # Initialize projections for IP-adapter 302 self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False) 303 self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False) 304 305 nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight) 306 nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight) 307 308 def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0): 309 mod, _ = attn.modulation(vec) 310 x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 311 qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 312 313 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 314 q, k = attn.norm(q, k, v) 315 316 # compute attention 317 attn_1 = attention(q, k, v, pe=pe) 318 319 # IP-adapter processing 320 ip_query = q 321 ip_key = self.ip_adapter_single_stream_k_proj(image_proj) 322 ip_value = self.ip_adapter_single_stream_v_proj(image_proj) 323 324 # Reshape projections for multi-head attention 325 ip_key = rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim) 326 ip_value = rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim) 327 328 # Compute attention between IP projections and the latent query 329 ip_attention = F.scaled_dot_product_attention(ip_query, ip_key, ip_value) 330 ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)") 331 332 attn_out = attn_1 + ip_scale * ip_attention 333 334 # compute activation in mlp stream, cat again and run second linear layer 335 output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2)) 336 out = x + mod.gate * output 337 338 return out 339 340 341class SingleStreamBlockLoraProcessor(nn.Module): 342 def __init__(self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1): 343 super().__init__() 344 self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 345 self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 346 self.lora_weight = lora_weight 347 348 def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor): 349 mod, _ = attn.modulation(vec) 350 x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 351 qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 352 qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight 353 354 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 355 q, k = attn.norm(q, k, v) 356 357 # compute attention 358 attn_1 = attention(q, k, v, pe=pe) 359 360 # compute activation in mlp stream, cat again and run second linear layer 361 output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 362 output = output + self.proj_lora(output) * self.lora_weight 363 output = x + mod.gate * output 364 365 return output 366 367 368class SingleStreamBlockProcessor: 369 def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor): 370 mod, _ = attn.modulation(vec) 371 x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 372 qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 373 374 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 375 q, k = attn.norm(q, k, v) 376 377 # compute attention 378 attn_1 = attention(q, k, v, pe=pe) 379 380 # compute activation in mlp stream, cat again and run second linear layer 381 output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 382 output = x + mod.gate * output 383 384 return output 385 386 387class SingleStreamBlock(nn.Module): 388 """ 389 A DiT block with parallel linear layers as described in 390 https://arxiv.org/abs/2302.05442 and adapted modulation interface. 391 """ 392 393 def __init__( 394 self, 395 hidden_size: int, 396 num_heads: int, 397 mlp_ratio: float = 4.0, 398 qk_scale: float = None, 399 ): 400 super().__init__() 401 self.hidden_dim = hidden_size 402 self.num_heads = num_heads 403 self.head_dim = hidden_size // num_heads 404 self.scale = qk_scale or self.head_dim**-0.5 405 406 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 407 # qkv and mlp_in 408 self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 409 # proj and mlp_out 410 self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 411 412 self.norm = QKNorm(self.head_dim) 413 414 self.hidden_size = hidden_size 415 self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 416 417 self.mlp_act = nn.GELU(approximate="tanh") 418 self.modulation = Modulation(hidden_size, double=False) 419 420 processor = SingleStreamBlockProcessor() 421 self.set_processor(processor) 422 423 def set_processor(self, processor): 424 self.processor = processor 425 426 def get_processor(self): 427 return self.processor 428 429 def forward(self, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0): 430 if image_proj is None: 431 return self.processor(self, x, vec, pe) 432 else: 433 return self.processor(self, x, vec, pe, image_proj, ip_scale) 434 435 436class ImageProjModel(torch.nn.Module): 437 """Projection Model 438 https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28 439 """ 440 441 def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 442 super().__init__() 443 444 self.generator = None 445 self.cross_attention_dim = cross_attention_dim 446 self.clip_extra_context_tokens = clip_extra_context_tokens 447 self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 448 self.norm = torch.nn.LayerNorm(cross_attention_dim) 449 450 def forward(self, image_embeds): 451 embeds = image_embeds 452 clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 453 clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 454 return clip_extra_context_tokens
15class LoRALinearLayer(nn.Module): 16 def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): 17 super().__init__() 18 19 self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 20 self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 21 # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 22 # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 23 self.network_alpha = network_alpha 24 self.rank = rank 25 26 nn.init.normal_(self.down.weight, std=1 / rank) 27 nn.init.zeros_(self.up.weight) 28 29 def forward(self, hidden_states): 30 orig_dtype = hidden_states.dtype 31 dtype = self.down.weight.dtype 32 33 down_hidden_states = self.down(hidden_states.to(dtype)) 34 up_hidden_states = self.up(down_hidden_states) 35 36 if self.network_alpha is not None: 37 up_hidden_states *= self.network_alpha / self.rank 38 39 return up_hidden_states.to(orig_dtype)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
16 def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): 17 super().__init__() 18 19 self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 20 self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 21 # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 22 # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 23 self.network_alpha = network_alpha 24 self.rank = rank 25 26 nn.init.normal_(self.down.weight, std=1 / rank) 27 nn.init.zeros_(self.up.weight)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
29 def forward(self, hidden_states): 30 orig_dtype = hidden_states.dtype 31 dtype = self.down.weight.dtype 32 33 down_hidden_states = self.down(hidden_states.to(dtype)) 34 up_hidden_states = self.up(down_hidden_states) 35 36 if self.network_alpha is not None: 37 up_hidden_states *= self.network_alpha / self.rank 38 39 return up_hidden_states.to(orig_dtype)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
52class LoraFluxAttnProcessor(nn.Module): 53 def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 54 super().__init__() 55 self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 56 self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 57 self.lora_weight = lora_weight 58 59 def __call__(self, attn, x, pe, **attention_kwargs): 60 qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight 61 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 62 q, k = attn.norm(q, k, v) 63 x = attention(q, k, v, pe=pe) 64 x = attn.proj(x) + self.proj_lora(x) * self.lora_weight 65 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
53 def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 54 super().__init__() 55 self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 56 self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 57 self.lora_weight = lora_weight
Initialize internal Module state, shared by both nn.Module and ScriptModule.
68class SelfAttention(nn.Module): 69 def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 70 super().__init__() 71 self.num_heads = num_heads 72 head_dim = dim // num_heads 73 74 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 75 self.norm = QKNorm(head_dim) 76 self.proj = nn.Linear(dim, dim) 77 78 def forward(): 79 pass
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
69 def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 70 super().__init__() 71 self.num_heads = num_heads 72 head_dim = dim // num_heads 73 74 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 75 self.norm = QKNorm(head_dim) 76 self.proj = nn.Linear(dim, dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
82class DoubleStreamBlockLoraProcessor(nn.Module): 83 def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 84 super().__init__() 85 self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 86 self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) 87 self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 88 self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) 89 self.lora_weight = lora_weight 90 91 def __call__(self, attn, img, txt, vec, pe): 92 img_mod1, img_mod2 = attn.img_mod(vec) 93 txt_mod1, txt_mod2 = attn.txt_mod(vec) 94 95 # prepare image for attention 96 img_modulated = attn.img_norm1(img) 97 img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 98 img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight 99 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 100 img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 101 102 # prepare txt for attention 103 txt_modulated = attn.txt_norm1(txt) 104 txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 105 txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight 106 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 107 txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 108 109 # run actual attention 110 q = torch.cat((txt_q, img_q), dim=2) 111 k = torch.cat((txt_k, img_k), dim=2) 112 v = torch.cat((txt_v, img_v), dim=2) 113 114 attn1 = attention(q, k, v, pe=pe) 115 txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 116 117 # calculate the img blocks 118 img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight 119 img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 120 121 # calculate the txt blocks 122 txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight 123 txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 124 125 return img, txt
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
83 def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 84 super().__init__() 85 self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 86 self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) 87 self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 88 self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) 89 self.lora_weight = lora_weight
Initialize internal Module state, shared by both nn.Module and ScriptModule.
128class IPDoubleStreamBlockProcessor(nn.Module): 129 """Attention processor for handling IP-adapter with double stream block.""" 130 131 def __init__(self, context_dim, hidden_dim): 132 super().__init__() 133 if not hasattr(F, "scaled_dot_product_attention"): 134 raise ImportError("IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.") 135 136 # Ensure context_dim matches the dimension of image_proj 137 self.context_dim = context_dim 138 self.hidden_dim = hidden_dim 139 140 # Initialize projections for IP-adapter 141 self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True) 142 self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True) 143 144 nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight) 145 nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias) 146 147 nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight) 148 nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias) 149 150 def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs): 151 # Prepare image for attention 152 img_mod1, img_mod2 = attn.img_mod(vec) 153 txt_mod1, txt_mod2 = attn.txt_mod(vec) 154 155 img_modulated = attn.img_norm1(img) 156 img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 157 img_qkv = attn.img_attn.qkv(img_modulated) 158 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 159 img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 160 161 txt_modulated = attn.txt_norm1(txt) 162 txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 163 txt_qkv = attn.txt_attn.qkv(txt_modulated) 164 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 165 txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 166 167 q = torch.cat((txt_q, img_q), dim=2) 168 k = torch.cat((txt_k, img_k), dim=2) 169 v = torch.cat((txt_v, img_v), dim=2) 170 171 attn1 = attention(q, k, v, pe=pe) 172 txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 173 174 img = img + img_mod1.gate * attn.img_attn.proj(img_attn) 175 img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 176 177 txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) 178 txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 179 180 # IP-adapter processing 181 ip_query = img_q # latent sample query 182 ip_key = self.ip_adapter_double_stream_k_proj(image_proj) 183 ip_value = self.ip_adapter_double_stream_v_proj(image_proj) 184 185 # Reshape projections for multi-head attention 186 ip_key = rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim) 187 ip_value = rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim) 188 189 # Compute attention between IP projections and the latent query 190 ip_attention = F.scaled_dot_product_attention(ip_query, ip_key, ip_value, dropout_p=0.0, is_causal=False) 191 ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim) 192 193 img = img + ip_scale * ip_attention 194 195 return img, txt
Attention processor for handling IP-adapter with double stream block.
131 def __init__(self, context_dim, hidden_dim): 132 super().__init__() 133 if not hasattr(F, "scaled_dot_product_attention"): 134 raise ImportError("IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.") 135 136 # Ensure context_dim matches the dimension of image_proj 137 self.context_dim = context_dim 138 self.hidden_dim = hidden_dim 139 140 # Initialize projections for IP-adapter 141 self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True) 142 self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True) 143 144 nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight) 145 nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias) 146 147 nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight) 148 nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
198class DoubleStreamBlockProcessor: 199 def __call__(self, attn, img, txt, vec, pe): 200 img_mod1, img_mod2 = attn.img_mod(vec) 201 txt_mod1, txt_mod2 = attn.txt_mod(vec) 202 203 # prepare image for attention 204 img_modulated = attn.img_norm1(img) 205 img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 206 img_qkv = attn.img_attn.qkv(img_modulated) 207 img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 208 img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 209 210 # prepare txt for attention 211 txt_modulated = attn.txt_norm1(txt) 212 txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 213 txt_qkv = attn.txt_attn.qkv(txt_modulated) 214 txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 215 txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 216 217 # run actual attention 218 q = torch.cat((txt_q, img_q), dim=2) 219 k = torch.cat((txt_k, img_k), dim=2) 220 v = torch.cat((txt_v, img_v), dim=2) 221 222 attn1 = attention(q, k, v, pe=pe) 223 txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 224 225 # calculate the img bloks 226 img = img + img_mod1.gate * attn.img_attn.proj(img_attn) 227 img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 228 229 # calculate the txt bloks 230 txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) 231 txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 232 233 return img, txt
236class DoubleStreamBlock(nn.Module): 237 def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 238 super().__init__() 239 mlp_hidden_dim = int(hidden_size * mlp_ratio) 240 self.num_heads = num_heads 241 self.hidden_size = hidden_size 242 self.head_dim = hidden_size // num_heads 243 244 self.img_mod = Modulation(hidden_size, double=True) 245 self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 246 self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 247 248 self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 249 self.img_mlp = nn.Sequential( 250 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 251 nn.GELU(approximate="tanh"), 252 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 253 ) 254 255 self.txt_mod = Modulation(hidden_size, double=True) 256 self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 257 self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 258 259 self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 260 self.txt_mlp = nn.Sequential( 261 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 262 nn.GELU(approximate="tanh"), 263 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 264 ) 265 266 processor = DoubleStreamBlockProcessor() 267 self.set_processor(processor) 268 269 def set_processor(self, processor): 270 self.processor = processor 271 272 def get_processor(self): 273 return self.processor 274 275 def forward( 276 self, 277 img: Tensor, 278 txt: Tensor, 279 vec: Tensor, 280 pe: Tensor, 281 image_proj: Tensor = None, 282 ip_scale: float = 1.0, 283 ): 284 if image_proj is None: 285 return self.processor(self, img, txt, vec, pe) 286 else: 287 return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
237 def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 238 super().__init__() 239 mlp_hidden_dim = int(hidden_size * mlp_ratio) 240 self.num_heads = num_heads 241 self.hidden_size = hidden_size 242 self.head_dim = hidden_size // num_heads 243 244 self.img_mod = Modulation(hidden_size, double=True) 245 self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 246 self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 247 248 self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 249 self.img_mlp = nn.Sequential( 250 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 251 nn.GELU(approximate="tanh"), 252 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 253 ) 254 255 self.txt_mod = Modulation(hidden_size, double=True) 256 self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 257 self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 258 259 self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 260 self.txt_mlp = nn.Sequential( 261 nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 262 nn.GELU(approximate="tanh"), 263 nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 264 ) 265 266 processor = DoubleStreamBlockProcessor() 267 self.set_processor(processor)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
275 def forward( 276 self, 277 img: Tensor, 278 txt: Tensor, 279 vec: Tensor, 280 pe: Tensor, 281 image_proj: Tensor = None, 282 ip_scale: float = 1.0, 283 ): 284 if image_proj is None: 285 return self.processor(self, img, txt, vec, pe) 286 else: 287 return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
290class IPSingleStreamBlockProcessor(nn.Module): 291 """Attention processor for handling IP-adapter with single stream block.""" 292 293 def __init__(self, context_dim, hidden_dim): 294 super().__init__() 295 if not hasattr(F, "scaled_dot_product_attention"): 296 raise ImportError("IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.") 297 298 # Ensure context_dim matches the dimension of image_proj 299 self.context_dim = context_dim 300 self.hidden_dim = hidden_dim 301 302 # Initialize projections for IP-adapter 303 self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False) 304 self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False) 305 306 nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight) 307 nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight) 308 309 def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0): 310 mod, _ = attn.modulation(vec) 311 x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 312 qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 313 314 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 315 q, k = attn.norm(q, k, v) 316 317 # compute attention 318 attn_1 = attention(q, k, v, pe=pe) 319 320 # IP-adapter processing 321 ip_query = q 322 ip_key = self.ip_adapter_single_stream_k_proj(image_proj) 323 ip_value = self.ip_adapter_single_stream_v_proj(image_proj) 324 325 # Reshape projections for multi-head attention 326 ip_key = rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim) 327 ip_value = rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim) 328 329 # Compute attention between IP projections and the latent query 330 ip_attention = F.scaled_dot_product_attention(ip_query, ip_key, ip_value) 331 ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)") 332 333 attn_out = attn_1 + ip_scale * ip_attention 334 335 # compute activation in mlp stream, cat again and run second linear layer 336 output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2)) 337 out = x + mod.gate * output 338 339 return out
Attention processor for handling IP-adapter with single stream block.
293 def __init__(self, context_dim, hidden_dim): 294 super().__init__() 295 if not hasattr(F, "scaled_dot_product_attention"): 296 raise ImportError("IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.") 297 298 # Ensure context_dim matches the dimension of image_proj 299 self.context_dim = context_dim 300 self.hidden_dim = hidden_dim 301 302 # Initialize projections for IP-adapter 303 self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False) 304 self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False) 305 306 nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight) 307 nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
342class SingleStreamBlockLoraProcessor(nn.Module): 343 def __init__(self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1): 344 super().__init__() 345 self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 346 self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 347 self.lora_weight = lora_weight 348 349 def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor): 350 mod, _ = attn.modulation(vec) 351 x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 352 qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 353 qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight 354 355 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 356 q, k = attn.norm(q, k, v) 357 358 # compute attention 359 attn_1 = attention(q, k, v, pe=pe) 360 361 # compute activation in mlp stream, cat again and run second linear layer 362 output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 363 output = output + self.proj_lora(output) * self.lora_weight 364 output = x + mod.gate * output 365 366 return output
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
343 def __init__(self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1): 344 super().__init__() 345 self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 346 self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 347 self.lora_weight = lora_weight
Initialize internal Module state, shared by both nn.Module and ScriptModule.
369class SingleStreamBlockProcessor: 370 def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor): 371 mod, _ = attn.modulation(vec) 372 x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 373 qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 374 375 q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 376 q, k = attn.norm(q, k, v) 377 378 # compute attention 379 attn_1 = attention(q, k, v, pe=pe) 380 381 # compute activation in mlp stream, cat again and run second linear layer 382 output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 383 output = x + mod.gate * output 384 385 return output
388class SingleStreamBlock(nn.Module): 389 """ 390 A DiT block with parallel linear layers as described in 391 https://arxiv.org/abs/2302.05442 and adapted modulation interface. 392 """ 393 394 def __init__( 395 self, 396 hidden_size: int, 397 num_heads: int, 398 mlp_ratio: float = 4.0, 399 qk_scale: float = None, 400 ): 401 super().__init__() 402 self.hidden_dim = hidden_size 403 self.num_heads = num_heads 404 self.head_dim = hidden_size // num_heads 405 self.scale = qk_scale or self.head_dim**-0.5 406 407 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 408 # qkv and mlp_in 409 self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 410 # proj and mlp_out 411 self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 412 413 self.norm = QKNorm(self.head_dim) 414 415 self.hidden_size = hidden_size 416 self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 417 418 self.mlp_act = nn.GELU(approximate="tanh") 419 self.modulation = Modulation(hidden_size, double=False) 420 421 processor = SingleStreamBlockProcessor() 422 self.set_processor(processor) 423 424 def set_processor(self, processor): 425 self.processor = processor 426 427 def get_processor(self): 428 return self.processor 429 430 def forward(self, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0): 431 if image_proj is None: 432 return self.processor(self, x, vec, pe) 433 else: 434 return self.processor(self, x, vec, pe, image_proj, ip_scale)
A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface.
394 def __init__( 395 self, 396 hidden_size: int, 397 num_heads: int, 398 mlp_ratio: float = 4.0, 399 qk_scale: float = None, 400 ): 401 super().__init__() 402 self.hidden_dim = hidden_size 403 self.num_heads = num_heads 404 self.head_dim = hidden_size // num_heads 405 self.scale = qk_scale or self.head_dim**-0.5 406 407 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 408 # qkv and mlp_in 409 self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 410 # proj and mlp_out 411 self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 412 413 self.norm = QKNorm(self.head_dim) 414 415 self.hidden_size = hidden_size 416 self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 417 418 self.mlp_act = nn.GELU(approximate="tanh") 419 self.modulation = Modulation(hidden_size, double=False) 420 421 processor = SingleStreamBlockProcessor() 422 self.set_processor(processor)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
430 def forward(self, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0): 431 if image_proj is None: 432 return self.processor(self, x, vec, pe) 433 else: 434 return self.processor(self, x, vec, pe, image_proj, ip_scale)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
437class ImageProjModel(torch.nn.Module): 438 """Projection Model 439 https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28 440 """ 441 442 def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 443 super().__init__() 444 445 self.generator = None 446 self.cross_attention_dim = cross_attention_dim 447 self.clip_extra_context_tokens = clip_extra_context_tokens 448 self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 449 self.norm = torch.nn.LayerNorm(cross_attention_dim) 450 451 def forward(self, image_embeds): 452 embeds = image_embeds 453 clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 454 clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 455 return clip_extra_context_tokens
442 def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 443 super().__init__() 444 445 self.generator = None 446 self.cross_attention_dim = cross_attention_dim 447 self.clip_extra_context_tokens = clip_extra_context_tokens 448 self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 449 self.norm = torch.nn.LayerNorm(cross_attention_dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
451 def forward(self, image_embeds): 452 embeds = image_embeds 453 clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 454 clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 455 return clip_extra_context_tokens
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.