divisor.acestep.models.customer_attention_processor
1# Copyright 2024 The HuggingFace Team. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14from typing import Optional, Union, Tuple 15 16import torch 17import torch.nn.functional as F 18from torch import nn 19 20from diffusers.utils import logging 21from diffusers.models.attention_processor import Attention 22 23logger = logging.get_logger(__name__) # pylint: disable=invalid-name 24 25 26class CustomLiteLAProcessor2_0: 27 """Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE""" 28 29 def __init__(self): 30 self.kernel_func = nn.ReLU(inplace=False) 31 self.eps = 1e-15 32 self.pad_val = 1.0 33 34 def apply_rotary_emb( 35 self, 36 x: torch.Tensor, 37 freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 38 ) -> Tuple[torch.Tensor, torch.Tensor]: 39 """ 40 Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings 41 to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are 42 reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting 43 tensors contain rotary embeddings and are returned as real tensors. 44 45 Args: 46 x (`torch.Tensor`): 47 Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply 48 freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) 49 50 Returns: 51 Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 52 """ 53 cos, sin = freqs_cis # [S, D] 54 cos = cos[None, None] 55 sin = sin[None, None] 56 cos, sin = cos.to(x.device), sin.to(x.device) 57 58 x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] 59 x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) 60 out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) 61 62 return out 63 64 def __call__( 65 self, 66 attn: Attention, 67 hidden_states: torch.FloatTensor, 68 encoder_hidden_states: torch.FloatTensor = None, 69 attention_mask: Optional[torch.FloatTensor] = None, 70 encoder_attention_mask: Optional[torch.FloatTensor] = None, 71 rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 72 rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 73 *args, 74 **kwargs, 75 ) -> torch.FloatTensor: 76 hidden_states_len = hidden_states.shape[1] 77 78 input_ndim = hidden_states.ndim 79 if input_ndim == 4: 80 batch_size, channel, height, width = hidden_states.shape 81 hidden_states = hidden_states.view( 82 batch_size, channel, height * width 83 ).transpose(1, 2) 84 if encoder_hidden_states is not None: 85 context_input_ndim = encoder_hidden_states.ndim 86 if context_input_ndim == 4: 87 batch_size, channel, height, width = encoder_hidden_states.shape 88 encoder_hidden_states = encoder_hidden_states.view( 89 batch_size, channel, height * width 90 ).transpose(1, 2) 91 92 batch_size = hidden_states.shape[0] 93 94 # `sample` projections. 95 dtype = hidden_states.dtype 96 query = attn.to_q(hidden_states) 97 key = attn.to_k(hidden_states) 98 value = attn.to_v(hidden_states) 99 100 # `context` projections. 101 has_encoder_hidden_state_proj = ( 102 hasattr(attn, "add_q_proj") 103 and hasattr(attn, "add_k_proj") 104 and hasattr(attn, "add_v_proj") 105 ) 106 if encoder_hidden_states is not None and has_encoder_hidden_state_proj: 107 encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 108 encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 109 encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 110 111 # attention 112 if not attn.is_cross_attention: 113 query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) 114 key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) 115 value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) 116 else: 117 query = hidden_states 118 key = encoder_hidden_states 119 value = encoder_hidden_states 120 121 inner_dim = key.shape[-1] 122 head_dim = inner_dim // attn.heads 123 124 query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1) 125 key = ( 126 key.transpose(-1, -2) 127 .reshape(batch_size, attn.heads, head_dim, -1) 128 .transpose(-1, -2) 129 ) 130 value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1) 131 132 # RoPE需要 [B, H, S, D] 输入 133 # 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE 134 query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S]) 135 136 # Apply query and key normalization if needed 137 if attn.norm_q is not None: 138 query = attn.norm_q(query) 139 if attn.norm_k is not None: 140 key = attn.norm_k(key) 141 142 # Apply RoPE if needed 143 if rotary_freqs_cis is not None: 144 query = self.apply_rotary_emb(query, rotary_freqs_cis) 145 if not attn.is_cross_attention: 146 key = self.apply_rotary_emb(key, rotary_freqs_cis) 147 elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj: 148 key = self.apply_rotary_emb(key, rotary_freqs_cis_cross) 149 150 # 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S] 151 query = query.permute(0, 1, 3, 2) # [B, H, D, S] 152 153 if attention_mask is not None: 154 # attention_mask: [B, S] -> [B, 1, S, 1] 155 attention_mask = attention_mask[:, None, :, None].to( 156 key.dtype 157 ) # [B, 1, S, 1] 158 query = query * attention_mask.permute( 159 0, 1, 3, 2 160 ) # [B, H, S, D] * [B, 1, S, 1] 161 if not attn.is_cross_attention: 162 key = ( 163 key * attention_mask 164 ) # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘 165 value = value * attention_mask.permute( 166 0, 1, 3, 2 167 ) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度 168 169 if ( 170 attn.is_cross_attention 171 and encoder_attention_mask is not None 172 and has_encoder_hidden_state_proj 173 ): 174 encoder_attention_mask = encoder_attention_mask[:, None, :, None].to( 175 key.dtype 176 ) # [B, 1, S_enc, 1] 177 # 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc] 178 key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1] 179 value = value * encoder_attention_mask.permute( 180 0, 1, 3, 2 181 ) # [B, h, D, S_enc] * [B, 1, 1, S_enc] 182 183 query = self.kernel_func(query) 184 key = self.kernel_func(key) 185 186 query, key, value = query.float(), key.float(), value.float() 187 188 value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val) 189 190 vk = torch.matmul(value, key) 191 192 hidden_states = torch.matmul(vk, query) 193 194 if hidden_states.dtype in [torch.float16, torch.bfloat16]: 195 hidden_states = hidden_states.float() 196 197 hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps) 198 199 hidden_states = hidden_states.view( 200 batch_size, attn.heads * head_dim, -1 201 ).permute(0, 2, 1) 202 203 hidden_states = hidden_states.to(dtype) 204 if encoder_hidden_states is not None: 205 encoder_hidden_states = encoder_hidden_states.to(dtype) 206 207 # Split the attention outputs. 208 if ( 209 encoder_hidden_states is not None 210 and not attn.is_cross_attention 211 and has_encoder_hidden_state_proj 212 ): 213 hidden_states, encoder_hidden_states = ( 214 hidden_states[:, :hidden_states_len], 215 hidden_states[:, hidden_states_len:], 216 ) 217 218 # linear proj 219 hidden_states = attn.to_out[0](hidden_states) 220 # dropout 221 hidden_states = attn.to_out[1](hidden_states) 222 if ( 223 encoder_hidden_states is not None 224 and not attn.context_pre_only 225 and not attn.is_cross_attention 226 and hasattr(attn, "to_add_out") 227 ): 228 encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 229 230 if input_ndim == 4: 231 hidden_states = hidden_states.transpose(-1, -2).reshape( 232 batch_size, channel, height, width 233 ) 234 if encoder_hidden_states is not None and context_input_ndim == 4: 235 encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( 236 batch_size, channel, height, width 237 ) 238 239 if torch.get_autocast_gpu_dtype() == torch.float16: 240 hidden_states = hidden_states.clip(-65504, 65504) 241 if encoder_hidden_states is not None: 242 encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 243 244 return hidden_states, encoder_hidden_states 245 246 247class CustomerAttnProcessor2_0: 248 r""" 249 Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 250 """ 251 252 def __init__(self): 253 if not hasattr(F, "scaled_dot_product_attention"): 254 raise ImportError( 255 "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." 256 ) 257 258 def apply_rotary_emb( 259 self, 260 x: torch.Tensor, 261 freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 262 ) -> Tuple[torch.Tensor, torch.Tensor]: 263 """ 264 Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings 265 to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are 266 reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting 267 tensors contain rotary embeddings and are returned as real tensors. 268 269 Args: 270 x (`torch.Tensor`): 271 Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply 272 freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) 273 274 Returns: 275 Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 276 """ 277 cos, sin = freqs_cis # [S, D] 278 cos = cos[None, None] 279 sin = sin[None, None] 280 cos, sin = cos.to(x.device), sin.to(x.device) 281 282 x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] 283 x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) 284 out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) 285 286 return out 287 288 def __call__( 289 self, 290 attn: Attention, 291 hidden_states: torch.FloatTensor, 292 encoder_hidden_states: torch.FloatTensor = None, 293 attention_mask: Optional[torch.FloatTensor] = None, 294 encoder_attention_mask: Optional[torch.FloatTensor] = None, 295 rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 296 rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 297 *args, 298 **kwargs, 299 ) -> torch.Tensor: 300 301 residual = hidden_states 302 input_ndim = hidden_states.ndim 303 304 if input_ndim == 4: 305 batch_size, channel, height, width = hidden_states.shape 306 hidden_states = hidden_states.view( 307 batch_size, channel, height * width 308 ).transpose(1, 2) 309 310 batch_size, sequence_length, _ = ( 311 hidden_states.shape 312 if encoder_hidden_states is None 313 else encoder_hidden_states.shape 314 ) 315 316 has_encoder_hidden_state_proj = ( 317 hasattr(attn, "add_q_proj") 318 and hasattr(attn, "add_k_proj") 319 and hasattr(attn, "add_v_proj") 320 ) 321 322 if attn.group_norm is not None: 323 hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 324 1, 2 325 ) 326 327 query = attn.to_q(hidden_states) 328 329 if encoder_hidden_states is None: 330 encoder_hidden_states = hidden_states 331 elif attn.norm_cross: 332 encoder_hidden_states = attn.norm_encoder_hidden_states( 333 encoder_hidden_states 334 ) 335 336 key = attn.to_k(encoder_hidden_states) 337 value = attn.to_v(encoder_hidden_states) 338 339 inner_dim = key.shape[-1] 340 head_dim = inner_dim // attn.heads 341 342 query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 343 344 key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 345 value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 346 347 if attn.norm_q is not None: 348 query = attn.norm_q(query) 349 if attn.norm_k is not None: 350 key = attn.norm_k(key) 351 352 # Apply RoPE if needed 353 if rotary_freqs_cis is not None: 354 query = self.apply_rotary_emb(query, rotary_freqs_cis) 355 if not attn.is_cross_attention: 356 key = self.apply_rotary_emb(key, rotary_freqs_cis) 357 elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj: 358 key = self.apply_rotary_emb(key, rotary_freqs_cis_cross) 359 360 if ( 361 attn.is_cross_attention 362 and encoder_attention_mask is not None 363 and has_encoder_hidden_state_proj 364 ): 365 # attention_mask: N x S1 366 # encoder_attention_mask: N x S2 367 # cross attention 整合attention_mask和encoder_attention_mask 368 combined_mask = ( 369 attention_mask[:, :, None] * encoder_attention_mask[:, None, :] 370 ) 371 attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf) 372 attention_mask = ( 373 attention_mask[:, None, :, :] 374 .expand(-1, attn.heads, -1, -1) 375 .to(query.dtype) 376 ) 377 378 elif not attn.is_cross_attention and attention_mask is not None: 379 attention_mask = attn.prepare_attention_mask( 380 attention_mask, sequence_length, batch_size 381 ) 382 # scaled_dot_product_attention expects attention_mask shape to be 383 # (batch, heads, source_length, target_length) 384 attention_mask = attention_mask.view( 385 batch_size, attn.heads, -1, attention_mask.shape[-1] 386 ) 387 388 # the output of sdp = (batch, num_heads, seq_len, head_dim) 389 # TODO: add support for attn.scale when we move to Torch 2.1 390 hidden_states = F.scaled_dot_product_attention( 391 query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 392 ) 393 394 hidden_states = hidden_states.transpose(1, 2).reshape( 395 batch_size, -1, attn.heads * head_dim 396 ) 397 hidden_states = hidden_states.to(query.dtype) 398 399 # linear proj 400 hidden_states = attn.to_out[0](hidden_states) 401 # dropout 402 hidden_states = attn.to_out[1](hidden_states) 403 404 if input_ndim == 4: 405 hidden_states = hidden_states.transpose(-1, -2).reshape( 406 batch_size, channel, height, width 407 ) 408 409 if attn.residual_connection: 410 hidden_states = hidden_states + residual 411 412 hidden_states = hidden_states / attn.rescale_output_factor 413 414 return hidden_states
27class CustomLiteLAProcessor2_0: 28 """Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE""" 29 30 def __init__(self): 31 self.kernel_func = nn.ReLU(inplace=False) 32 self.eps = 1e-15 33 self.pad_val = 1.0 34 35 def apply_rotary_emb( 36 self, 37 x: torch.Tensor, 38 freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 39 ) -> Tuple[torch.Tensor, torch.Tensor]: 40 """ 41 Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings 42 to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are 43 reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting 44 tensors contain rotary embeddings and are returned as real tensors. 45 46 Args: 47 x (`torch.Tensor`): 48 Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply 49 freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) 50 51 Returns: 52 Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 53 """ 54 cos, sin = freqs_cis # [S, D] 55 cos = cos[None, None] 56 sin = sin[None, None] 57 cos, sin = cos.to(x.device), sin.to(x.device) 58 59 x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] 60 x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) 61 out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) 62 63 return out 64 65 def __call__( 66 self, 67 attn: Attention, 68 hidden_states: torch.FloatTensor, 69 encoder_hidden_states: torch.FloatTensor = None, 70 attention_mask: Optional[torch.FloatTensor] = None, 71 encoder_attention_mask: Optional[torch.FloatTensor] = None, 72 rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 73 rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 74 *args, 75 **kwargs, 76 ) -> torch.FloatTensor: 77 hidden_states_len = hidden_states.shape[1] 78 79 input_ndim = hidden_states.ndim 80 if input_ndim == 4: 81 batch_size, channel, height, width = hidden_states.shape 82 hidden_states = hidden_states.view( 83 batch_size, channel, height * width 84 ).transpose(1, 2) 85 if encoder_hidden_states is not None: 86 context_input_ndim = encoder_hidden_states.ndim 87 if context_input_ndim == 4: 88 batch_size, channel, height, width = encoder_hidden_states.shape 89 encoder_hidden_states = encoder_hidden_states.view( 90 batch_size, channel, height * width 91 ).transpose(1, 2) 92 93 batch_size = hidden_states.shape[0] 94 95 # `sample` projections. 96 dtype = hidden_states.dtype 97 query = attn.to_q(hidden_states) 98 key = attn.to_k(hidden_states) 99 value = attn.to_v(hidden_states) 100 101 # `context` projections. 102 has_encoder_hidden_state_proj = ( 103 hasattr(attn, "add_q_proj") 104 and hasattr(attn, "add_k_proj") 105 and hasattr(attn, "add_v_proj") 106 ) 107 if encoder_hidden_states is not None and has_encoder_hidden_state_proj: 108 encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 109 encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 110 encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 111 112 # attention 113 if not attn.is_cross_attention: 114 query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) 115 key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) 116 value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) 117 else: 118 query = hidden_states 119 key = encoder_hidden_states 120 value = encoder_hidden_states 121 122 inner_dim = key.shape[-1] 123 head_dim = inner_dim // attn.heads 124 125 query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1) 126 key = ( 127 key.transpose(-1, -2) 128 .reshape(batch_size, attn.heads, head_dim, -1) 129 .transpose(-1, -2) 130 ) 131 value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1) 132 133 # RoPE需要 [B, H, S, D] 输入 134 # 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE 135 query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S]) 136 137 # Apply query and key normalization if needed 138 if attn.norm_q is not None: 139 query = attn.norm_q(query) 140 if attn.norm_k is not None: 141 key = attn.norm_k(key) 142 143 # Apply RoPE if needed 144 if rotary_freqs_cis is not None: 145 query = self.apply_rotary_emb(query, rotary_freqs_cis) 146 if not attn.is_cross_attention: 147 key = self.apply_rotary_emb(key, rotary_freqs_cis) 148 elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj: 149 key = self.apply_rotary_emb(key, rotary_freqs_cis_cross) 150 151 # 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S] 152 query = query.permute(0, 1, 3, 2) # [B, H, D, S] 153 154 if attention_mask is not None: 155 # attention_mask: [B, S] -> [B, 1, S, 1] 156 attention_mask = attention_mask[:, None, :, None].to( 157 key.dtype 158 ) # [B, 1, S, 1] 159 query = query * attention_mask.permute( 160 0, 1, 3, 2 161 ) # [B, H, S, D] * [B, 1, S, 1] 162 if not attn.is_cross_attention: 163 key = ( 164 key * attention_mask 165 ) # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘 166 value = value * attention_mask.permute( 167 0, 1, 3, 2 168 ) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度 169 170 if ( 171 attn.is_cross_attention 172 and encoder_attention_mask is not None 173 and has_encoder_hidden_state_proj 174 ): 175 encoder_attention_mask = encoder_attention_mask[:, None, :, None].to( 176 key.dtype 177 ) # [B, 1, S_enc, 1] 178 # 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc] 179 key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1] 180 value = value * encoder_attention_mask.permute( 181 0, 1, 3, 2 182 ) # [B, h, D, S_enc] * [B, 1, 1, S_enc] 183 184 query = self.kernel_func(query) 185 key = self.kernel_func(key) 186 187 query, key, value = query.float(), key.float(), value.float() 188 189 value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val) 190 191 vk = torch.matmul(value, key) 192 193 hidden_states = torch.matmul(vk, query) 194 195 if hidden_states.dtype in [torch.float16, torch.bfloat16]: 196 hidden_states = hidden_states.float() 197 198 hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps) 199 200 hidden_states = hidden_states.view( 201 batch_size, attn.heads * head_dim, -1 202 ).permute(0, 2, 1) 203 204 hidden_states = hidden_states.to(dtype) 205 if encoder_hidden_states is not None: 206 encoder_hidden_states = encoder_hidden_states.to(dtype) 207 208 # Split the attention outputs. 209 if ( 210 encoder_hidden_states is not None 211 and not attn.is_cross_attention 212 and has_encoder_hidden_state_proj 213 ): 214 hidden_states, encoder_hidden_states = ( 215 hidden_states[:, :hidden_states_len], 216 hidden_states[:, hidden_states_len:], 217 ) 218 219 # linear proj 220 hidden_states = attn.to_out[0](hidden_states) 221 # dropout 222 hidden_states = attn.to_out[1](hidden_states) 223 if ( 224 encoder_hidden_states is not None 225 and not attn.context_pre_only 226 and not attn.is_cross_attention 227 and hasattr(attn, "to_add_out") 228 ): 229 encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 230 231 if input_ndim == 4: 232 hidden_states = hidden_states.transpose(-1, -2).reshape( 233 batch_size, channel, height, width 234 ) 235 if encoder_hidden_states is not None and context_input_ndim == 4: 236 encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( 237 batch_size, channel, height, width 238 ) 239 240 if torch.get_autocast_gpu_dtype() == torch.float16: 241 hidden_states = hidden_states.clip(-65504, 65504) 242 if encoder_hidden_states is not None: 243 encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 244 245 return hidden_states, encoder_hidden_states
Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE
35 def apply_rotary_emb( 36 self, 37 x: torch.Tensor, 38 freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 39 ) -> Tuple[torch.Tensor, torch.Tensor]: 40 """ 41 Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings 42 to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are 43 reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting 44 tensors contain rotary embeddings and are returned as real tensors. 45 46 Args: 47 x (`torch.Tensor`): 48 Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply 49 freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) 50 51 Returns: 52 Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 53 """ 54 cos, sin = freqs_cis # [S, D] 55 cos = cos[None, None] 56 sin = sin[None, None] 57 cos, sin = cos.to(x.device), sin.to(x.device) 58 59 x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] 60 x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) 61 out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) 62 63 return out
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.
Args:
x (torch.Tensor):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (Tuple[torch.Tensor]): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
248class CustomerAttnProcessor2_0: 249 r""" 250 Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 251 """ 252 253 def __init__(self): 254 if not hasattr(F, "scaled_dot_product_attention"): 255 raise ImportError( 256 "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." 257 ) 258 259 def apply_rotary_emb( 260 self, 261 x: torch.Tensor, 262 freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 263 ) -> Tuple[torch.Tensor, torch.Tensor]: 264 """ 265 Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings 266 to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are 267 reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting 268 tensors contain rotary embeddings and are returned as real tensors. 269 270 Args: 271 x (`torch.Tensor`): 272 Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply 273 freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) 274 275 Returns: 276 Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 277 """ 278 cos, sin = freqs_cis # [S, D] 279 cos = cos[None, None] 280 sin = sin[None, None] 281 cos, sin = cos.to(x.device), sin.to(x.device) 282 283 x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] 284 x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) 285 out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) 286 287 return out 288 289 def __call__( 290 self, 291 attn: Attention, 292 hidden_states: torch.FloatTensor, 293 encoder_hidden_states: torch.FloatTensor = None, 294 attention_mask: Optional[torch.FloatTensor] = None, 295 encoder_attention_mask: Optional[torch.FloatTensor] = None, 296 rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 297 rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, 298 *args, 299 **kwargs, 300 ) -> torch.Tensor: 301 302 residual = hidden_states 303 input_ndim = hidden_states.ndim 304 305 if input_ndim == 4: 306 batch_size, channel, height, width = hidden_states.shape 307 hidden_states = hidden_states.view( 308 batch_size, channel, height * width 309 ).transpose(1, 2) 310 311 batch_size, sequence_length, _ = ( 312 hidden_states.shape 313 if encoder_hidden_states is None 314 else encoder_hidden_states.shape 315 ) 316 317 has_encoder_hidden_state_proj = ( 318 hasattr(attn, "add_q_proj") 319 and hasattr(attn, "add_k_proj") 320 and hasattr(attn, "add_v_proj") 321 ) 322 323 if attn.group_norm is not None: 324 hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 325 1, 2 326 ) 327 328 query = attn.to_q(hidden_states) 329 330 if encoder_hidden_states is None: 331 encoder_hidden_states = hidden_states 332 elif attn.norm_cross: 333 encoder_hidden_states = attn.norm_encoder_hidden_states( 334 encoder_hidden_states 335 ) 336 337 key = attn.to_k(encoder_hidden_states) 338 value = attn.to_v(encoder_hidden_states) 339 340 inner_dim = key.shape[-1] 341 head_dim = inner_dim // attn.heads 342 343 query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 344 345 key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 346 value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 347 348 if attn.norm_q is not None: 349 query = attn.norm_q(query) 350 if attn.norm_k is not None: 351 key = attn.norm_k(key) 352 353 # Apply RoPE if needed 354 if rotary_freqs_cis is not None: 355 query = self.apply_rotary_emb(query, rotary_freqs_cis) 356 if not attn.is_cross_attention: 357 key = self.apply_rotary_emb(key, rotary_freqs_cis) 358 elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj: 359 key = self.apply_rotary_emb(key, rotary_freqs_cis_cross) 360 361 if ( 362 attn.is_cross_attention 363 and encoder_attention_mask is not None 364 and has_encoder_hidden_state_proj 365 ): 366 # attention_mask: N x S1 367 # encoder_attention_mask: N x S2 368 # cross attention 整合attention_mask和encoder_attention_mask 369 combined_mask = ( 370 attention_mask[:, :, None] * encoder_attention_mask[:, None, :] 371 ) 372 attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf) 373 attention_mask = ( 374 attention_mask[:, None, :, :] 375 .expand(-1, attn.heads, -1, -1) 376 .to(query.dtype) 377 ) 378 379 elif not attn.is_cross_attention and attention_mask is not None: 380 attention_mask = attn.prepare_attention_mask( 381 attention_mask, sequence_length, batch_size 382 ) 383 # scaled_dot_product_attention expects attention_mask shape to be 384 # (batch, heads, source_length, target_length) 385 attention_mask = attention_mask.view( 386 batch_size, attn.heads, -1, attention_mask.shape[-1] 387 ) 388 389 # the output of sdp = (batch, num_heads, seq_len, head_dim) 390 # TODO: add support for attn.scale when we move to Torch 2.1 391 hidden_states = F.scaled_dot_product_attention( 392 query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 393 ) 394 395 hidden_states = hidden_states.transpose(1, 2).reshape( 396 batch_size, -1, attn.heads * head_dim 397 ) 398 hidden_states = hidden_states.to(query.dtype) 399 400 # linear proj 401 hidden_states = attn.to_out[0](hidden_states) 402 # dropout 403 hidden_states = attn.to_out[1](hidden_states) 404 405 if input_ndim == 4: 406 hidden_states = hidden_states.transpose(-1, -2).reshape( 407 batch_size, channel, height, width 408 ) 409 410 if attn.residual_connection: 411 hidden_states = hidden_states + residual 412 413 hidden_states = hidden_states / attn.rescale_output_factor 414 415 return hidden_states
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
259 def apply_rotary_emb( 260 self, 261 x: torch.Tensor, 262 freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 263 ) -> Tuple[torch.Tensor, torch.Tensor]: 264 """ 265 Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings 266 to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are 267 reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting 268 tensors contain rotary embeddings and are returned as real tensors. 269 270 Args: 271 x (`torch.Tensor`): 272 Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply 273 freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) 274 275 Returns: 276 Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 277 """ 278 cos, sin = freqs_cis # [S, D] 279 cos = cos[None, None] 280 sin = sin[None, None] 281 cos, sin = cos.to(x.device), sin.to(x.device) 282 283 x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] 284 x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) 285 out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) 286 287 return out
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.
Args:
x (torch.Tensor):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (Tuple[torch.Tensor]): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.