divisor.acestep.models.attention
1# Copyright 2024 The HuggingFace Team. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14from typing import Tuple, Union 15 16import torch 17import torch.nn.functional as F 18from torch import nn 19 20from diffusers.utils import logging 21from diffusers.models.normalization import RMSNorm 22 23 24try: 25 # from .dcformer import DCMHAttention 26 from .customer_attention_processor import ( 27 Attention, 28 CustomLiteLAProcessor2_0, 29 CustomerAttnProcessor2_0, 30 ) 31except ImportError: 32 # from dcformer import DCMHAttention 33 from customer_attention_processor import ( 34 Attention, 35 CustomLiteLAProcessor2_0, 36 CustomerAttnProcessor2_0, 37 ) 38 39 40logger = logging.get_logger(__name__) 41 42 43def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore 44 """Repeat `val` for `repeat_time` times and return the list or val if list/tuple.""" 45 if isinstance(x, (list, tuple)): 46 return list(x) 47 return [x for _ in range(repeat_time)] 48 49 50def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore 51 """Return tuple with min_len by repeating element at idx_repeat.""" 52 # convert to list first 53 x = val2list(x) 54 55 # repeat elements if necessary 56 if len(x) > 0: 57 x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] 58 59 return tuple(x) 60 61 62def t2i_modulate(x, shift, scale): 63 return x * (1 + scale) + shift 64 65 66def get_same_padding( 67 kernel_size: Union[int, Tuple[int, ...]], 68) -> Union[int, Tuple[int, ...]]: 69 if isinstance(kernel_size, tuple): 70 return tuple([get_same_padding(ks) for ks in kernel_size]) 71 else: 72 assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number" 73 return kernel_size // 2 74 75 76class ConvLayer(nn.Module): 77 def __init__( 78 self, 79 in_dim: int, 80 out_dim: int, 81 kernel_size=3, 82 stride=1, 83 dilation=1, 84 groups=1, 85 padding: Union[int, None] = None, 86 use_bias=False, 87 norm=None, 88 act=None, 89 ): 90 super().__init__() 91 if padding is None: 92 padding = get_same_padding(kernel_size) 93 padding *= dilation 94 95 self.in_dim = in_dim 96 self.out_dim = out_dim 97 self.kernel_size = kernel_size 98 self.stride = stride 99 self.dilation = dilation 100 self.groups = groups 101 self.padding = padding 102 self.use_bias = use_bias 103 104 self.conv = nn.Conv1d( 105 in_dim, 106 out_dim, 107 kernel_size=kernel_size, 108 stride=stride, 109 padding=padding, 110 dilation=dilation, 111 groups=groups, 112 bias=use_bias, 113 ) 114 if norm is not None: 115 self.norm = RMSNorm(out_dim, elementwise_affine=False) 116 else: 117 self.norm = None 118 if act is not None: 119 self.act = nn.SiLU(inplace=True) 120 else: 121 self.act = None 122 123 def forward(self, x: torch.Tensor) -> torch.Tensor: 124 x = self.conv(x) 125 if self.norm: 126 x = self.norm(x) 127 if self.act: 128 x = self.act(x) 129 return x 130 131 132class GLUMBConv(nn.Module): 133 def __init__( 134 self, 135 in_features: int, 136 hidden_features: int, 137 out_feature=None, 138 kernel_size=3, 139 stride=1, 140 padding: Union[int, None] = None, 141 use_bias=False, 142 norm=(None, None, None), 143 act=("silu", "silu", None), 144 dilation=1, 145 ): 146 out_feature = out_feature or in_features 147 super().__init__() 148 use_bias = val2tuple(use_bias, 3) 149 norm = val2tuple(norm, 3) 150 act = val2tuple(act, 3) 151 152 self.glu_act = nn.SiLU(inplace=False) 153 self.inverted_conv = ConvLayer( 154 in_features, 155 hidden_features * 2, 156 1, 157 use_bias=use_bias[0], 158 norm=norm[0], 159 act=act[0], 160 ) 161 self.depth_conv = ConvLayer( 162 hidden_features * 2, 163 hidden_features * 2, 164 kernel_size, 165 stride=stride, 166 groups=hidden_features * 2, 167 padding=padding, 168 use_bias=use_bias[1], 169 norm=norm[1], 170 act=None, 171 dilation=dilation, 172 ) 173 self.point_conv = ConvLayer( 174 hidden_features, 175 out_feature, 176 1, 177 use_bias=use_bias[2], 178 norm=norm[2], 179 act=act[2], 180 ) 181 182 def forward(self, x: torch.Tensor) -> torch.Tensor: 183 x = x.transpose(1, 2) 184 x = self.inverted_conv(x) 185 x = self.depth_conv(x) 186 187 x, gate = torch.chunk(x, 2, dim=1) 188 gate = self.glu_act(gate) 189 x = x * gate 190 191 x = self.point_conv(x) 192 x = x.transpose(1, 2) 193 194 return x 195 196 197class LinearTransformerBlock(nn.Module): 198 """ 199 A Sana block with global shared adaptive layer norm (adaLN-single) conditioning. 200 """ 201 202 def __init__( 203 self, 204 dim, 205 num_attention_heads, 206 attention_head_dim, 207 use_adaln_single=True, 208 cross_attention_dim=None, 209 added_kv_proj_dim=None, 210 context_pre_only=False, 211 mlp_ratio=4.0, 212 add_cross_attention=False, 213 add_cross_attention_dim=None, 214 qk_norm=None, 215 ): 216 super().__init__() 217 218 self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6) 219 self.attn = Attention( 220 query_dim=dim, 221 cross_attention_dim=cross_attention_dim, 222 added_kv_proj_dim=added_kv_proj_dim, 223 dim_head=attention_head_dim, 224 heads=num_attention_heads, 225 out_dim=dim, 226 bias=True, 227 qk_norm=qk_norm, 228 processor=CustomLiteLAProcessor2_0(), 229 ) 230 231 self.add_cross_attention = add_cross_attention 232 self.context_pre_only = context_pre_only 233 234 if add_cross_attention and add_cross_attention_dim is not None: 235 self.cross_attn = Attention( 236 query_dim=dim, 237 cross_attention_dim=add_cross_attention_dim, 238 added_kv_proj_dim=add_cross_attention_dim, 239 dim_head=attention_head_dim, 240 heads=num_attention_heads, 241 out_dim=dim, 242 context_pre_only=context_pre_only, 243 bias=True, 244 qk_norm=qk_norm, 245 processor=CustomerAttnProcessor2_0(), 246 ) 247 248 self.norm2 = RMSNorm(dim, 1e-06, elementwise_affine=False) 249 250 self.ff = GLUMBConv( 251 in_features=dim, 252 hidden_features=int(dim * mlp_ratio), 253 use_bias=(True, True, False), 254 norm=(None, None, None), 255 act=("silu", "silu", None), 256 ) 257 self.use_adaln_single = use_adaln_single 258 if use_adaln_single: 259 self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 260 261 def forward( 262 self, 263 hidden_states: torch.FloatTensor, 264 encoder_hidden_states: torch.FloatTensor = None, 265 attention_mask: torch.FloatTensor = None, 266 encoder_attention_mask: torch.FloatTensor = None, 267 rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 268 rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 269 temb: torch.FloatTensor = None, 270 ): 271 272 N = hidden_states.shape[0] 273 274 # step 1: AdaLN single 275 if self.use_adaln_single: 276 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 277 self.scale_shift_table[None] + temb.reshape(N, 6, -1) 278 ).chunk(6, dim=1) 279 280 norm_hidden_states = self.norm1(hidden_states) 281 if self.use_adaln_single: 282 norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 283 284 # step 2: attention 285 if not self.add_cross_attention: 286 attn_output, encoder_hidden_states = self.attn( 287 hidden_states=norm_hidden_states, 288 attention_mask=attention_mask, 289 encoder_hidden_states=encoder_hidden_states, 290 encoder_attention_mask=encoder_attention_mask, 291 rotary_freqs_cis=rotary_freqs_cis, 292 rotary_freqs_cis_cross=rotary_freqs_cis_cross, 293 ) 294 else: 295 attn_output, _ = self.attn( 296 hidden_states=norm_hidden_states, 297 attention_mask=attention_mask, 298 encoder_hidden_states=None, 299 encoder_attention_mask=None, 300 rotary_freqs_cis=rotary_freqs_cis, 301 rotary_freqs_cis_cross=None, 302 ) 303 304 if self.use_adaln_single: 305 attn_output = gate_msa * attn_output 306 hidden_states = attn_output + hidden_states 307 308 if self.add_cross_attention: 309 attn_output = self.cross_attn( 310 hidden_states=hidden_states, 311 attention_mask=attention_mask, 312 encoder_hidden_states=encoder_hidden_states, 313 encoder_attention_mask=encoder_attention_mask, 314 rotary_freqs_cis=rotary_freqs_cis, 315 rotary_freqs_cis_cross=rotary_freqs_cis_cross, 316 ) 317 hidden_states = attn_output + hidden_states 318 319 # step 3: add norm 320 norm_hidden_states = self.norm2(hidden_states) 321 if self.use_adaln_single: 322 norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 323 324 # step 4: feed forward 325 ff_output = self.ff(norm_hidden_states) 326 if self.use_adaln_single: 327 ff_output = gate_mlp * ff_output 328 329 hidden_states = hidden_states + ff_output 330 331 return hidden_states
44def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore 45 """Repeat `val` for `repeat_time` times and return the list or val if list/tuple.""" 46 if isinstance(x, (list, tuple)): 47 return list(x) 48 return [x for _ in range(repeat_time)]
Repeat val for repeat_time times and return the list or val if list/tuple.
51def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore 52 """Return tuple with min_len by repeating element at idx_repeat.""" 53 # convert to list first 54 x = val2list(x) 55 56 # repeat elements if necessary 57 if len(x) > 0: 58 x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] 59 60 return tuple(x)
Return tuple with min_len by repeating element at idx_repeat.
67def get_same_padding( 68 kernel_size: Union[int, Tuple[int, ...]], 69) -> Union[int, Tuple[int, ...]]: 70 if isinstance(kernel_size, tuple): 71 return tuple([get_same_padding(ks) for ks in kernel_size]) 72 else: 73 assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number" 74 return kernel_size // 2
77class ConvLayer(nn.Module): 78 def __init__( 79 self, 80 in_dim: int, 81 out_dim: int, 82 kernel_size=3, 83 stride=1, 84 dilation=1, 85 groups=1, 86 padding: Union[int, None] = None, 87 use_bias=False, 88 norm=None, 89 act=None, 90 ): 91 super().__init__() 92 if padding is None: 93 padding = get_same_padding(kernel_size) 94 padding *= dilation 95 96 self.in_dim = in_dim 97 self.out_dim = out_dim 98 self.kernel_size = kernel_size 99 self.stride = stride 100 self.dilation = dilation 101 self.groups = groups 102 self.padding = padding 103 self.use_bias = use_bias 104 105 self.conv = nn.Conv1d( 106 in_dim, 107 out_dim, 108 kernel_size=kernel_size, 109 stride=stride, 110 padding=padding, 111 dilation=dilation, 112 groups=groups, 113 bias=use_bias, 114 ) 115 if norm is not None: 116 self.norm = RMSNorm(out_dim, elementwise_affine=False) 117 else: 118 self.norm = None 119 if act is not None: 120 self.act = nn.SiLU(inplace=True) 121 else: 122 self.act = None 123 124 def forward(self, x: torch.Tensor) -> torch.Tensor: 125 x = self.conv(x) 126 if self.norm: 127 x = self.norm(x) 128 if self.act: 129 x = self.act(x) 130 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
78 def __init__( 79 self, 80 in_dim: int, 81 out_dim: int, 82 kernel_size=3, 83 stride=1, 84 dilation=1, 85 groups=1, 86 padding: Union[int, None] = None, 87 use_bias=False, 88 norm=None, 89 act=None, 90 ): 91 super().__init__() 92 if padding is None: 93 padding = get_same_padding(kernel_size) 94 padding *= dilation 95 96 self.in_dim = in_dim 97 self.out_dim = out_dim 98 self.kernel_size = kernel_size 99 self.stride = stride 100 self.dilation = dilation 101 self.groups = groups 102 self.padding = padding 103 self.use_bias = use_bias 104 105 self.conv = nn.Conv1d( 106 in_dim, 107 out_dim, 108 kernel_size=kernel_size, 109 stride=stride, 110 padding=padding, 111 dilation=dilation, 112 groups=groups, 113 bias=use_bias, 114 ) 115 if norm is not None: 116 self.norm = RMSNorm(out_dim, elementwise_affine=False) 117 else: 118 self.norm = None 119 if act is not None: 120 self.act = nn.SiLU(inplace=True) 121 else: 122 self.act = None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
124 def forward(self, x: torch.Tensor) -> torch.Tensor: 125 x = self.conv(x) 126 if self.norm: 127 x = self.norm(x) 128 if self.act: 129 x = self.act(x) 130 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
133class GLUMBConv(nn.Module): 134 def __init__( 135 self, 136 in_features: int, 137 hidden_features: int, 138 out_feature=None, 139 kernel_size=3, 140 stride=1, 141 padding: Union[int, None] = None, 142 use_bias=False, 143 norm=(None, None, None), 144 act=("silu", "silu", None), 145 dilation=1, 146 ): 147 out_feature = out_feature or in_features 148 super().__init__() 149 use_bias = val2tuple(use_bias, 3) 150 norm = val2tuple(norm, 3) 151 act = val2tuple(act, 3) 152 153 self.glu_act = nn.SiLU(inplace=False) 154 self.inverted_conv = ConvLayer( 155 in_features, 156 hidden_features * 2, 157 1, 158 use_bias=use_bias[0], 159 norm=norm[0], 160 act=act[0], 161 ) 162 self.depth_conv = ConvLayer( 163 hidden_features * 2, 164 hidden_features * 2, 165 kernel_size, 166 stride=stride, 167 groups=hidden_features * 2, 168 padding=padding, 169 use_bias=use_bias[1], 170 norm=norm[1], 171 act=None, 172 dilation=dilation, 173 ) 174 self.point_conv = ConvLayer( 175 hidden_features, 176 out_feature, 177 1, 178 use_bias=use_bias[2], 179 norm=norm[2], 180 act=act[2], 181 ) 182 183 def forward(self, x: torch.Tensor) -> torch.Tensor: 184 x = x.transpose(1, 2) 185 x = self.inverted_conv(x) 186 x = self.depth_conv(x) 187 188 x, gate = torch.chunk(x, 2, dim=1) 189 gate = self.glu_act(gate) 190 x = x * gate 191 192 x = self.point_conv(x) 193 x = x.transpose(1, 2) 194 195 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
134 def __init__( 135 self, 136 in_features: int, 137 hidden_features: int, 138 out_feature=None, 139 kernel_size=3, 140 stride=1, 141 padding: Union[int, None] = None, 142 use_bias=False, 143 norm=(None, None, None), 144 act=("silu", "silu", None), 145 dilation=1, 146 ): 147 out_feature = out_feature or in_features 148 super().__init__() 149 use_bias = val2tuple(use_bias, 3) 150 norm = val2tuple(norm, 3) 151 act = val2tuple(act, 3) 152 153 self.glu_act = nn.SiLU(inplace=False) 154 self.inverted_conv = ConvLayer( 155 in_features, 156 hidden_features * 2, 157 1, 158 use_bias=use_bias[0], 159 norm=norm[0], 160 act=act[0], 161 ) 162 self.depth_conv = ConvLayer( 163 hidden_features * 2, 164 hidden_features * 2, 165 kernel_size, 166 stride=stride, 167 groups=hidden_features * 2, 168 padding=padding, 169 use_bias=use_bias[1], 170 norm=norm[1], 171 act=None, 172 dilation=dilation, 173 ) 174 self.point_conv = ConvLayer( 175 hidden_features, 176 out_feature, 177 1, 178 use_bias=use_bias[2], 179 norm=norm[2], 180 act=act[2], 181 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
183 def forward(self, x: torch.Tensor) -> torch.Tensor: 184 x = x.transpose(1, 2) 185 x = self.inverted_conv(x) 186 x = self.depth_conv(x) 187 188 x, gate = torch.chunk(x, 2, dim=1) 189 gate = self.glu_act(gate) 190 x = x * gate 191 192 x = self.point_conv(x) 193 x = x.transpose(1, 2) 194 195 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
198class LinearTransformerBlock(nn.Module): 199 """ 200 A Sana block with global shared adaptive layer norm (adaLN-single) conditioning. 201 """ 202 203 def __init__( 204 self, 205 dim, 206 num_attention_heads, 207 attention_head_dim, 208 use_adaln_single=True, 209 cross_attention_dim=None, 210 added_kv_proj_dim=None, 211 context_pre_only=False, 212 mlp_ratio=4.0, 213 add_cross_attention=False, 214 add_cross_attention_dim=None, 215 qk_norm=None, 216 ): 217 super().__init__() 218 219 self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6) 220 self.attn = Attention( 221 query_dim=dim, 222 cross_attention_dim=cross_attention_dim, 223 added_kv_proj_dim=added_kv_proj_dim, 224 dim_head=attention_head_dim, 225 heads=num_attention_heads, 226 out_dim=dim, 227 bias=True, 228 qk_norm=qk_norm, 229 processor=CustomLiteLAProcessor2_0(), 230 ) 231 232 self.add_cross_attention = add_cross_attention 233 self.context_pre_only = context_pre_only 234 235 if add_cross_attention and add_cross_attention_dim is not None: 236 self.cross_attn = Attention( 237 query_dim=dim, 238 cross_attention_dim=add_cross_attention_dim, 239 added_kv_proj_dim=add_cross_attention_dim, 240 dim_head=attention_head_dim, 241 heads=num_attention_heads, 242 out_dim=dim, 243 context_pre_only=context_pre_only, 244 bias=True, 245 qk_norm=qk_norm, 246 processor=CustomerAttnProcessor2_0(), 247 ) 248 249 self.norm2 = RMSNorm(dim, 1e-06, elementwise_affine=False) 250 251 self.ff = GLUMBConv( 252 in_features=dim, 253 hidden_features=int(dim * mlp_ratio), 254 use_bias=(True, True, False), 255 norm=(None, None, None), 256 act=("silu", "silu", None), 257 ) 258 self.use_adaln_single = use_adaln_single 259 if use_adaln_single: 260 self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 261 262 def forward( 263 self, 264 hidden_states: torch.FloatTensor, 265 encoder_hidden_states: torch.FloatTensor = None, 266 attention_mask: torch.FloatTensor = None, 267 encoder_attention_mask: torch.FloatTensor = None, 268 rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 269 rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 270 temb: torch.FloatTensor = None, 271 ): 272 273 N = hidden_states.shape[0] 274 275 # step 1: AdaLN single 276 if self.use_adaln_single: 277 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 278 self.scale_shift_table[None] + temb.reshape(N, 6, -1) 279 ).chunk(6, dim=1) 280 281 norm_hidden_states = self.norm1(hidden_states) 282 if self.use_adaln_single: 283 norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 284 285 # step 2: attention 286 if not self.add_cross_attention: 287 attn_output, encoder_hidden_states = self.attn( 288 hidden_states=norm_hidden_states, 289 attention_mask=attention_mask, 290 encoder_hidden_states=encoder_hidden_states, 291 encoder_attention_mask=encoder_attention_mask, 292 rotary_freqs_cis=rotary_freqs_cis, 293 rotary_freqs_cis_cross=rotary_freqs_cis_cross, 294 ) 295 else: 296 attn_output, _ = self.attn( 297 hidden_states=norm_hidden_states, 298 attention_mask=attention_mask, 299 encoder_hidden_states=None, 300 encoder_attention_mask=None, 301 rotary_freqs_cis=rotary_freqs_cis, 302 rotary_freqs_cis_cross=None, 303 ) 304 305 if self.use_adaln_single: 306 attn_output = gate_msa * attn_output 307 hidden_states = attn_output + hidden_states 308 309 if self.add_cross_attention: 310 attn_output = self.cross_attn( 311 hidden_states=hidden_states, 312 attention_mask=attention_mask, 313 encoder_hidden_states=encoder_hidden_states, 314 encoder_attention_mask=encoder_attention_mask, 315 rotary_freqs_cis=rotary_freqs_cis, 316 rotary_freqs_cis_cross=rotary_freqs_cis_cross, 317 ) 318 hidden_states = attn_output + hidden_states 319 320 # step 3: add norm 321 norm_hidden_states = self.norm2(hidden_states) 322 if self.use_adaln_single: 323 norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 324 325 # step 4: feed forward 326 ff_output = self.ff(norm_hidden_states) 327 if self.use_adaln_single: 328 ff_output = gate_mlp * ff_output 329 330 hidden_states = hidden_states + ff_output 331 332 return hidden_states
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
203 def __init__( 204 self, 205 dim, 206 num_attention_heads, 207 attention_head_dim, 208 use_adaln_single=True, 209 cross_attention_dim=None, 210 added_kv_proj_dim=None, 211 context_pre_only=False, 212 mlp_ratio=4.0, 213 add_cross_attention=False, 214 add_cross_attention_dim=None, 215 qk_norm=None, 216 ): 217 super().__init__() 218 219 self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6) 220 self.attn = Attention( 221 query_dim=dim, 222 cross_attention_dim=cross_attention_dim, 223 added_kv_proj_dim=added_kv_proj_dim, 224 dim_head=attention_head_dim, 225 heads=num_attention_heads, 226 out_dim=dim, 227 bias=True, 228 qk_norm=qk_norm, 229 processor=CustomLiteLAProcessor2_0(), 230 ) 231 232 self.add_cross_attention = add_cross_attention 233 self.context_pre_only = context_pre_only 234 235 if add_cross_attention and add_cross_attention_dim is not None: 236 self.cross_attn = Attention( 237 query_dim=dim, 238 cross_attention_dim=add_cross_attention_dim, 239 added_kv_proj_dim=add_cross_attention_dim, 240 dim_head=attention_head_dim, 241 heads=num_attention_heads, 242 out_dim=dim, 243 context_pre_only=context_pre_only, 244 bias=True, 245 qk_norm=qk_norm, 246 processor=CustomerAttnProcessor2_0(), 247 ) 248 249 self.norm2 = RMSNorm(dim, 1e-06, elementwise_affine=False) 250 251 self.ff = GLUMBConv( 252 in_features=dim, 253 hidden_features=int(dim * mlp_ratio), 254 use_bias=(True, True, False), 255 norm=(None, None, None), 256 act=("silu", "silu", None), 257 ) 258 self.use_adaln_single = use_adaln_single 259 if use_adaln_single: 260 self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
262 def forward( 263 self, 264 hidden_states: torch.FloatTensor, 265 encoder_hidden_states: torch.FloatTensor = None, 266 attention_mask: torch.FloatTensor = None, 267 encoder_attention_mask: torch.FloatTensor = None, 268 rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 269 rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 270 temb: torch.FloatTensor = None, 271 ): 272 273 N = hidden_states.shape[0] 274 275 # step 1: AdaLN single 276 if self.use_adaln_single: 277 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 278 self.scale_shift_table[None] + temb.reshape(N, 6, -1) 279 ).chunk(6, dim=1) 280 281 norm_hidden_states = self.norm1(hidden_states) 282 if self.use_adaln_single: 283 norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 284 285 # step 2: attention 286 if not self.add_cross_attention: 287 attn_output, encoder_hidden_states = self.attn( 288 hidden_states=norm_hidden_states, 289 attention_mask=attention_mask, 290 encoder_hidden_states=encoder_hidden_states, 291 encoder_attention_mask=encoder_attention_mask, 292 rotary_freqs_cis=rotary_freqs_cis, 293 rotary_freqs_cis_cross=rotary_freqs_cis_cross, 294 ) 295 else: 296 attn_output, _ = self.attn( 297 hidden_states=norm_hidden_states, 298 attention_mask=attention_mask, 299 encoder_hidden_states=None, 300 encoder_attention_mask=None, 301 rotary_freqs_cis=rotary_freqs_cis, 302 rotary_freqs_cis_cross=None, 303 ) 304 305 if self.use_adaln_single: 306 attn_output = gate_msa * attn_output 307 hidden_states = attn_output + hidden_states 308 309 if self.add_cross_attention: 310 attn_output = self.cross_attn( 311 hidden_states=hidden_states, 312 attention_mask=attention_mask, 313 encoder_hidden_states=encoder_hidden_states, 314 encoder_attention_mask=encoder_attention_mask, 315 rotary_freqs_cis=rotary_freqs_cis, 316 rotary_freqs_cis_cross=rotary_freqs_cis_cross, 317 ) 318 hidden_states = attn_output + hidden_states 319 320 # step 3: add norm 321 norm_hidden_states = self.norm2(hidden_states) 322 if self.use_adaln_single: 323 norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 324 325 # step 4: feed forward 326 ff_output = self.ff(norm_hidden_states) 327 if self.use_adaln_single: 328 ff_output = gate_mlp * ff_output 329 330 hidden_states = hidden_states + ff_output 331 332 return hidden_states
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.