divisor.acestep.models.attention

  1# Copyright 2024 The HuggingFace Team. All rights reserved.
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#     http://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14from typing import Tuple, Union
 15
 16import torch
 17import torch.nn.functional as F
 18from torch import nn
 19
 20from diffusers.utils import logging
 21from diffusers.models.normalization import RMSNorm
 22
 23
 24try:
 25    # from .dcformer import DCMHAttention
 26    from .customer_attention_processor import (
 27        Attention,
 28        CustomLiteLAProcessor2_0,
 29        CustomerAttnProcessor2_0,
 30    )
 31except ImportError:
 32    # from dcformer import DCMHAttention
 33    from customer_attention_processor import (
 34        Attention,
 35        CustomLiteLAProcessor2_0,
 36        CustomerAttnProcessor2_0,
 37    )
 38
 39
 40logger = logging.get_logger(__name__)
 41
 42
 43def val2list(x: list or tuple or any, repeat_time=1) -> list:  # type: ignore
 44    """Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
 45    if isinstance(x, (list, tuple)):
 46        return list(x)
 47    return [x for _ in range(repeat_time)]
 48
 49
 50def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:  # type: ignore
 51    """Return tuple with min_len by repeating element at idx_repeat."""
 52    # convert to list first
 53    x = val2list(x)
 54
 55    # repeat elements if necessary
 56    if len(x) > 0:
 57        x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
 58
 59    return tuple(x)
 60
 61
 62def t2i_modulate(x, shift, scale):
 63    return x * (1 + scale) + shift
 64
 65
 66def get_same_padding(
 67    kernel_size: Union[int, Tuple[int, ...]],
 68) -> Union[int, Tuple[int, ...]]:
 69    if isinstance(kernel_size, tuple):
 70        return tuple([get_same_padding(ks) for ks in kernel_size])
 71    else:
 72        assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
 73        return kernel_size // 2
 74
 75
 76class ConvLayer(nn.Module):
 77    def __init__(
 78        self,
 79        in_dim: int,
 80        out_dim: int,
 81        kernel_size=3,
 82        stride=1,
 83        dilation=1,
 84        groups=1,
 85        padding: Union[int, None] = None,
 86        use_bias=False,
 87        norm=None,
 88        act=None,
 89    ):
 90        super().__init__()
 91        if padding is None:
 92            padding = get_same_padding(kernel_size)
 93            padding *= dilation
 94
 95        self.in_dim = in_dim
 96        self.out_dim = out_dim
 97        self.kernel_size = kernel_size
 98        self.stride = stride
 99        self.dilation = dilation
100        self.groups = groups
101        self.padding = padding
102        self.use_bias = use_bias
103
104        self.conv = nn.Conv1d(
105            in_dim,
106            out_dim,
107            kernel_size=kernel_size,
108            stride=stride,
109            padding=padding,
110            dilation=dilation,
111            groups=groups,
112            bias=use_bias,
113        )
114        if norm is not None:
115            self.norm = RMSNorm(out_dim, elementwise_affine=False)
116        else:
117            self.norm = None
118        if act is not None:
119            self.act = nn.SiLU(inplace=True)
120        else:
121            self.act = None
122
123    def forward(self, x: torch.Tensor) -> torch.Tensor:
124        x = self.conv(x)
125        if self.norm:
126            x = self.norm(x)
127        if self.act:
128            x = self.act(x)
129        return x
130
131
132class GLUMBConv(nn.Module):
133    def __init__(
134        self,
135        in_features: int,
136        hidden_features: int,
137        out_feature=None,
138        kernel_size=3,
139        stride=1,
140        padding: Union[int, None] = None,
141        use_bias=False,
142        norm=(None, None, None),
143        act=("silu", "silu", None),
144        dilation=1,
145    ):
146        out_feature = out_feature or in_features
147        super().__init__()
148        use_bias = val2tuple(use_bias, 3)
149        norm = val2tuple(norm, 3)
150        act = val2tuple(act, 3)
151
152        self.glu_act = nn.SiLU(inplace=False)
153        self.inverted_conv = ConvLayer(
154            in_features,
155            hidden_features * 2,
156            1,
157            use_bias=use_bias[0],
158            norm=norm[0],
159            act=act[0],
160        )
161        self.depth_conv = ConvLayer(
162            hidden_features * 2,
163            hidden_features * 2,
164            kernel_size,
165            stride=stride,
166            groups=hidden_features * 2,
167            padding=padding,
168            use_bias=use_bias[1],
169            norm=norm[1],
170            act=None,
171            dilation=dilation,
172        )
173        self.point_conv = ConvLayer(
174            hidden_features,
175            out_feature,
176            1,
177            use_bias=use_bias[2],
178            norm=norm[2],
179            act=act[2],
180        )
181
182    def forward(self, x: torch.Tensor) -> torch.Tensor:
183        x = x.transpose(1, 2)
184        x = self.inverted_conv(x)
185        x = self.depth_conv(x)
186
187        x, gate = torch.chunk(x, 2, dim=1)
188        gate = self.glu_act(gate)
189        x = x * gate
190
191        x = self.point_conv(x)
192        x = x.transpose(1, 2)
193
194        return x
195
196
197class LinearTransformerBlock(nn.Module):
198    """
199    A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
200    """
201
202    def __init__(
203        self,
204        dim,
205        num_attention_heads,
206        attention_head_dim,
207        use_adaln_single=True,
208        cross_attention_dim=None,
209        added_kv_proj_dim=None,
210        context_pre_only=False,
211        mlp_ratio=4.0,
212        add_cross_attention=False,
213        add_cross_attention_dim=None,
214        qk_norm=None,
215    ):
216        super().__init__()
217
218        self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6)
219        self.attn = Attention(
220            query_dim=dim,
221            cross_attention_dim=cross_attention_dim,
222            added_kv_proj_dim=added_kv_proj_dim,
223            dim_head=attention_head_dim,
224            heads=num_attention_heads,
225            out_dim=dim,
226            bias=True,
227            qk_norm=qk_norm,
228            processor=CustomLiteLAProcessor2_0(),
229        )
230
231        self.add_cross_attention = add_cross_attention
232        self.context_pre_only = context_pre_only
233
234        if add_cross_attention and add_cross_attention_dim is not None:
235            self.cross_attn = Attention(
236                query_dim=dim,
237                cross_attention_dim=add_cross_attention_dim,
238                added_kv_proj_dim=add_cross_attention_dim,
239                dim_head=attention_head_dim,
240                heads=num_attention_heads,
241                out_dim=dim,
242                context_pre_only=context_pre_only,
243                bias=True,
244                qk_norm=qk_norm,
245                processor=CustomerAttnProcessor2_0(),
246            )
247
248        self.norm2 = RMSNorm(dim, 1e-06, elementwise_affine=False)
249
250        self.ff = GLUMBConv(
251            in_features=dim,
252            hidden_features=int(dim * mlp_ratio),
253            use_bias=(True, True, False),
254            norm=(None, None, None),
255            act=("silu", "silu", None),
256        )
257        self.use_adaln_single = use_adaln_single
258        if use_adaln_single:
259            self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
260
261    def forward(
262        self,
263        hidden_states: torch.FloatTensor,
264        encoder_hidden_states: torch.FloatTensor = None,
265        attention_mask: torch.FloatTensor = None,
266        encoder_attention_mask: torch.FloatTensor = None,
267        rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
268        rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
269        temb: torch.FloatTensor = None,
270    ):
271
272        N = hidden_states.shape[0]
273
274        # step 1: AdaLN single
275        if self.use_adaln_single:
276            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
277                self.scale_shift_table[None] + temb.reshape(N, 6, -1)
278            ).chunk(6, dim=1)
279
280        norm_hidden_states = self.norm1(hidden_states)
281        if self.use_adaln_single:
282            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
283
284        # step 2: attention
285        if not self.add_cross_attention:
286            attn_output, encoder_hidden_states = self.attn(
287                hidden_states=norm_hidden_states,
288                attention_mask=attention_mask,
289                encoder_hidden_states=encoder_hidden_states,
290                encoder_attention_mask=encoder_attention_mask,
291                rotary_freqs_cis=rotary_freqs_cis,
292                rotary_freqs_cis_cross=rotary_freqs_cis_cross,
293            )
294        else:
295            attn_output, _ = self.attn(
296                hidden_states=norm_hidden_states,
297                attention_mask=attention_mask,
298                encoder_hidden_states=None,
299                encoder_attention_mask=None,
300                rotary_freqs_cis=rotary_freqs_cis,
301                rotary_freqs_cis_cross=None,
302            )
303
304        if self.use_adaln_single:
305            attn_output = gate_msa * attn_output
306        hidden_states = attn_output + hidden_states
307
308        if self.add_cross_attention:
309            attn_output = self.cross_attn(
310                hidden_states=hidden_states,
311                attention_mask=attention_mask,
312                encoder_hidden_states=encoder_hidden_states,
313                encoder_attention_mask=encoder_attention_mask,
314                rotary_freqs_cis=rotary_freqs_cis,
315                rotary_freqs_cis_cross=rotary_freqs_cis_cross,
316            )
317            hidden_states = attn_output + hidden_states
318
319        # step 3: add norm
320        norm_hidden_states = self.norm2(hidden_states)
321        if self.use_adaln_single:
322            norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
323
324        # step 4: feed forward
325        ff_output = self.ff(norm_hidden_states)
326        if self.use_adaln_single:
327            ff_output = gate_mlp * ff_output
328
329        hidden_states = hidden_states + ff_output
330
331        return hidden_states
logger = <Logger divisor.acestep.models.attention (WARNING)>
def val2list(x: list, repeat_time=1) -> list:
44def val2list(x: list or tuple or any, repeat_time=1) -> list:  # type: ignore
45    """Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
46    if isinstance(x, (list, tuple)):
47        return list(x)
48    return [x for _ in range(repeat_time)]

