divisor.xflux1.layers

  1# SPDX-License-Identifier:Apache-2.0
  2# original XFlux code from https://github.com/TencentARC/FluxKits
  3
  4# type: ignore
  5from einops import rearrange
  6import torch
  7from torch import Tensor, nn
  8import torch.nn.functional as F
  9
 10from divisor.flux1.layers import Modulation, QKNorm
 11from divisor.flux1.math import attention
 12
 13
 14class LoRALinearLayer(nn.Module):
 15    def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
 16        super().__init__()
 17
 18        self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
 19        self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
 20        # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
 21        # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
 22        self.network_alpha = network_alpha
 23        self.rank = rank
 24
 25        nn.init.normal_(self.down.weight, std=1 / rank)
 26        nn.init.zeros_(self.up.weight)
 27
 28    def forward(self, hidden_states):
 29        orig_dtype = hidden_states.dtype
 30        dtype = self.down.weight.dtype
 31
 32        down_hidden_states = self.down(hidden_states.to(dtype))
 33        up_hidden_states = self.up(down_hidden_states)
 34
 35        if self.network_alpha is not None:
 36            up_hidden_states *= self.network_alpha / self.rank
 37
 38        return up_hidden_states.to(orig_dtype)
 39
 40
 41class FLuxSelfAttnProcessor:
 42    def __call__(self, attn, x, pe, **attention_kwargs):
 43        qkv = attn.qkv(x)
 44        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
 45        q, k = attn.norm(q, k, v)
 46        x = attention(q, k, v, pe=pe)
 47        x = attn.proj(x)
 48        return x
 49
 50
 51class LoraFluxAttnProcessor(nn.Module):
 52    def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
 53        super().__init__()
 54        self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
 55        self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
 56        self.lora_weight = lora_weight
 57
 58    def __call__(self, attn, x, pe, **attention_kwargs):
 59        qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
 60        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
 61        q, k = attn.norm(q, k, v)
 62        x = attention(q, k, v, pe=pe)
 63        x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
 64        return x
 65
 66
 67class SelfAttention(nn.Module):
 68    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
 69        super().__init__()
 70        self.num_heads = num_heads
 71        head_dim = dim // num_heads
 72
 73        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 74        self.norm = QKNorm(head_dim)
 75        self.proj = nn.Linear(dim, dim)
 76
 77    def forward():
 78        pass
 79
 80
 81class DoubleStreamBlockLoraProcessor(nn.Module):
 82    def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
 83        super().__init__()
 84        self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
 85        self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
 86        self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
 87        self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
 88        self.lora_weight = lora_weight
 89
 90    def __call__(self, attn, img, txt, vec, pe):
 91        img_mod1, img_mod2 = attn.img_mod(vec)
 92        txt_mod1, txt_mod2 = attn.txt_mod(vec)
 93
 94        # prepare image for attention
 95        img_modulated = attn.img_norm1(img)
 96        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
 97        img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
 98        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
 99        img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
