divisor.acestep.models.customer_attention_processor

  1# Copyright 2024 The HuggingFace Team. All rights reserved.
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#     http://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14from typing import Optional, Union, Tuple
 15
 16import torch
 17import torch.nn.functional as F
 18from torch import nn
 19
 20from diffusers.utils import logging
 21from diffusers.models.attention_processor import Attention
 22
 23logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
 24
 25
 26class CustomLiteLAProcessor2_0:
 27    """Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
 28
 29    def __init__(self):
 30        self.kernel_func = nn.ReLU(inplace=False)
 31        self.eps = 1e-15
 32        self.pad_val = 1.0
 33
 34    def apply_rotary_emb(
 35        self,
 36        x: torch.Tensor,
 37        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
 38    ) -> Tuple[torch.Tensor, torch.Tensor]:
 39        """
 40        Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
 41        to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
 42        reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
 43        tensors contain rotary embeddings and are returned as real tensors.
 44
 45        Args:
 46            x (`torch.Tensor`):
 47                Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
 48            freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
 49
 50        Returns:
 51            Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
 52        """
 53        cos, sin = freqs_cis  # [S, D]
 54        cos = cos[None, None]
 55        sin = sin[None, None]
 56        cos, sin = cos.to(x.device), sin.to(x.device)
 57
 58        x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
 59        x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
 60        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
 61
 62        return out
 63
 64    def __call__(
 65        self,
 66        attn: Attention,
 67        hidden_states: torch.FloatTensor,
 68        encoder_hidden_states: torch.FloatTensor = None,
 69        attention_mask: Optional[torch.FloatTensor] = None,
 70        encoder_attention_mask: Optional[torch.FloatTensor] = None,
 71        rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
 72        rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
 73        *args,
 74        **kwargs,
 75    ) -> torch.FloatTensor:
 76        hidden_states_len = hidden_states.shape[1]
 77
 78        input_ndim = hidden_states.ndim
 79        if input_ndim == 4:
 80            batch_size, channel, height, width = hidden_states.shape
 81            hidden_states = hidden_states.view(
 82                batch_size, channel, height * width
 83            ).transpose(1, 2)
 84        if encoder_hidden_states is not None:
 85            context_input_ndim = encoder_hidden_states.ndim
 86            if context_input_ndim == 4:
 87                batch_size, channel, height, width = encoder_hidden_states.shape
 88                encoder_hidden_states = encoder_hidden_states.view(
 89                    batch_size, channel, height * width
 90                ).transpose(1, 2)
 91
 92        batch_size = hidden_states.shape[0]
 93
 94        # `sample` projections.
 95        dtype = hidden_states.dtype
 96        query = attn.to_q(hidden_states)
 97        key = attn.to_k(hidden_states)
 98        value = attn.to_v(hidden_states)
 99
100        # `context` projections.
101        has_encoder_hidden_state_proj = (
102            hasattr(attn, "add_q_proj")
103            and hasattr(attn, "add_k_proj")
104            and hasattr(attn, "add_v_proj")
105        )
106        if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
107            encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
108            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
109            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
110
111            # attention
112            if not attn.is_cross_attention:
113                query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
114                key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
115                value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
116            else:
117                query = hidden_states
118                key = encoder_hidden_states
119                value = encoder_hidden_states
120
121        inner_dim = key.shape[-1]
122        head_dim = inner_dim // attn.heads
123
124        query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
125        key = (
126            key.transpose(-1, -2)
127            .reshape(batch_size, attn.heads, head_dim, -1)
128            .transpose(-1, -2)
129        )
130        value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
131
132        # RoPE需要 [B, H, S, D] 输入
133        # 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
134        query = query.permute(0, 1, 3, 2)  # [B, H, S, D]  (从 [B, H, D, S])
135
136        # Apply query and key normalization if needed
137        if attn.norm_q is not None:
138            query = attn.norm_q(query)
139        if attn.norm_k is not None:
140            key = attn.norm_k(key)
141
142        # Apply RoPE if needed
143        if rotary_freqs_cis is not None:
144            query = self.apply_rotary_emb(query, rotary_freqs_cis)
145            if not attn.is_cross_attention:
146                key = self.apply_rotary_emb(key, rotary_freqs_cis)
147            elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
148                key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
149
150        # 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
151        query = query.permute(0, 1, 3, 2)  # [B, H, D, S]
152
153        if attention_mask is not None:
154            # attention_mask: [B, S] -> [B, 1, S, 1]
155            attention_mask = attention_mask[:, None, :, None].to(
156                key.dtype
157            )  # [B, 1, S, 1]
158            query = query * attention_mask.permute(
159                0, 1, 3, 2
160            )  # [B, H, S, D] * [B, 1, S, 1]
161            if not attn.is_cross_attention:
162                key = (
163                    key * attention_mask
164                )  # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
165                value = value * attention_mask.permute(
166                    0, 1, 3, 2
167                )  # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
168
169        if (
170            attn.is_cross_attention
171            and encoder_attention_mask is not None
172            and has_encoder_hidden_state_proj
173        ):
174            encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(
175                key.dtype
176            )  # [B, 1, S_enc, 1]
177            # 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
178            key = key * encoder_attention_mask  # [B, h, S_enc, D] * [B, 1, S_enc, 1]
179            value = value * encoder_attention_mask.permute(
180                0, 1, 3, 2
181            )  # [B, h, D, S_enc] * [B, 1, 1, S_enc]
182
183        query = self.kernel_func(query)
184        key = self.kernel_func(key)
185
186        query, key, value = query.float(), key.float(), value.float()
187
188        value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
189
190        vk = torch.matmul(value, key)
191
192        hidden_states = torch.matmul(vk, query)
193
194        if hidden_states.dtype in [torch.float16, torch.bfloat16]:
195            hidden_states = hidden_states.float()
196
197        hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
198
199        hidden_states = hidden_states.view(
200            batch_size, attn.heads * head_dim, -1
201        ).permute(0, 2, 1)
202
203        hidden_states = hidden_states.to(dtype)
204        if encoder_hidden_states is not None:
205            encoder_hidden_states = encoder_hidden_states.to(dtype)
206
207        # Split the attention outputs.
208        if (
209            encoder_hidden_states is not None
210            and not attn.is_cross_attention
211            and has_encoder_hidden_state_proj
212        ):
213            hidden_states, encoder_hidden_states = (
214                hidden_states[:, :hidden_states_len],
215                hidden_states[:, hidden_states_len:],
216            )
217
218        # linear proj
219        hidden_states = attn.to_out[0](hidden_states)
220        # dropout
221        hidden_states = attn.to_out[1](hidden_states)
222        if (
223            encoder_hidden_states is not None
224            and not attn.context_pre_only
225            and not attn.is_cross_attention
226            and hasattr(attn, "to_add_out")
227        ):
228            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
229
230        if input_ndim == 4:
231            hidden_states = hidden_states.transpose(-1, -2).reshape(
232                batch_size, channel, height, width
233            )
234        if encoder_hidden_states is not None and context_input_ndim == 4:
235            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
236                batch_size, channel, height, width
237            )
238
239        if torch.get_autocast_gpu_dtype() == torch.float16:
240            hidden_states = hidden_states.clip(-65504, 65504)
241            if encoder_hidden_states is not None:
242                encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
243
244        return hidden_states, encoder_hidden_states
245
246
247class CustomerAttnProcessor2_0:
248    r"""
249    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
250    """
251
252    def __init__(self):
253        if not hasattr(F, "scaled_dot_product_attention"):
254            raise ImportError(
255                "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
256            )
257
258    def apply_rotary_emb(
259        self,
260        x: torch.Tensor,
261        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
262    ) -> Tuple[torch.Tensor, torch.Tensor]:
263        """
264        Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
265        to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
266        reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
267        tensors contain rotary embeddings and are returned as real tensors.
268
269        Args:
270            x (`torch.Tensor`):
271                Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
272            freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
273
274        Returns:
275            Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
276        """
277        cos, sin = freqs_cis  # [S, D]
278        cos = cos[None, None]
279        sin = sin[None, None]
280        cos, sin = cos.to(x.device), sin.to(x.device)
281
282        x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
283        x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
284        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
285
286        return out
287
288    def __call__(
289        self,
290        attn: Attention,
291        hidden_states: torch.FloatTensor,
292        encoder_hidden_states: torch.FloatTensor = None,
293        attention_mask: Optional[torch.FloatTensor] = None,
294        encoder_attention_mask: Optional[torch.FloatTensor] = None,
295        rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
296        rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
297        *args,
298        **kwargs,
299    ) -> torch.Tensor:
300
301        residual = hidden_states
302        input_ndim = hidden_states.ndim
303
304        if input_ndim == 4:
305            batch_size, channel, height, width = hidden_states.shape
306            hidden_states = hidden_states.view(
307                batch_size, channel, height * width
308            ).transpose(1, 2)
309
310        batch_size, sequence_length, _ = (
311            hidden_states.shape
312            if encoder_hidden_states is None
313            else encoder_hidden_states.shape
314        )
315
316        has_encoder_hidden_state_proj = (
317            hasattr(attn, "add_q_proj")
318            and hasattr(attn, "add_k_proj")
319            and hasattr(attn, "add_v_proj")
320        )
321
322        if attn.group_norm is not None:
323            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
324                1, 2
325            )
326
327        query = attn.to_q(hidden_states)
328
329        if encoder_hidden_states is None:
330            encoder_hidden_states = hidden_states
331        elif attn.norm_cross:
332            encoder_hidden_states = attn.norm_encoder_hidden_states(
333                encoder_hidden_states
334            )
335
336        key = attn.to_k(encoder_hidden_states)
337        value = attn.to_v(encoder_hidden_states)
338
339        inner_dim = key.shape[-1]
340        head_dim = inner_dim // attn.heads
341
342        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
343
344        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
345        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
346
347        if attn.norm_q is not None:
348            query = attn.norm_q(query)
349        if attn.norm_k is not None:
350            key = attn.norm_k(key)
351
352        # Apply RoPE if needed
353        if rotary_freqs_cis is not None:
354            query = self.apply_rotary_emb(query, rotary_freqs_cis)
355            if not attn.is_cross_attention:
356                key = self.apply_rotary_emb(key, rotary_freqs_cis)
357            elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
358                key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
359
360        if (
361            attn.is_cross_attention
362            and encoder_attention_mask is not None
363            and has_encoder_hidden_state_proj
364        ):
365            # attention_mask: N x S1
366            # encoder_attention_mask: N x S2
367            # cross attention 整合attention_mask和encoder_attention_mask
368            combined_mask = (
369                attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
370            )
371            attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
372            attention_mask = (
373                attention_mask[:, None, :, :]
374                .expand(-1, attn.heads, -1, -1)
375                .to(query.dtype)
376            )
377
378        elif not attn.is_cross_attention and attention_mask is not None:
379            attention_mask = attn.prepare_attention_mask(
380                attention_mask, sequence_length, batch_size
381            )
382            # scaled_dot_product_attention expects attention_mask shape to be
383            # (batch, heads, source_length, target_length)
384            attention_mask = attention_mask.view(
385                batch_size, attn.heads, -1, attention_mask.shape[-1]
386            )
387
388        # the output of sdp = (batch, num_heads, seq_len, head_dim)
389        # TODO: add support for attn.scale when we move to Torch 2.1
390        hidden_states = F.scaled_dot_product_attention(
391            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
392        )
393
394        hidden_states = hidden_states.transpose(1, 2).reshape(
395            batch_size, -1, attn.heads * head_dim
396        )
397        hidden_states = hidden_states.to(query.dtype)
398
399        # linear proj
400        hidden_states = attn.to_out[0](hidden_states)
401        # dropout
402        hidden_states = attn.to_out[1](hidden_states)
403
404        if input_ndim == 4:
405            hidden_states = hidden_states.transpose(-1, -2).reshape(
406                batch_size, channel, height, width
407            )
408
409        if attn.residual_connection:
410            hidden_states = hidden_states + residual
411
412        hidden_states = hidden_states / attn.rescale_output_factor
413
414        return hidden_states
class CustomLiteLAProcessor2_0:
 27class CustomLiteLAProcessor2_0:
 28    """Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
 29
 30    def __init__(self):
 31        self.kernel_func = nn.ReLU(inplace=False)
 32        self.eps = 1e-15
 33        self.pad_val = 1.0
 34
 35    def apply_rotary_emb(
 36        self,
 37        x: torch.Tensor,
 38        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
 39    ) -> Tuple[torch.Tensor, torch.Tensor]:
 40        """
 41        Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
 42        to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
 43        reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
 44        tensors contain rotary embeddings and are returned as real tensors.
 45
 46        Args:
 47            x (`torch.Tensor`):
 48                Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
 49            freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
 50
 51        Returns:
 52            Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
 53        """
 54        cos, sin = freqs_cis  # [S, D]
 55        cos = cos[None, None]
 56        sin = sin[None, None]
 57        cos, sin = cos.to(x.device), sin.to(x.device)
 58
 59        x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
 60        x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
 61        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
 62
 63        return out
 64
 65    def __call__(
 66        self,
 67        attn: Attention,
 68        hidden_states: torch.FloatTensor,
 69        encoder_hidden_states: torch.FloatTensor = None,
 70        attention_mask: Optional[torch.FloatTensor] = None,
 71        encoder_attention_mask: Optional[torch.FloatTensor] = None,
 72        rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
 73        rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
 74        *args,
 75        **kwargs,
 76    ) -> torch.FloatTensor:
 77        hidden_states_len = hidden_states.shape[1]
 78
 79        input_ndim = hidden_states.ndim
 80        if input_ndim == 4:
 81            batch_size, channel, height, width = hidden_states.shape
 82            hidden_states = hidden_states.view(
 83                batch_size, channel, height * width
 84            ).transpose(1, 2)
 85        if encoder_hidden_states is not None:
 86            context_input_ndim = encoder_hidden_states.ndim
 87            if context_input_ndim == 4:
 88                batch_size, channel, height, width = encoder_hidden_states.shape
 89                encoder_hidden_states = encoder_hidden_states.view(
 90                    batch_size, channel, height * width
 91                ).transpose(1, 2)
 92
 93        batch_size = hidden_states.shape[0]
 94
 95        # `sample` projections.
 96        dtype = hidden_states.dtype
 97        query = attn.to_q(hidden_states)
 98        key = attn.to_k(hidden_states)
 99        value = attn.to_v(hidden_states)