Repeat val for repeat_time times and return the list or val if list/tuple.

def val2tuple(x: list, min_len: int = 1, idx_repeat: int = -1) -> tuple:
51def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:  # type: ignore
52    """Return tuple with min_len by repeating element at idx_repeat."""
53    # convert to list first
54    x = val2list(x)
55
56    # repeat elements if necessary
57    if len(x) > 0:
58        x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
59
60    return tuple(x)

Return tuple with min_len by repeating element at idx_repeat.

def t2i_modulate(x, shift, scale):
63def t2i_modulate(x, shift, scale):
64    return x * (1 + scale) + shift
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
67def get_same_padding(
68    kernel_size: Union[int, Tuple[int, ...]],
69) -> Union[int, Tuple[int, ...]]:
70    if isinstance(kernel_size, tuple):
71        return tuple([get_same_padding(ks) for ks in kernel_size])
72    else:
73        assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
74        return kernel_size // 2
class ConvLayer(torch.nn.modules.module.Module):
 77class ConvLayer(nn.Module):
 78    def __init__(
 79        self,
 80        in_dim: int,
 81        out_dim: int,
 82        kernel_size=3,
 83        stride=1,
 84        dilation=1,
 85        groups=1,
 86        padding: Union[int, None] = None,
 87        use_bias=False,
 88        norm=None,
 89        act=None,
 90    ):
 91        super().__init__()
 92        if padding is None:
 93            padding = get_same_padding(kernel_size)
 94            padding *= dilation
 95
 96        self.in_dim = in_dim
 97        self.out_dim = out_dim
 98        self.kernel_size = kernel_size
 99        self.stride = stride