100
101        # prepare txt for attention
102        txt_modulated = attn.txt_norm1(txt)
103        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
104        txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
105        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
106        txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
107
108        # run actual attention
109        q = torch.cat((txt_q, img_q), dim=2)
110        k = torch.cat((txt_k, img_k), dim=2)
111        v = torch.cat((txt_v, img_v), dim=2)
112
113        attn1 = attention(q, k, v, pe=pe)
114        txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
115
116        # calculate the img blocks
117        img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
118        img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
119
120        # calculate the txt blocks
121        txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
122        txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
123
124        return img, txt
125
126
127class IPDoubleStreamBlockProcessor(nn.Module):
128    """Attention processor for handling IP-adapter with double stream block."""
129
130    def __init__(self, context_dim, hidden_dim):
131        super().__init__()
132        if not hasattr(F, "scaled_dot_product_attention"):
133            raise ImportError("IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.")
134
135        # Ensure context_dim matches the dimension of image_proj
136        self.context_dim = context_dim
137        self.hidden_dim = hidden_dim
138
139        # Initialize projections for IP-adapter
140        self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True)
141        self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True)
142
143        nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
144        nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
145
146        nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
147        nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
148
149    def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs):
150        # Prepare image for attention
151        img_mod1, img_mod2 = attn.img_mod(vec)
152        txt_mod1, txt_mod2 = attn.txt_mod(vec)
153
154        img_modulated = attn.img_norm1(img)
155        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
156        img_qkv = attn.img_attn.qkv(img_modulated)
157        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
158        img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
159
160        txt_modulated = attn.txt_norm1(txt)
161        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
162        txt_qkv = attn.txt_attn.qkv(txt_modulated)
163        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
164        txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
165
166        q = torch.cat((txt_q, img_q), dim=2)
167        k = torch.cat((txt_k, img_k), dim=2)
168        v = torch.cat((txt_v, img_v), dim=2)
169
170        attn1 = attention(q, k, v, pe=pe)
171        txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
172
173        img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
174        img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
175
176        txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
177        txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
178
179        # IP-adapter processing
180        ip_query = img_q  # latent sample query
181        ip_key = self.ip_adapter_double_stream_k_proj(image_proj)
182        ip_value = self.ip_adapter_double_stream_v_proj(image_proj)
183
184        # Reshape projections for multi-head attention
185        ip_key = rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
186        ip_value = rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
187
188        # Compute attention between IP projections and the latent query
189        ip_attention = F.scaled_dot_product_attention(ip_query, ip_key, ip_value, dropout_p=0.0, is_causal=False)
190        ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim)
191
192        img = img + ip_scale * ip_attention
193
194        return img, txt
195
196
197class DoubleStreamBlockProcessor:
198    def __call__(self, attn, img, txt, vec, pe):
199        img_mod1, img_mod2 = attn.img_mod(vec)
200        txt_mod1, txt_mod2 = attn.txt_mod(vec)
201
202        # prepare image for attention
203        img_modulated = attn.img_norm1(img)
204        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
205        img_qkv = attn.img_attn.qkv(img_modulated)
206        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
207        img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
208
209        # prepare txt for attention
210        txt_modulated = attn.txt_norm1(txt)
211        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
212        txt_qkv = attn.txt_attn.qkv(txt_modulated)
213        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
214        txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
215
216        # run actual attention
217        q = torch.cat((txt_q, img_q), dim=2)
218        k = torch.cat((txt_k, img_k), dim=2)
219        v = torch.cat((txt_v, img_v), dim=2)
220
221        attn1 = attention(q, k, v, pe=pe)
222        txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
223
224        # calculate the img bloks
225        img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
226        img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
227
228        # calculate the txt bloks
229        txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
230        txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
231
232        return img, txt
233
234
235class DoubleStreamBlock(nn.Module):
236    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
237        super().__init__()
238        mlp_hidden_dim = int(hidden_size * mlp_ratio)
239        self.num_heads = num_heads
240        self.hidden_size = hidden_size
241        self.head_dim = hidden_size // num_heads
242
243        self.img_mod = Modulation(hidden_size, double=True)
244        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
245        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
246
247        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
248        self.img_mlp = nn.Sequential(
249            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
250            nn.GELU(approximate="tanh"),
251            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
252        )
253
254        self.txt_mod = Modulation(hidden_size, double=True)
255        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
256        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
257
258        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
259        self.txt_mlp = nn.Sequential(
260            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
261            nn.GELU(approximate="tanh"),
262            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
263        )
264
265        processor = DoubleStreamBlockProcessor()
266        self.set_processor(processor)
267
268    def set_processor(self, processor):
269        self.processor = processor
270
271    def get_processor(self):
272        return self.processor
273
274    def forward(
275        self,
276        img: Tensor,
277        txt: Tensor,
278        vec: Tensor,
279        pe: Tensor,
280        image_proj: Tensor = None,
281        ip_scale: float = 1.0,
282    ):
283        if image_proj is None:
284            return self.processor(self, img, txt, vec, pe)
285        else:
286            return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
287
288
289class IPSingleStreamBlockProcessor(nn.Module):
290    """Attention processor for handling IP-adapter with single stream block."""
291
292    def __init__(self, context_dim, hidden_dim):
293        super().__init__()
294        if not hasattr(F, "scaled_dot_product_attention"):
295            raise ImportError("IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.")
296
297        # Ensure context_dim matches the dimension of image_proj
298        self.context_dim = context_dim
299        self.hidden_dim = hidden_dim
300
301        # Initialize projections for IP-adapter
302        self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
303        self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
304
305        nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight)
306        nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight)
307
308    def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0):
309        mod, _ = attn.modulation(vec)
310        x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
311        qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
312
313        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
314        q, k = attn.norm(q, k, v)
315
316        # compute attention
317        attn_1 = attention(q, k, v, pe=pe)
318
319        # IP-adapter processing
320        ip_query = q
321        ip_key = self.ip_adapter_single_stream_k_proj(image_proj)
322        ip_value = self.ip_adapter_single_stream_v_proj(image_proj)
323
324        # Reshape projections for multi-head attention
325        ip_key = rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
326        ip_value = rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
327
328        # Compute attention between IP projections and the latent query
329        ip_attention = F.scaled_dot_product_attention(ip_query, ip_key, ip_value)
330        ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)")
331
332        attn_out = attn_1 + ip_scale * ip_attention
333
334        # compute activation in mlp stream, cat again and run second linear layer
335        output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2))
336        out = x + mod.gate * output
337
338        return out
339
340
341class SingleStreamBlockLoraProcessor(nn.Module):
342    def __init__(self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1):
343        super().__init__()
344        self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
345        self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
346        self.lora_weight = lora_weight
347
348    def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor):
349        mod, _ = attn.modulation(vec)
350        x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
351        qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
352        qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
353
354        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
355        q, k = attn.norm(q, k, v)
356
357        # compute attention
358        attn_1 = attention(q, k, v, pe=pe)
359
360        # compute activation in mlp stream, cat again and run second linear layer
361        output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
362        output = output + self.proj_lora(output) * self.lora_weight
363        output = x + mod.gate * output
364
365        return output
366
367
368class SingleStreamBlockProcessor:
369    def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor):
370        mod, _ = attn.modulation(vec)
371        x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
372        qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
373
374        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
375        q, k = attn.norm(q, k, v)
376
377        # compute attention
378        attn_1 = attention(q, k, v, pe=pe)
379
380        # compute activation in mlp stream, cat again and run second linear layer
381        output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
382        output = x + mod.gate * output
383
384        return output
385
386
387class SingleStreamBlock(nn.Module):
388    """
389    A DiT block with parallel linear layers as described in
390    https://arxiv.org/abs/2302.05442 and adapted modulation interface.
391    """
392
393    def __init__(
394        self,
395        hidden_size: int,
396        num_heads: int,
397        mlp_ratio: float = 4.0,
398        qk_scale: float = None,
399    ):
400        super().__init__()
401        self.hidden_dim = hidden_size
402        self.num_heads = num_heads
403        self.head_dim = hidden_size // num_heads
404        self.scale = qk_scale or self.head_dim**-0.5
405
406        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
407        # qkv and mlp_in
408        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
409        # proj and mlp_out
410        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
411
412        self.norm = QKNorm(self.head_dim)
413
414        self.hidden_size = hidden_size
415        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
416
417        self.mlp_act = nn.GELU(approximate="tanh")
418        self.modulation = Modulation(hidden_size, double=False)
419
420        processor = SingleStreamBlockProcessor()
421        self.set_processor(processor)
422
423    def set_processor(self, processor):
424        self.processor = processor
425
426    def get_processor(self):
427        return self.processor
428
429    def forward(self, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0):
430        if image_proj is None:
431            return self.processor(self, x, vec, pe)
432        else:
433            return self.processor(self, x, vec, pe, image_proj, ip_scale)
434
435
436class ImageProjModel(torch.nn.Module):
437    """Projection Model
438    https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28
439    """
440
441    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
442        super().__init__()
443
444        self.generator = None
445        self.cross_attention_dim = cross_attention_dim
446        self.clip_extra_context_tokens = clip_extra_context_tokens
447        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
448        self.norm = torch.nn.LayerNorm(cross_attention_dim)
449
450    def forward(self, image_embeds):
451        embeds = image_embeds
452        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
453        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
454        return clip_extra_context_tokens
class LoRALinearLayer(torch.nn.modules.module.Module):
15class LoRALinearLayer(nn.Module):
16    def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
17        super().__init__()
18
19        self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
20        self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
21        # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
22        # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
23        self.network_alpha = network_alpha
24        self.rank = rank
25
26        nn.init.normal_(self.down.weight, std=1 / rank)
27        nn.init.zeros_(self.up.weight)
28
29    def forward(self, hidden_states):
30        orig_dtype = hidden_states.dtype
31        dtype = self.down.weight.dtype
32
33        down_hidden_states = self.down(hidden_states.to(dtype))
34        up_hidden_states = self.up(down_hidden_states)
35
36        if self.network_alpha is not None:
37            up_hidden_states *= self.network_alpha / self.rank
38
39        return up_hidden_states.to(orig_dtype)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