100
101        # `context` projections.
102        has_encoder_hidden_state_proj = (
103            hasattr(attn, "add_q_proj")
104            and hasattr(attn, "add_k_proj")
105            and hasattr(attn, "add_v_proj")
106        )
107        if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
108            encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
109            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
110            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
111
112            # attention
113            if not attn.is_cross_attention:
114                query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
115                key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
116                value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
117            else:
118                query = hidden_states
119                key = encoder_hidden_states
120                value = encoder_hidden_states
121
122        inner_dim = key.shape[-1]
123        head_dim = inner_dim // attn.heads
124
125        query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
126        key = (
127            key.transpose(-1, -2)
128            .reshape(batch_size, attn.heads, head_dim, -1)
129            .transpose(-1, -2)
130        )
131        value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
132
133        # RoPE需要 [B, H, S, D] 输入
134        # 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
135        query = query.permute(0, 1, 3, 2)  # [B, H, S, D]  (从 [B, H, D, S])
136
137        # Apply query and key normalization if needed
138        if attn.norm_q is not None:
139            query = attn.norm_q(query)
140        if attn.norm_k is not None:
141            key = attn.norm_k(key)
142
143        # Apply RoPE if needed
144        if rotary_freqs_cis is not None:
145            query = self.apply_rotary_emb(query, rotary_freqs_cis)
146            if not attn.is_cross_attention:
147                key = self.apply_rotary_emb(key, rotary_freqs_cis)
148            elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
149                key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
150
151        # 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
152        query = query.permute(0, 1, 3, 2)  # [B, H, D, S]
153
154        if attention_mask is not None:
155            # attention_mask: [B, S] -> [B, 1, S, 1]
156            attention_mask = attention_mask[:, None, :, None].to(
157                key.dtype
158            )  # [B, 1, S, 1]
159            query = query * attention_mask.permute(
160                0, 1, 3, 2
161            )  # [B, H, S, D] * [B, 1, S, 1]
162            if not attn.is_cross_attention:
163                key = (
164                    key * attention_mask
165                )  # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
166                value = value * attention_mask.permute(
167                    0, 1, 3, 2
168                )  # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
169
170        if (
171            attn.is_cross_attention
172            and encoder_attention_mask is not None
173            and has_encoder_hidden_state_proj
174        ):
175            encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(
176                key.dtype
177            )  # [B, 1, S_enc, 1]
178            # 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
179            key = key * encoder_attention_mask  # [B, h, S_enc, D] * [B, 1, S_enc, 1]
180            value = value * encoder_attention_mask.permute(
181                0, 1, 3, 2
182            )  # [B, h, D, S_enc] * [B, 1, 1, S_enc]
183
184        query = self.kernel_func(query)
185        key = self.kernel_func(key)
186
187        query, key, value = query.float(), key.float(), value.float()
188
189        value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
190
191        vk = torch.matmul(value, key)
192
193        hidden_states = torch.matmul(vk, query)
194
195        if hidden_states.dtype in [torch.float16, torch.bfloat16]:
196            hidden_states = hidden_states.float()
197
198        hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
199
200        hidden_states = hidden_states.view(
201            batch_size, attn.heads * head_dim, -1
202        ).permute(0, 2, 1)
203
204        hidden_states = hidden_states.to(dtype)
205        if encoder_hidden_states is not None:
206            encoder_hidden_states = encoder_hidden_states.to(dtype)
207
208        # Split the attention outputs.
209        if (
210            encoder_hidden_states is not None
211            and not attn.is_cross_attention
212            and has_encoder_hidden_state_proj
213        ):
214            hidden_states, encoder_hidden_states = (
215                hidden_states[:, :hidden_states_len],
216                hidden_states[:, hidden_states_len:],
217            )
218
219        # linear proj
220        hidden_states = attn.to_out[0](hidden_states)
221        # dropout
222        hidden_states = attn.to_out[1](hidden_states)
223        if (
224            encoder_hidden_states is not None
225            and not attn.context_pre_only
226            and not attn.is_cross_attention
227            and hasattr(attn, "to_add_out")
228        ):
229            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
230
231        if input_ndim == 4:
232            hidden_states = hidden_states.transpose(-1, -2).reshape(
233                batch_size, channel, height, width
234            )
235        if encoder_hidden_states is not None and context_input_ndim == 4:
236            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
237                batch_size, channel, height, width
238            )
239
240        if torch.get_autocast_gpu_dtype() == torch.float16:
241            hidden_states = hidden_states.clip(-65504, 65504)
242            if encoder_hidden_states is not None:
243                encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
244
245        return hidden_states, encoder_hidden_states

Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE

kernel_func
eps
pad_val
def apply_rotary_emb( self, x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
35    def apply_rotary_emb(
36        self,
37        x: torch.Tensor,
38        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
39    ) -> Tuple[torch.Tensor, torch.Tensor]:
40        """
41        Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
42        to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
43        reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
44        tensors contain rotary embeddings and are returned as real tensors.
45
46        Args:
47            x (`torch.Tensor`):
48                Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
49            freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
50
51        Returns:
52            Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
53        """
54        cos, sin = freqs_cis  # [S, D]
55        cos = cos[None, None]
56        sin = sin[None, None]
57        cos, sin = cos.to(x.device), sin.to(x.device)
58
59        x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
60        x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
61        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
62
63        return out

Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.

Args: x (torch.Tensor): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (Tuple[torch.Tensor]): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)

Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.

class CustomerAttnProcessor2_0:
248class CustomerAttnProcessor2_0:
249    r"""
250    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
251    """
252
253    def __init__(self):
254        if not hasattr(F, "scaled_dot_product_attention"):
255            raise ImportError(
256                "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
257            )
258
259    def apply_rotary_emb(
260        self,
261        x: torch.Tensor,
262        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
263    ) -> Tuple[torch.Tensor, torch.Tensor]:
264        """
265        Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
266        to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
267        reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
268        tensors contain rotary embeddings and are returned as real tensors.
269
270        Args:
271            x (`torch.Tensor`):
272                Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
273            freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
274
275        Returns:
276            Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
277        """
278        cos, sin = freqs_cis  # [S, D]
279        cos = cos[None, None]
280        sin = sin[None, None]
281        cos, sin = cos.to(x.device), sin.to(x.device)
282
283        x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
284        x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
285        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
286
287        return out
288
289    def __call__(
290        self,
291        attn: Attention,
292        hidden_states: torch.FloatTensor,
293        encoder_hidden_states: torch.FloatTensor = None,
294        attention_mask: Optional[torch.FloatTensor] = None,
295        encoder_attention_mask: Optional[torch.FloatTensor] = None,
296        rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
297        rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
298        *args,
299        **kwargs,
300    ) -> torch.Tensor:
301
302        residual = hidden_states
303        input_ndim = hidden_states.ndim
304
305        if input_ndim == 4:
306            batch_size, channel, height, width = hidden_states.shape
307            hidden_states = hidden_states.view(
308                batch_size, channel, height * width
309            ).transpose(1, 2)
310
311        batch_size, sequence_length, _ = (
312            hidden_states.shape
313            if encoder_hidden_states is None
314            else encoder_hidden_states.shape
315        )
316
317        has_encoder_hidden_state_proj = (
318            hasattr(attn, "add_q_proj")
319            and hasattr(attn, "add_k_proj")
320            and hasattr(attn, "add_v_proj")
321        )
322
323        if attn.group_norm is not None:
324            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
325                1, 2
326            )
327
328        query = attn.to_q(hidden_states)
329
330        if encoder_hidden_states is None:
331            encoder_hidden_states = hidden_states
332        elif attn.norm_cross:
333            encoder_hidden_states = attn.norm_encoder_hidden_states(
334                encoder_hidden_states
335            )
336
337        key = attn.to_k(encoder_hidden_states)
338        value = attn.to_v(encoder_hidden_states)
339
340        inner_dim = key.shape[-1]
341        head_dim = inner_dim // attn.heads
342
343        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
345        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
346        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
347
348        if attn.norm_q is not None:
349            query = attn.norm_q(query)
350        if attn.norm_k is not None:
351            key = attn.norm_k(key)
352
353        # Apply RoPE if needed
354        if rotary_freqs_cis is not None:
355            query = self.apply_rotary_emb(query, rotary_freqs_cis)
356            if not attn.is_cross_attention:
357                key = self.apply_rotary_emb(key, rotary_freqs_cis)
358            elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
359                key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
360
361        if (
362            attn.is_cross_attention
363            and encoder_attention_mask is not None
364            and has_encoder_hidden_state_proj
365        ):
366            # attention_mask: N x S1
367            # encoder_attention_mask: N x S2
368            # cross attention 整合attention_mask和encoder_attention_mask
369            combined_mask = (
370                attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
371            )
372            attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
373            attention_mask = (
374                attention_mask[:, None, :, :]
375                .expand(-1, attn.heads, -1, -1)
376                .to(query.dtype)
377            )
378
379        elif not attn.is_cross_attention and attention_mask is not None:
380            attention_mask = attn.prepare_attention_mask(
381                attention_mask, sequence_length, batch_size
382            )
383            # scaled_dot_product_attention expects attention_mask shape to be
384            # (batch, heads, source_length, target_length)
385            attention_mask = attention_mask.view(
386                batch_size, attn.heads, -1, attention_mask.shape[-1]
387            )
388
389        # the output of sdp = (batch, num_heads, seq_len, head_dim)
390        # TODO: add support for attn.scale when we move to Torch 2.1
391        hidden_states = F.scaled_dot_product_attention(
392            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
393        )
394
395        hidden_states = hidden_states.transpose(1, 2).reshape(
396            batch_size, -1, attn.heads * head_dim
397        )
398        hidden_states = hidden_states.to(query.dtype)
399
400        # linear proj
401        hidden_states = attn.to_out[0](hidden_states)
402        # dropout
403        hidden_states = attn.to_out[1](hidden_states)
404
405        if input_ndim == 4:
406            hidden_states = hidden_states.transpose(-1, -2).reshape(
407                batch_size, channel, height, width
408            )
409
410        if attn.residual_connection:
411            hidden_states = hidden_states + residual
412
413        hidden_states = hidden_states / attn.rescale_output_factor
414
415        return hidden_states

Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).

def apply_rotary_emb( self, x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
259    def apply_rotary_emb(
260        self,
261        x: torch.Tensor,
262        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
263    ) -> Tuple[torch.Tensor, torch.Tensor]:
264        """
265        Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
266        to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
267        reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
268        tensors contain rotary embeddings and are returned as real tensors.
269
270        Args:
271            x (`torch.Tensor`):
272                Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
273            freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
274
275        Returns:
276            Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
277        """
278        cos, sin = freqs_cis  # [S, D]
279        cos = cos[None, None]
280        sin = sin[None, None]
281        cos, sin = cos.to(x.device), sin.to(x.device)
282
283        x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
284        x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
285        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
286
287        return out

Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.

Args: x (torch.Tensor): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (Tuple[torch.Tensor]): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)

Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.