100        self.dilation = dilation
101        self.groups = groups
102        self.padding = padding
103        self.use_bias = use_bias
104
105        self.conv = nn.Conv1d(
106            in_dim,
107            out_dim,
108            kernel_size=kernel_size,
109            stride=stride,
110            padding=padding,
111            dilation=dilation,
112            groups=groups,
113            bias=use_bias,
114        )
115        if norm is not None:
116            self.norm = RMSNorm(out_dim, elementwise_affine=False)
117        else:
118            self.norm = None
119        if act is not None:
120            self.act = nn.SiLU(inplace=True)
121        else:
122            self.act = None
123
124    def forward(self, x: torch.Tensor) -> torch.Tensor:
125        x = self.conv(x)
126        if self.norm:
127            x = self.norm(x)
128        if self.act:
129            x = self.act(x)
130        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ConvLayer( in_dim: int, out_dim: int, kernel_size=3, stride=1, dilation=1, groups=1, padding: Optional[int] = None, use_bias=False, norm=None, act=None)
 78    def __init__(
 79        self,
 80        in_dim: int,
 81        out_dim: int,
 82        kernel_size=3,
 83        stride=1,
 84        dilation=1,
 85        groups=1,
 86        padding: Union[int, None] = None,
 87        use_bias=False,
 88        norm=None,
 89        act=None,
 90    ):
 91        super().__init__()
 92        if padding is None:
 93            padding = get_same_padding(kernel_size)
 94            padding *= dilation
 95
 96        self.in_dim = in_dim
 97        self.out_dim = out_dim
 98        self.kernel_size = kernel_size
 99        self.stride = stride