LoRALinearLayer( in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None)
16    def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
17        super().__init__()
18
19        self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
20        self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
21        # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
22        # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
23        self.network_alpha = network_alpha
24        self.rank = rank
25
26        nn.init.normal_(self.down.weight, std=1 / rank)
27        nn.init.zeros_(self.up.weight)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

down
up
network_alpha
rank
def forward(self, hidden_states):
29    def forward(self, hidden_states):
30        orig_dtype = hidden_states.dtype
31        dtype = self.down.weight.dtype
32
33        down_hidden_states = self.down(hidden_states.to(dtype))
34        up_hidden_states = self.up(down_hidden_states)
35
36        if self.network_alpha is not None:
37            up_hidden_states *= self.network_alpha / self.rank
38
39        return up_hidden_states.to(orig_dtype)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class FLuxSelfAttnProcessor:
42class FLuxSelfAttnProcessor:
43    def __call__(self, attn, x, pe, **attention_kwargs):
44        qkv = attn.qkv(x)
45        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
46        q, k = attn.norm(q, k, v)
47        x = attention(q, k, v, pe=pe)
48        x = attn.proj(x)
49        return x
class LoraFluxAttnProcessor(torch.nn.modules.module.Module):
52class LoraFluxAttnProcessor(nn.Module):
53    def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
54        super().__init__()
55        self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
56        self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
57        self.lora_weight = lora_weight
58
59    def __call__(self, attn, x, pe, **attention_kwargs):
60        qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
61        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
62        q, k = attn.norm(q, k, v)
63        x = attention(q, k, v, pe=pe)
64        x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
65        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

LoraFluxAttnProcessor(dim: int, rank=4, network_alpha=None, lora_weight=1)
53    def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
54        super().__init__()
55        self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
56        self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
57        self.lora_weight = lora_weight

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_lora
proj_lora
lora_weight
class SelfAttention(torch.nn.modules.module.Module):
68class SelfAttention(nn.Module):
69    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
70        super().__init__()
71        self.num_heads = num_heads
72        head_dim = dim // num_heads
73
74        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
75        self.norm = QKNorm(head_dim)
76        self.proj = nn.Linear(dim, dim)
77
78    def forward():
79        pass

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SelfAttention(dim: int, num_heads: int = 8, qkv_bias: bool = False)
69    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
70        super().__init__()
71        self.num_heads = num_heads
72        head_dim = dim // num_heads
73
74        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
75        self.norm = QKNorm(head_dim)
76        self.proj = nn.Linear(dim, dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_heads
qkv
norm
proj
def forward():
78    def forward():
79        pass

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class DoubleStreamBlockLoraProcessor(torch.nn.modules.module.Module):
 82class DoubleStreamBlockLoraProcessor(nn.Module):
 83    def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
 84        super().__init__()
 85        self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
 86        self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
 87        self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
 88        self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
 89        self.lora_weight = lora_weight
 90
 91    def __call__(self, attn, img, txt, vec, pe):
 92        img_mod1, img_mod2 = attn.img_mod(vec)
 93        txt_mod1, txt_mod2 = attn.txt_mod(vec)
 94
 95        # prepare image for attention
 96        img_modulated = attn.img_norm1(img)
 97        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
 98        img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
 99        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
100        img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
101
102        # prepare txt for attention
103        txt_modulated = attn.txt_norm1(txt)
104        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
105        txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
106        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
107        txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
108
109        # run actual attention
110        q = torch.cat((txt_q, img_q), dim=2)
111        k = torch.cat((txt_k, img_k), dim=2)
112        v = torch.cat((txt_v, img_v), dim=2)
113
114        attn1 = attention(q, k, v, pe=pe)
115        txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
116
117        # calculate the img blocks
118        img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
119        img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
120
121        # calculate the txt blocks
122        txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
123        txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
124
125        return img, txt

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

DoubleStreamBlockLoraProcessor(dim: int, rank=4, network_alpha=None, lora_weight=1)
83    def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
84        super().__init__()
85        self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
86        self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
87        self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
88        self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
89        self.lora_weight = lora_weight

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_lora1
proj_lora1
qkv_lora2
proj_lora2
lora_weight
class IPDoubleStreamBlockProcessor(torch.nn.modules.module.Module):
128class IPDoubleStreamBlockProcessor(nn.Module):
129    """Attention processor for handling IP-adapter with double stream block."""
130
131    def __init__(self, context_dim, hidden_dim):
132        super().__init__()
133        if not hasattr(F, "scaled_dot_product_attention"):
134            raise ImportError("IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.")
135
136        # Ensure context_dim matches the dimension of image_proj
137        self.context_dim = context_dim
138        self.hidden_dim = hidden_dim
139
140        # Initialize projections for IP-adapter
141        self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True)
142        self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True)
143
144        nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
145        nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
146
147        nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
148        nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
149
150    def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs):
151        # Prepare image for attention
152        img_mod1, img_mod2 = attn.img_mod(vec)
153        txt_mod1, txt_mod2 = attn.txt_mod(vec)
154
155        img_modulated = attn.img_norm1(img)
156        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
157        img_qkv = attn.img_attn.qkv(img_modulated)
158        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
159        img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
160
161        txt_modulated = attn.txt_norm1(txt)
162        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
163        txt_qkv = attn.txt_attn.qkv(txt_modulated)
164        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
165        txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
166
167        q = torch.cat((txt_q, img_q), dim=2)
168        k = torch.cat((txt_k, img_k), dim=2)
169        v = torch.cat((txt_v, img_v), dim=2)
170
171        attn1 = attention(q, k, v, pe=pe)
172        txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
173
174        img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
175        img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
176
177        txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
178        txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
179
180        # IP-adapter processing
181        ip_query = img_q  # latent sample query
182        ip_key = self.ip_adapter_double_stream_k_proj(image_proj)
183        ip_value = self.ip_adapter_double_stream_v_proj(image_proj)
184
185        # Reshape projections for multi-head attention
186        ip_key = rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
187        ip_value = rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
188
189        # Compute attention between IP projections and the latent query
190        ip_attention = F.scaled_dot_product_attention(ip_query, ip_key, ip_value, dropout_p=0.0, is_causal=False)
191        ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim)
192
193        img = img + ip_scale * ip_attention
194
195        return img, txt