100        self.dilation = dilation
101        self.groups = groups
102        self.padding = padding
103        self.use_bias = use_bias
104
105        self.conv = nn.Conv1d(
106            in_dim,
107            out_dim,
108            kernel_size=kernel_size,
109            stride=stride,
110            padding=padding,
111            dilation=dilation,
112            groups=groups,
113            bias=use_bias,
114        )
115        if norm is not None:
116            self.norm = RMSNorm(out_dim, elementwise_affine=False)
117        else:
118            self.norm = None
119        if act is not None:
120            self.act = nn.SiLU(inplace=True)
121        else:
122            self.act = None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_dim
out_dim
kernel_size
stride
dilation
groups
padding
use_bias
conv
def forward(self, x: torch.Tensor) -> torch.Tensor:
124    def forward(self, x: torch.Tensor) -> torch.Tensor:
125        x = self.conv(x)
126        if self.norm:
127            x = self.norm(x)
128        if self.act:
129            x = self.act(x)
130        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class GLUMBConv(torch.nn.modules.module.Module):
133class GLUMBConv(nn.Module):
134    def __init__(
135        self,
136        in_features: int,
137        hidden_features: int,
138        out_feature=None,
139        kernel_size=3,
140        stride=1,
141        padding: Union[int, None] = None,
142        use_bias=False,
143        norm=(None, None, None),
144        act=("silu", "silu", None),
145        dilation=1,
146    ):
147        out_feature = out_feature or in_features
148        super().__init__()
149        use_bias = val2tuple(use_bias, 3)
150        norm = val2tuple(norm, 3)
151        act = val2tuple(act, 3)
152
153        self.glu_act = nn.SiLU(inplace=False)
154        self.inverted_conv = ConvLayer(
155            in_features,
156            hidden_features * 2,
157            1,
158            use_bias=use_bias[0],
159            norm=norm[0],
160            act=act[0],
161        )
162        self.depth_conv = ConvLayer(
163            hidden_features * 2,
164            hidden_features * 2,
165            kernel_size,
166            stride=stride,
167            groups=hidden_features * 2,
168            padding=padding,
169            use_bias=use_bias[1],
170            norm=norm[1],
171            act=None,
172            dilation=dilation,
173        )
174        self.point_conv = ConvLayer(
175            hidden_features,
176            out_feature,
177            1,
178            use_bias=use_bias[2],
179            norm=norm[2],
180            act=act[2],
181        )
182
183    def forward(self, x: torch.Tensor) -> torch.Tensor:
184        x = x.transpose(1, 2)
185        x = self.inverted_conv(x)
186        x = self.depth_conv(x)
187
188        x, gate = torch.chunk(x, 2, dim=1)
189        gate = self.glu_act(gate)
190        x = x * gate
191
192        x = self.point_conv(x)
193        x = x.transpose(1, 2)
194
195        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