Attention processor for handling IP-adapter with double stream block.

IPDoubleStreamBlockProcessor(context_dim, hidden_dim)
131    def __init__(self, context_dim, hidden_dim):
132        super().__init__()
133        if not hasattr(F, "scaled_dot_product_attention"):
134            raise ImportError("IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.")
135
136        # Ensure context_dim matches the dimension of image_proj
137        self.context_dim = context_dim
138        self.hidden_dim = hidden_dim
139
140        # Initialize projections for IP-adapter
141        self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True)
142        self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True)
143
144        nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
145        nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
146
147        nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
148        nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

context_dim
hidden_dim
ip_adapter_double_stream_k_proj
ip_adapter_double_stream_v_proj
class DoubleStreamBlockProcessor:
198class DoubleStreamBlockProcessor:
199    def __call__(self, attn, img, txt, vec, pe):
200        img_mod1, img_mod2 = attn.img_mod(vec)
201        txt_mod1, txt_mod2 = attn.txt_mod(vec)
202
203        # prepare image for attention
204        img_modulated = attn.img_norm1(img)
205        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
206        img_qkv = attn.img_attn.qkv(img_modulated)
207        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
208        img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
209
210        # prepare txt for attention
211        txt_modulated = attn.txt_norm1(txt)
212        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
213        txt_qkv = attn.txt_attn.qkv(txt_modulated)
214        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
215        txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
216
217        # run actual attention
218        q = torch.cat((txt_q, img_q), dim=2)
219        k = torch.cat((txt_k, img_k), dim=2)
220        v = torch.cat((txt_v, img_v), dim=2)
221
222        attn1 = attention(q, k, v, pe=pe)
223        txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
224
225        # calculate the img bloks
226        img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
227        img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
228
229        # calculate the txt bloks
230        txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
231        txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
232
233        return img, txt
class DoubleStreamBlock(torch.nn.modules.module.Module):
236class DoubleStreamBlock(nn.Module):
237    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
238        super().__init__()
239        mlp_hidden_dim = int(hidden_size * mlp_ratio)
240        self.num_heads = num_heads
241        self.hidden_size = hidden_size
242        self.head_dim = hidden_size // num_heads
243
244        self.img_mod = Modulation(hidden_size, double=True)
245        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
247
248        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
249        self.img_mlp = nn.Sequential(
250            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
251            nn.GELU(approximate="tanh"),
252            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
253        )
254
255        self.txt_mod = Modulation(hidden_size, double=True)
256        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
257        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
258
259        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
260        self.txt_mlp = nn.Sequential(
261            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
262            nn.GELU(approximate="tanh"),
263            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
264        )
265
266        processor = DoubleStreamBlockProcessor()
267        self.set_processor(processor)
268
269    def set_processor(self, processor):
270        self.processor = processor
271
272    def get_processor(self):
273        return self.processor
274
275    def forward(
276        self,
277        img: Tensor,
278        txt: Tensor,
279        vec: Tensor,
280        pe: Tensor,
281        image_proj: Tensor = None,
282        ip_scale: float = 1.0,
283    ):
284        if image_proj is None:
285            return self.processor(self, img, txt, vec, pe)
286        else:
287            return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

DoubleStreamBlock( hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False)
237    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
238        super().__init__()
239        mlp_hidden_dim = int(hidden_size * mlp_ratio)
240        self.num_heads = num_heads
241        self.hidden_size = hidden_size
242        self.head_dim = hidden_size // num_heads
243
244        self.img_mod = Modulation(hidden_size, double=True)
245        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
247
248        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
249        self.img_mlp = nn.Sequential(
250            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
251            nn.GELU(approximate="tanh"),
252            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
253        )
254
255        self.txt_mod = Modulation(hidden_size, double=True)
256        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
257        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
258
259        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
260        self.txt_mlp = nn.Sequential(
261            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
262            nn.GELU(approximate="tanh"),
263            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
264        )
265
266        processor = DoubleStreamBlockProcessor()
267        self.set_processor(processor)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_heads
hidden_size
head_dim
img_mod
img_norm1
img_attn
img_norm2
img_mlp
txt_mod
txt_norm1
txt_attn
txt_norm2
txt_mlp
def set_processor(self, processor):
269    def set_processor(self, processor):
270        self.processor = processor
def get_processor(self):
272    def get_processor(self):
273        return self.processor
def forward( self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor, image_proj: torch.Tensor = None, ip_scale: float = 1.0):
275    def forward(
276        self,
277        img: Tensor,
278        txt: Tensor,
279        vec: Tensor,
280        pe: Tensor,
281        image_proj: Tensor = None,
282        ip_scale: float = 1.0,
283    ):
284        if image_proj is None:
285            return self.processor(self, img, txt, vec, pe)
286        else:
287            return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class IPSingleStreamBlockProcessor(torch.nn.modules.module.Module):
290class IPSingleStreamBlockProcessor(nn.Module):
291    """Attention processor for handling IP-adapter with single stream block."""
292
293    def __init__(self, context_dim, hidden_dim):
294        super().__init__()
295        if not hasattr(F, "scaled_dot_product_attention"):
296            raise ImportError("IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.")
297
298        # Ensure context_dim matches the dimension of image_proj
299        self.context_dim = context_dim
300        self.hidden_dim = hidden_dim
301
302        # Initialize projections for IP-adapter
303        self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
304        self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
305
306        nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight)
307        nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight)
308
309    def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0):
310        mod, _ = attn.modulation(vec)
311        x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
312        qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
313
314        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
315        q, k = attn.norm(q, k, v)
316
317        # compute attention
318        attn_1 = attention(q, k, v, pe=pe)
319
320        # IP-adapter processing
321        ip_query = q
322        ip_key = self.ip_adapter_single_stream_k_proj(image_proj)
323        ip_value = self.ip_adapter_single_stream_v_proj(image_proj)
324
325        # Reshape projections for multi-head attention
326        ip_key = rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
327        ip_value = rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
328
329        # Compute attention between IP projections and the latent query
330        ip_attention = F.scaled_dot_product_attention(ip_query, ip_key, ip_value)
331        ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)")
332
333        attn_out = attn_1 + ip_scale * ip_attention
334
335        # compute activation in mlp stream, cat again and run second linear layer
336        output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2))
337        out = x + mod.gate * output
338
339        return out