GLUMBConv( in_features: int, hidden_features: int, out_feature=None, kernel_size=3, stride=1, padding: Optional[int] = None, use_bias=False, norm=(None, None, None), act=('silu', 'silu', None), dilation=1)
134    def __init__(
135        self,
136        in_features: int,
137        hidden_features: int,
138        out_feature=None,
139        kernel_size=3,
140        stride=1,
141        padding: Union[int, None] = None,
142        use_bias=False,
143        norm=(None, None, None),
144        act=("silu", "silu", None),
145        dilation=1,
146    ):
147        out_feature = out_feature or in_features
148        super().__init__()
149        use_bias = val2tuple(use_bias, 3)
150        norm = val2tuple(norm, 3)
151        act = val2tuple(act, 3)
152
153        self.glu_act = nn.SiLU(inplace=False)
154        self.inverted_conv = ConvLayer(
155            in_features,
156            hidden_features * 2,
157            1,
158            use_bias=use_bias[0],
159            norm=norm[0],
160            act=act[0],
161        )
162        self.depth_conv = ConvLayer(
163            hidden_features * 2,
164            hidden_features * 2,
165            kernel_size,
166            stride=stride,
167            groups=hidden_features * 2,
168            padding=padding,
169            use_bias=use_bias[1],
170            norm=norm[1],
171            act=None,
172            dilation=dilation,
173        )
174        self.point_conv = ConvLayer(
175            hidden_features,
176            out_feature,
177            1,
178            use_bias=use_bias[2],
179            norm=norm[2],
180            act=act[2],
181        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

glu_act
inverted_conv
depth_conv
point_conv
def forward(self, x: torch.Tensor) -> torch.Tensor:
183    def forward(self, x: torch.Tensor) -> torch.Tensor:
184        x = x.transpose(1, 2)
185        x = self.inverted_conv(x)
186        x = self.depth_conv(x)
187
188        x, gate = torch.chunk(x, 2, dim=1)
189        gate = self.glu_act(gate)
190        x = x * gate
191
192        x = self.point_conv(x)
193        x = x.transpose(1, 2)
194
195        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class LinearTransformerBlock(torch.nn.modules.module.Module):
198class LinearTransformerBlock(nn.Module):
199    """
200    A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
201    """
202
203    def __init__(
204        self,
205        dim,
206        num_attention_heads,
207        attention_head_dim,
208        use_adaln_single=True,
209        cross_attention_dim=None,
210        added_kv_proj_dim=None,
211        context_pre_only=False,
212        mlp_ratio=4.0,
213        add_cross_attention=False,
214        add_cross_attention_dim=None,
215        qk_norm=None,
216    ):
217        super().__init__()
218
219        self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6)
220        self.attn = Attention(
221            query_dim=dim,
222            cross_attention_dim=cross_attention_dim,
223            added_kv_proj_dim=added_kv_proj_dim,
224            dim_head=attention_head_dim,
225            heads=num_attention_heads,
226            out_dim=dim,
227            bias=True,
228            qk_norm=qk_norm,
229            processor=CustomLiteLAProcessor2_0(),
230        )
231
232        self.add_cross_attention = add_cross_attention
233        self.context_pre_only = context_pre_only
234
235        if add_cross_attention and add_cross_attention_dim is not None:
236            self.cross_attn = Attention(
237                query_dim=dim,
238                cross_attention_dim=add_cross_attention_dim,
239                added_kv_proj_dim=add_cross_attention_dim,
240                dim_head=attention_head_dim,
241                heads=num_attention_heads,
242                out_dim=dim,
243                context_pre_only=context_pre_only,
244                bias=True,
245                qk_norm=qk_norm,
246                processor=CustomerAttnProcessor2_0(),
247            )
248
249        self.norm2 = RMSNorm(dim, 1e-06, elementwise_affine=False)
250
251        self.ff = GLUMBConv(
252            in_features=dim,
253            hidden_features=int(dim * mlp_ratio),
254            use_bias=(True, True, False),
255            norm=(None, None, None),
256            act=("silu", "silu", None),
257        )
258        self.use_adaln_single = use_adaln_single
259        if use_adaln_single:
260            self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
261
262    def forward(
263        self,
264        hidden_states: torch.FloatTensor,
265        encoder_hidden_states: torch.FloatTensor = None,
266        attention_mask: torch.FloatTensor = None,
267        encoder_attention_mask: torch.FloatTensor = None,
268        rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
269        rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
270        temb: torch.FloatTensor = None,
271    ):
272
273        N = hidden_states.shape[0]
274
275        # step 1: AdaLN single
276        if self.use_adaln_single:
277            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
278                self.scale_shift_table[None] + temb.reshape(N, 6, -1)
279            ).chunk(6, dim=1)
280
281        norm_hidden_states = self.norm1(hidden_states)
282        if self.use_adaln_single:
283            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
284
285        # step 2: attention
286        if not self.add_cross_attention:
287            attn_output, encoder_hidden_states = self.attn(
288                hidden_states=norm_hidden_states,
289                attention_mask=attention_mask,
290                encoder_hidden_states=encoder_hidden_states,
291                encoder_attention_mask=encoder_attention_mask,
292                rotary_freqs_cis=rotary_freqs_cis,
293                rotary_freqs_cis_cross=rotary_freqs_cis_cross,
294            )
295        else:
296            attn_output, _ = self.attn(
297                hidden_states=norm_hidden_states,
298                attention_mask=attention_mask,
299                encoder_hidden_states=None,
300                encoder_attention_mask=None,
301                rotary_freqs_cis=rotary_freqs_cis,
302                rotary_freqs_cis_cross=None,
303            )
304
305        if self.use_adaln_single:
306            attn_output = gate_msa * attn_output
307        hidden_states = attn_output + hidden_states
308
309        if self.add_cross_attention:
310            attn_output = self.cross_attn(
311                hidden_states=hidden_states,
312                attention_mask=attention_mask,
313                encoder_hidden_states=encoder_hidden_states,
314                encoder_attention_mask=encoder_attention_mask,
315                rotary_freqs_cis=rotary_freqs_cis,
316                rotary_freqs_cis_cross=rotary_freqs_cis_cross,
317            )
318            hidden_states = attn_output + hidden_states
319
320        # step 3: add norm
321        norm_hidden_states = self.norm2(hidden_states)
322        if self.use_adaln_single:
323            norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
324
325        # step 4: feed forward
326        ff_output = self.ff(norm_hidden_states)
327        if self.use_adaln_single:
328            ff_output = gate_mlp * ff_output
329
330        hidden_states = hidden_states + ff_output
331
332        return hidden_states