Attention processor for handling IP-adapter with single stream block.

IPSingleStreamBlockProcessor(context_dim, hidden_dim)
293    def __init__(self, context_dim, hidden_dim):
294        super().__init__()
295        if not hasattr(F, "scaled_dot_product_attention"):
296            raise ImportError("IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch.")
297
298        # Ensure context_dim matches the dimension of image_proj
299        self.context_dim = context_dim
300        self.hidden_dim = hidden_dim
301
302        # Initialize projections for IP-adapter
303        self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
304        self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
305
306        nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight)
307        nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

context_dim
hidden_dim
ip_adapter_single_stream_k_proj
ip_adapter_single_stream_v_proj
class SingleStreamBlockLoraProcessor(torch.nn.modules.module.Module):
342class SingleStreamBlockLoraProcessor(nn.Module):
343    def __init__(self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1):
344        super().__init__()
345        self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
346        self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
347        self.lora_weight = lora_weight
348
349    def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor):
350        mod, _ = attn.modulation(vec)
351        x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
352        qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
353        qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
354
355        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
356        q, k = attn.norm(q, k, v)
357
358        # compute attention
359        attn_1 = attention(q, k, v, pe=pe)
360
361        # compute activation in mlp stream, cat again and run second linear layer
362        output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
363        output = output + self.proj_lora(output) * self.lora_weight
364        output = x + mod.gate * output
365
366        return output

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SingleStreamBlockLoraProcessor(dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1)
343    def __init__(self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1):
344        super().__init__()
345        self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
346        self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
347        self.lora_weight = lora_weight

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_lora
proj_lora
lora_weight
class SingleStreamBlockProcessor:
369class SingleStreamBlockProcessor:
370    def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor):
371        mod, _ = attn.modulation(vec)
372        x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
373        qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
374
375        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
376        q, k = attn.norm(q, k, v)
377
378        # compute attention
379        attn_1 = attention(q, k, v, pe=pe)
380
381        # compute activation in mlp stream, cat again and run second linear layer
382        output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
383        output = x + mod.gate * output
384
385        return output
class SingleStreamBlock(torch.nn.modules.module.Module):
388class SingleStreamBlock(nn.Module):
389    """
390    A DiT block with parallel linear layers as described in
391    https://arxiv.org/abs/2302.05442 and adapted modulation interface.
392    """
393
394    def __init__(
395        self,
396        hidden_size: int,
397        num_heads: int,
398        mlp_ratio: float = 4.0,
399        qk_scale: float = None,
400    ):
401        super().__init__()
402        self.hidden_dim = hidden_size
403        self.num_heads = num_heads
404        self.head_dim = hidden_size // num_heads
405        self.scale = qk_scale or self.head_dim**-0.5
406
407        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
408        # qkv and mlp_in
409        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
410        # proj and mlp_out
411        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
412
413        self.norm = QKNorm(self.head_dim)
414
415        self.hidden_size = hidden_size
416        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
417
418        self.mlp_act = nn.GELU(approximate="tanh")
419        self.modulation = Modulation(hidden_size, double=False)
420
421        processor = SingleStreamBlockProcessor()
422        self.set_processor(processor)
423
424    def set_processor(self, processor):
425        self.processor = processor
426
427    def get_processor(self):
428        return self.processor
429
430    def forward(self, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0):
431        if image_proj is None:
432            return self.processor(self, x, vec, pe)
433        else:
434            return self.processor(self, x, vec, pe, image_proj, ip_scale)

A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface.

SingleStreamBlock( hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, qk_scale: float = None)
394    def __init__(
395        self,
396        hidden_size: int,
397        num_heads: int,
398        mlp_ratio: float = 4.0,
399        qk_scale: float = None,
400    ):
401        super().__init__()
402        self.hidden_dim = hidden_size
403        self.num_heads = num_heads
404        self.head_dim = hidden_size // num_heads
405        self.scale = qk_scale or self.head_dim**-0.5
406
407        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
408        # qkv and mlp_in
409        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
410        # proj and mlp_out
411        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
412
413        self.norm = QKNorm(self.head_dim)
414
415        self.hidden_size = hidden_size
416        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
417
418        self.mlp_act = nn.GELU(approximate="tanh")
419        self.modulation = Modulation(hidden_size, double=False)
420
421        processor = SingleStreamBlockProcessor()
422        self.set_processor(processor)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

hidden_dim
num_heads
head_dim
scale
mlp_hidden_dim
linear1
linear2
norm
hidden_size
pre_norm
mlp_act
modulation
def set_processor(self, processor):
424    def set_processor(self, processor):
425        self.processor = processor
def get_processor(self):
427    def get_processor(self):
428        return self.processor
def forward( self, x: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor, image_proj: torch.Tensor = None, ip_scale: float = 1.0):
430    def forward(self, x: Tensor, vec: Tensor, pe: Tensor, image_proj: Tensor = None, ip_scale: float = 1.0):
431        if image_proj is None:
432            return self.processor(self, x, vec, pe)
433        else:
434            return self.processor(self, x, vec, pe, image_proj, ip_scale)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ImageProjModel(torch.nn.modules.module.Module):
437class ImageProjModel(torch.nn.Module):
438    """Projection Model
439    https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28
440    """
441
442    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
443        super().__init__()
444
445        self.generator = None
446        self.cross_attention_dim = cross_attention_dim
447        self.clip_extra_context_tokens = clip_extra_context_tokens
448        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
449        self.norm = torch.nn.LayerNorm(cross_attention_dim)
450
451    def forward(self, image_embeds):
452        embeds = image_embeds
453        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
454        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
455        return clip_extra_context_tokens
ImageProjModel( cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
442    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
443        super().__init__()
444
445        self.generator = None
446        self.cross_attention_dim = cross_attention_dim
447        self.clip_extra_context_tokens = clip_extra_context_tokens
448        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
449        self.norm = torch.nn.LayerNorm(cross_attention_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

generator
cross_attention_dim
clip_extra_context_tokens
proj
norm
def forward(self, image_embeds):
451    def forward(self, image_embeds):
452        embeds = image_embeds
453        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
454        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
455        return clip_extra_context_tokens

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.