A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.

LinearTransformerBlock( dim, num_attention_heads, attention_head_dim, use_adaln_single=True, cross_attention_dim=None, added_kv_proj_dim=None, context_pre_only=False, mlp_ratio=4.0, add_cross_attention=False, add_cross_attention_dim=None, qk_norm=None)
203    def __init__(
204        self,
205        dim,
206        num_attention_heads,
207        attention_head_dim,
208        use_adaln_single=True,
209        cross_attention_dim=None,
210        added_kv_proj_dim=None,
211        context_pre_only=False,
212        mlp_ratio=4.0,
213        add_cross_attention=False,
214        add_cross_attention_dim=None,
215        qk_norm=None,
216    ):
217        super().__init__()
218
219        self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6)
220        self.attn = Attention(
221            query_dim=dim,
222            cross_attention_dim=cross_attention_dim,
223            added_kv_proj_dim=added_kv_proj_dim,
224            dim_head=attention_head_dim,
225            heads=num_attention_heads,
226            out_dim=dim,
227            bias=True,
228            qk_norm=qk_norm,
229            processor=CustomLiteLAProcessor2_0(),
230        )
231
232        self.add_cross_attention = add_cross_attention
233        self.context_pre_only = context_pre_only
234
235        if add_cross_attention and add_cross_attention_dim is not None:
236            self.cross_attn = Attention(
237                query_dim=dim,
238                cross_attention_dim=add_cross_attention_dim,
239                added_kv_proj_dim=add_cross_attention_dim,
240                dim_head=attention_head_dim,
241                heads=num_attention_heads,
242                out_dim=dim,
243                context_pre_only=context_pre_only,
244                bias=True,
245                qk_norm=qk_norm,
246                processor=CustomerAttnProcessor2_0(),
247            )
248
249        self.norm2 = RMSNorm(dim, 1e-06, elementwise_affine=False)
250
251        self.ff = GLUMBConv(
252            in_features=dim,
253            hidden_features=int(dim * mlp_ratio),
254            use_bias=(True, True, False),
255            norm=(None, None, None),
256            act=("silu", "silu", None),
257        )
258        self.use_adaln_single = use_adaln_single
259        if use_adaln_single:
260            self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

norm1
attn
add_cross_attention
context_pre_only
norm2
ff
use_adaln_single
def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: torch.FloatTensor = None, encoder_attention_mask: torch.FloatTensor = None, rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, temb: torch.FloatTensor = None):
262    def forward(
263        self,
264        hidden_states: torch.FloatTensor,
265        encoder_hidden_states: torch.FloatTensor = None,
266        attention_mask: torch.FloatTensor = None,
267        encoder_attention_mask: torch.FloatTensor = None,
268        rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
269        rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
270        temb: torch.FloatTensor = None,
271    ):
272
273        N = hidden_states.shape[0]
274
275        # step 1: AdaLN single
276        if self.use_adaln_single:
277            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
278                self.scale_shift_table[None] + temb.reshape(N, 6, -1)
279            ).chunk(6, dim=1)
280
281        norm_hidden_states = self.norm1(hidden_states)
282        if self.use_adaln_single:
283            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
284
285        # step 2: attention
286        if not self.add_cross_attention:
287            attn_output, encoder_hidden_states = self.attn(
288                hidden_states=norm_hidden_states,
289                attention_mask=attention_mask,
290                encoder_hidden_states=encoder_hidden_states,
291                encoder_attention_mask=encoder_attention_mask,
292                rotary_freqs_cis=rotary_freqs_cis,
293                rotary_freqs_cis_cross=rotary_freqs_cis_cross,
294            )
295        else:
296            attn_output, _ = self.attn(
297                hidden_states=norm_hidden_states,
298                attention_mask=attention_mask,
299                encoder_hidden_states=None,
300                encoder_attention_mask=None,
301                rotary_freqs_cis=rotary_freqs_cis,
302                rotary_freqs_cis_cross=None,
303            )
304
305        if self.use_adaln_single:
306            attn_output = gate_msa * attn_output
307        hidden_states = attn_output + hidden_states
308
309        if self.add_cross_attention:
310            attn_output = self.cross_attn(
311                hidden_states=hidden_states,
312                attention_mask=attention_mask,
313                encoder_hidden_states=encoder_hidden_states,
314                encoder_attention_mask=encoder_attention_mask,
315                rotary_freqs_cis=rotary_freqs_cis,
316                rotary_freqs_cis_cross=rotary_freqs_cis_cross,
317            )
318            hidden_states = attn_output + hidden_states
319
320        # step 3: add norm
321        norm_hidden_states = self.norm2(hidden_states)
322        if self.use_adaln_single:
323            norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
324
325        # step 4: feed forward
326        ff_output = self.ff(norm_hidden_states)
327        if self.use_adaln_single:
328            ff_output = gate_mlp * ff_output
329
330        hidden_states = hidden_states + ff_output
331
332        return hidden_states

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.