divisor.acestep.models.ace_step_transformer

  1# Copyright 2024 The HuggingFace Team. All rights reserved.
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#     http://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14from dataclasses import dataclass
 15from typing import Any, Dict, Optional, Tuple, List, Union
 16
 17import torch
 18import torch.nn.functional as F
 19from torch import nn
 20
 21from diffusers.configuration_utils import ConfigMixin, register_to_config
 22from diffusers.utils import BaseOutput, is_torch_version
 23from diffusers.models.modeling_utils import ModelMixin
 24from diffusers.models.embeddings import TimestepEmbedding, Timesteps
 25from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
 26
 27
 28from divisor.acestep.models.attention import LinearTransformerBlock, t2i_modulate
 29from divisor.acestep.models.lyrics_utils.lyric_encoder import ConformerEncoder as LyricEncoder
 30
 31
 32def cross_norm(hidden_states, controlnet_input):
 33    # input N x T x c
 34    mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1, 2), keepdim=True), hidden_states.std(dim=(1, 2), keepdim=True)
 35    mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1, 2), keepdim=True), controlnet_input.std(dim=(1, 2), keepdim=True)
 36    controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
 37    return controlnet_input
 38
 39
 40# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
 41class Qwen2RotaryEmbedding(nn.Module):
 42    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 43        super().__init__()
 44
 45        self.dim = dim
 46        self.max_position_embeddings = max_position_embeddings
 47        self.base = base
 48        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
 49        self.register_buffer("inv_freq", inv_freq, persistent=False)
 50
 51        # Build here to make `torch.jit.trace` work.
 52        self._set_cos_sin_cache(
 53            seq_len=max_position_embeddings,
 54            device=self.inv_freq.device,
 55            dtype=torch.get_default_dtype(),
 56        )
 57
 58    def _set_cos_sin_cache(self, seq_len, device, dtype):
 59        self.max_seq_len_cached = seq_len
 60        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
 61
 62        freqs = torch.outer(t, self.inv_freq)
 63        # Different from paper, but it uses a different permutation in order to obtain the same calculation
 64        emb = torch.cat((freqs, freqs), dim=-1)
 65        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
 66        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 67
 68    def forward(self, x, seq_len=None):
 69        # x: [bs, num_attention_heads, seq_len, head_size]
 70        if seq_len > self.max_seq_len_cached:
 71            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 72
 73        return (
 74            self.cos_cached[:seq_len].to(dtype=x.dtype),
 75            self.sin_cached[:seq_len].to(dtype=x.dtype),
 76        )
 77
 78
 79class T2IFinalLayer(nn.Module):
 80    """
 81    The final layer of Sana.
 82    """
 83
 84    def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256):
 85        super().__init__()
 86        self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
 87        self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True)
 88        self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
 89        self.out_channels = out_channels
 90        self.patch_size = patch_size
 91
 92    def unpatchfy(
 93        self,
 94        hidden_states: torch.Tensor,
 95        width: int,
 96    ):
 97        # 4 unpatchify
 98        new_height, new_width = 1, hidden_states.size(1)
 99        hidden_states = hidden_states.reshape(
100            shape=(
101                hidden_states.shape[0],
102                new_height,
103                new_width,
104                self.patch_size[0],
105                self.patch_size[1],
106                self.out_channels,
107            )
108        ).contiguous()
109        hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
110        output = hidden_states.reshape(
111            shape=(
112                hidden_states.shape[0],
113                self.out_channels,
114                new_height * self.patch_size[0],
115                new_width * self.patch_size[1],
116            )
117        ).contiguous()
118        if width > new_width:
119            output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), "constant", 0)
120        elif width < new_width:
121            output = output[:, :, :, :width]
122        return output
123
124    def forward(self, x, t, output_length):
125        shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
126        x = t2i_modulate(self.norm_final(x), shift, scale)
127        x = self.linear(x)
128        # unpatchify
129        output = self.unpatchfy(x, output_length)
130        return output
131
132
133class PatchEmbed(nn.Module):
134    """2D Image to Patch Embedding"""
135
136    def __init__(
137        self,
138        height=16,
139        width=4096,
140        patch_size=(16, 1),
141        in_channels=8,
142        embed_dim=1152,
143        bias=True,
144    ):
145        super().__init__()
146        patch_size_h, patch_size_w = patch_size
147        self.early_conv_layers = nn.Sequential(
148            nn.Conv2d(
149                in_channels,
150                in_channels * 256,
151                kernel_size=patch_size,
152                stride=patch_size,
153                padding=0,
154                bias=bias,
155            ),
156            torch.nn.GroupNorm(num_groups=32, num_channels=in_channels * 256, eps=1e-6, affine=True),
157            nn.Conv2d(
158                in_channels * 256,
159                embed_dim,
160                kernel_size=1,
161                stride=1,
162                padding=0,
163                bias=bias,
164            ),
165        )
166        self.patch_size = patch_size
167        self.height, self.width = height // patch_size_h, width // patch_size_w
168        self.base_size = self.width
169
170    def forward(self, latent):
171        # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
172        latent = self.early_conv_layers(latent)
173        latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC
174        return latent
175
176
177@dataclass
178class Transformer2DModelOutput(BaseOutput):
179    sample: torch.FloatTensor
180    proj_losses: Optional[Tuple[Tuple[str, torch.Tensor]]] = None
181
182
183class ACEStepTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
184    _supports_gradient_checkpointing = True
185
186    @register_to_config
187    def __init__(
188        self,
189        in_channels: Optional[int] = 8,
190        num_layers: int = 28,
191        inner_dim: int = 1536,
192        attention_head_dim: int = 64,
193        num_attention_heads: int = 24,
194        mlp_ratio: float = 4.0,
195        out_channels: int = 8,
196        max_position: int = 32768,
197        rope_theta: float = 1000000.0,
198        speaker_embedding_dim: int = 512,
199        text_embedding_dim: int = 768,
200        ssl_encoder_depths: List[int] = [9, 9],
201        ssl_names: List[str] = ["mert", "m-hubert"],
202        ssl_latent_dims: List[int] = [1024, 768],
203        lyric_encoder_vocab_size: int = 6681,
204        lyric_hidden_size: int = 1024,
205        patch_size: List[int] = [16, 1],
206        max_height: int = 16,
207        max_width: int = 4096,
208        **kwargs,
209    ):
210        super().__init__()
211
212        self.num_attention_heads = num_attention_heads
213        self.attention_head_dim = attention_head_dim
214        inner_dim = num_attention_heads * attention_head_dim
215        self.inner_dim = inner_dim
216        self.out_channels = out_channels
217        self.max_position = max_position
218        self.patch_size = patch_size
219
220        self.rope_theta = rope_theta
221
222        self.rotary_emb = Qwen2RotaryEmbedding(
223            dim=self.attention_head_dim,
224            max_position_embeddings=self.max_position,
225            base=self.rope_theta,
226        )
227
228        # 2. Define input layers
229        self.in_channels = in_channels
230
231        # 3. Define transformers blocks
232        self.transformer_blocks = nn.ModuleList(
233            [
234                LinearTransformerBlock(
235                    dim=self.inner_dim,
236                    num_attention_heads=self.num_attention_heads,
237                    attention_head_dim=attention_head_dim,
238                    mlp_ratio=mlp_ratio,
239                    add_cross_attention=True,
240                    add_cross_attention_dim=self.inner_dim,
241                )
242                for i in range(self.config.num_layers)
243            ]
244        )
245        self.num_layers = num_layers
246
247        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
248        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
249        self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True))
250
251        # speaker
252        self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim)
253
254        # genre
255        self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim)
256
257        # lyric
258        self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size)
259        self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0)
260        self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim)
261
262        projector_dim = 2 * self.inner_dim
263
264        self.projectors = nn.ModuleList(
265            [
266                nn.Sequential(
267                    nn.Linear(self.inner_dim, projector_dim),
268                    nn.SiLU(),
269                    nn.Linear(projector_dim, projector_dim),
270                    nn.SiLU(),
271                    nn.Linear(projector_dim, ssl_dim),
272                )
273                for ssl_dim in ssl_latent_dims
274            ]
275        )
276
277        self.ssl_latent_dims = ssl_latent_dims
278        self.ssl_encoder_depths = ssl_encoder_depths
279
280        self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction="mean")
281        self.ssl_names = ssl_names
282
283        self.proj_in = PatchEmbed(
284            height=max_height,
285            width=max_width,
286            patch_size=patch_size,
287            embed_dim=self.inner_dim,
288            bias=True,
289        )
290
291        self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels)
292        self.gradient_checkpointing = False
293
294    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
295    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
296        """
297        Sets the attention processor to use [feed forward
298        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
299
300        Parameters:
301            chunk_size (`int`, *optional*):
302                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
303                over each tensor of dim=`dim`.
304            dim (`int`, *optional*, defaults to `0`):
305                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
306                or dim=1 (sequence length).
307        """
308        if dim not in [0, 1]:
309            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
310
311        # By default chunk size is 1
312        chunk_size = chunk_size or 1
313
314        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
315            if hasattr(module, "set_chunk_feed_forward"):
316                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
317
318            for child in module.children():
319                fn_recursive_feed_forward(child, chunk_size, dim)
320
321        for module in self.children():
322            fn_recursive_feed_forward(module, chunk_size, dim)
323
324    def forward_lyric_encoder(
325        self,
326        lyric_token_idx: Optional[torch.LongTensor] = None,
327        lyric_mask: Optional[torch.LongTensor] = None,
328    ):
329        # N x T x D
330        lyric_embs = self.lyric_embs(lyric_token_idx)
331        prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
332        prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
333        return prompt_prenet_out
334
335    def encode(
336        self,
337        encoder_text_hidden_states: Optional[torch.Tensor] = None,
338        text_attention_mask: Optional[torch.LongTensor] = None,
339        speaker_embeds: Optional[torch.FloatTensor] = None,
340        lyric_token_idx: Optional[torch.LongTensor] = None,
341        lyric_mask: Optional[torch.LongTensor] = None,
342    ):
343        bs = encoder_text_hidden_states.shape[0]
344        device = encoder_text_hidden_states.device
345
346        # speaker embedding
347        encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
348        speaker_mask = torch.ones(bs, 1, device=device)
349
350        # genre embedding
351        encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
352
353        # lyric
354        encoder_lyric_hidden_states = self.forward_lyric_encoder(
355            lyric_token_idx=lyric_token_idx,
356            lyric_mask=lyric_mask,
357        )
358
359        encoder_hidden_states = torch.cat(
360            [
361                encoder_spk_hidden_states,
362                encoder_text_hidden_states,
363                encoder_lyric_hidden_states,
364            ],
365            dim=1,
366        )
367        encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
368        return encoder_hidden_states, encoder_hidden_mask
369
370    def decode(
371        self,
372        hidden_states: torch.Tensor,
373        attention_mask: torch.Tensor,
374        encoder_hidden_states: torch.Tensor,
375        encoder_hidden_mask: torch.Tensor,
376        timestep: Optional[torch.Tensor],
377        ssl_hidden_states: Optional[List[torch.Tensor]] = None,
378        output_length: int = 0,
379        block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
380        controlnet_scale: Union[float, torch.Tensor] = 1.0,
381        return_dict: bool = True,
382    ):
383        embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
384        temb = self.t_block(embedded_timestep)
385
386        hidden_states = self.proj_in(hidden_states)
387
388        # controlnet logic
389        if block_controlnet_hidden_states is not None:
390            control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
391            hidden_states = hidden_states + control_condi * controlnet_scale
392
393        inner_hidden_states = []
394
395        rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
396        encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
397
398        for index_block, block in enumerate(self.transformer_blocks):
399            if self.training and self.gradient_checkpointing:
400                hidden_states = torch.utils.checkpoint.checkpoint(
401                    block,
402                    hidden_states=hidden_states,
403                    attention_mask=attention_mask,
404                    encoder_hidden_states=encoder_hidden_states,
405                    encoder_attention_mask=encoder_hidden_mask,
406                    rotary_freqs_cis=rotary_freqs_cis,
407                    rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
408                    temb=temb,
409                    use_reentrant=False,
410                )
411
412            else:
413                hidden_states = block(
414                    hidden_states=hidden_states,
415                    attention_mask=attention_mask,
416                    encoder_hidden_states=encoder_hidden_states,
417                    encoder_attention_mask=encoder_hidden_mask,
418                    rotary_freqs_cis=rotary_freqs_cis,
419                    rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
420                    temb=temb,
421                )
422
423            for ssl_encoder_depth in self.ssl_encoder_depths:
424                if index_block == ssl_encoder_depth:
425                    inner_hidden_states.append(hidden_states)
426
427        proj_losses = []
428        if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0:
429            for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names):
430                if ssl_hidden_state is None:
431                    continue
432                # 1. N x T x D1 -> N x D x D2
433                est_ssl_hidden_state = projector(inner_hidden_state)
434                # 3. projection loss
435                bs = inner_hidden_state.shape[0]
436                proj_loss = 0.0
437                for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)):
438                    # 2. interpolate
439                    z_tilde = (
440                        F.interpolate(
441                            z_tilde.unsqueeze(0).transpose(1, 2),
442                            size=len(z),
443                            mode="linear",
444                            align_corners=False,
445                        )
446                        .transpose(1, 2)
447                        .squeeze(0)
448                    )
449
450                    z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1)
451                    z = torch.nn.functional.normalize(z, dim=-1)
452                    # T x d -> T x 1 -> 1
453                    target = torch.ones(z.shape[0], device=z.device)
454                    proj_loss += self.cosine_loss(z, z_tilde, target)
455                proj_losses.append((ssl_name, proj_loss / bs))
456
457        output = self.final_layer(hidden_states, embedded_timestep, output_length)
458        if not return_dict:
459            return (output, proj_losses)
460
461        return Transformer2DModelOutput(sample=output, proj_losses=proj_losses)
462
463    # @torch.compile
464    def forward(
465        self,
466        hidden_states: torch.Tensor,
467        attention_mask: torch.Tensor,
468        encoder_text_hidden_states: Optional[torch.Tensor] = None,
469        text_attention_mask: Optional[torch.LongTensor] = None,
470        speaker_embeds: Optional[torch.FloatTensor] = None,
471        lyric_token_idx: Optional[torch.LongTensor] = None,
472        lyric_mask: Optional[torch.LongTensor] = None,
473        timestep: Optional[torch.Tensor] = None,
474        ssl_hidden_states: Optional[List[torch.Tensor]] = None,
475        block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
476        controlnet_scale: Union[float, torch.Tensor] = 1.0,
477        return_dict: bool = True,
478    ):
479        encoder_hidden_states, encoder_hidden_mask = self.encode(
480            encoder_text_hidden_states=encoder_text_hidden_states,
481            text_attention_mask=text_attention_mask,
482            speaker_embeds=speaker_embeds,
483            lyric_token_idx=lyric_token_idx,
484            lyric_mask=lyric_mask,
485        )
486
487        output_length = hidden_states.shape[-1]
488
489        output = self.decode(
490            hidden_states=hidden_states,
491            attention_mask=attention_mask,
492            encoder_hidden_states=encoder_hidden_states,
493            encoder_hidden_mask=encoder_hidden_mask,
494            timestep=timestep,
495            ssl_hidden_states=ssl_hidden_states,
496            output_length=output_length,
497            block_controlnet_hidden_states=block_controlnet_hidden_states,
498            controlnet_scale=controlnet_scale,
499            return_dict=return_dict,
500        )
501
502        return output
def cross_norm(hidden_states, controlnet_input):
33def cross_norm(hidden_states, controlnet_input):
34    # input N x T x c
35    mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1, 2), keepdim=True), hidden_states.std(dim=(1, 2), keepdim=True)
36    mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1, 2), keepdim=True), controlnet_input.std(dim=(1, 2), keepdim=True)
37    controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
38    return controlnet_input
class Qwen2RotaryEmbedding(torch.nn.modules.module.Module):
42class Qwen2RotaryEmbedding(nn.Module):
43    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
44        super().__init__()
45
46        self.dim = dim
47        self.max_position_embeddings = max_position_embeddings
48        self.base = base
49        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
50        self.register_buffer("inv_freq", inv_freq, persistent=False)
51
52        # Build here to make `torch.jit.trace` work.
53        self._set_cos_sin_cache(
54            seq_len=max_position_embeddings,
55            device=self.inv_freq.device,
56            dtype=torch.get_default_dtype(),
57        )
58
59    def _set_cos_sin_cache(self, seq_len, device, dtype):
60        self.max_seq_len_cached = seq_len
61        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
62
63        freqs = torch.outer(t, self.inv_freq)
64        # Different from paper, but it uses a different permutation in order to obtain the same calculation
65        emb = torch.cat((freqs, freqs), dim=-1)
66        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
67        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
68
69    def forward(self, x, seq_len=None):
70        # x: [bs, num_attention_heads, seq_len, head_size]
71        if seq_len > self.max_seq_len_cached:
72            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
73
74        return (
75            self.cos_cached[:seq_len].to(dtype=x.dtype),
76            self.sin_cached[:seq_len].to(dtype=x.dtype),
77        )

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Qwen2RotaryEmbedding(dim, max_position_embeddings=2048, base=10000, device=None)
43    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
44        super().__init__()
45
46        self.dim = dim
47        self.max_position_embeddings = max_position_embeddings
48        self.base = base
49        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
50        self.register_buffer("inv_freq", inv_freq, persistent=False)
51
52        # Build here to make `torch.jit.trace` work.
53        self._set_cos_sin_cache(
54            seq_len=max_position_embeddings,
55            device=self.inv_freq.device,
56            dtype=torch.get_default_dtype(),
57        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dim
max_position_embeddings
base
def forward(self, x, seq_len=None):
69    def forward(self, x, seq_len=None):
70        # x: [bs, num_attention_heads, seq_len, head_size]
71        if seq_len > self.max_seq_len_cached:
72            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
73
74        return (
75            self.cos_cached[:seq_len].to(dtype=x.dtype),
76            self.sin_cached[:seq_len].to(dtype=x.dtype),
77        )

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class T2IFinalLayer(torch.nn.modules.module.Module):
 80class T2IFinalLayer(nn.Module):
 81    """
 82    The final layer of Sana.
 83    """
 84
 85    def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256):
 86        super().__init__()
 87        self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
 88        self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True)
 89        self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
 90        self.out_channels = out_channels
 91        self.patch_size = patch_size
 92
 93    def unpatchfy(
 94        self,
 95        hidden_states: torch.Tensor,
 96        width: int,
 97    ):
 98        # 4 unpatchify
 99        new_height, new_width = 1, hidden_states.size(1)
100        hidden_states = hidden_states.reshape(
101            shape=(
102                hidden_states.shape[0],
103                new_height,
104                new_width,
105                self.patch_size[0],
106                self.patch_size[1],
107                self.out_channels,
108            )
109        ).contiguous()
110        hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
111        output = hidden_states.reshape(
112            shape=(
113                hidden_states.shape[0],
114                self.out_channels,
115                new_height * self.patch_size[0],
116                new_width * self.patch_size[1],
117            )
118        ).contiguous()
119        if width > new_width:
120            output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), "constant", 0)
121        elif width < new_width:
122            output = output[:, :, :, :width]
123        return output
124
125    def forward(self, x, t, output_length):
126        shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
127        x = t2i_modulate(self.norm_final(x), shift, scale)
128        x = self.linear(x)
129        # unpatchify
130        output = self.unpatchfy(x, output_length)
131        return output

The final layer of Sana.

T2IFinalLayer(hidden_size, patch_size=[16, 1], out_channels=256)
85    def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256):
86        super().__init__()
87        self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
88        self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True)
89        self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
90        self.out_channels = out_channels
91        self.patch_size = patch_size

Initialize internal Module state, shared by both nn.Module and ScriptModule.

norm_final
linear
scale_shift_table
out_channels
patch_size
def unpatchfy(self, hidden_states: torch.Tensor, width: int):
 93    def unpatchfy(
 94        self,
 95        hidden_states: torch.Tensor,
 96        width: int,
 97    ):
 98        # 4 unpatchify
 99        new_height, new_width = 1, hidden_states.size(1)
100        hidden_states = hidden_states.reshape(
101            shape=(
102                hidden_states.shape[0],
103                new_height,
104                new_width,
105                self.patch_size[0],
106                self.patch_size[1],
107                self.out_channels,
108            )
109        ).contiguous()
110        hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
111        output = hidden_states.reshape(
112            shape=(
113                hidden_states.shape[0],
114                self.out_channels,
115                new_height * self.patch_size[0],
116                new_width * self.patch_size[1],
117            )
118        ).contiguous()
119        if width > new_width:
120            output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), "constant", 0)
121        elif width < new_width:
122            output = output[:, :, :, :width]
123        return output
def forward(self, x, t, output_length):
125    def forward(self, x, t, output_length):
126        shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
127        x = t2i_modulate(self.norm_final(x), shift, scale)
128        x = self.linear(x)
129        # unpatchify
130        output = self.unpatchfy(x, output_length)
131        return output

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class PatchEmbed(torch.nn.modules.module.Module):
134class PatchEmbed(nn.Module):
135    """2D Image to Patch Embedding"""
136
137    def __init__(
138        self,
139        height=16,
140        width=4096,
141        patch_size=(16, 1),
142        in_channels=8,
143        embed_dim=1152,
144        bias=True,
145    ):
146        super().__init__()
147        patch_size_h, patch_size_w = patch_size
148        self.early_conv_layers = nn.Sequential(
149            nn.Conv2d(
150                in_channels,
151                in_channels * 256,
152                kernel_size=patch_size,
153                stride=patch_size,
154                padding=0,
155                bias=bias,
156            ),
157            torch.nn.GroupNorm(num_groups=32, num_channels=in_channels * 256, eps=1e-6, affine=True),
158            nn.Conv2d(
159                in_channels * 256,
160                embed_dim,
161                kernel_size=1,
162                stride=1,
163                padding=0,
164                bias=bias,
165            ),
166        )
167        self.patch_size = patch_size
168        self.height, self.width = height // patch_size_h, width // patch_size_w
169        self.base_size = self.width
170
171    def forward(self, latent):
172        # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
173        latent = self.early_conv_layers(latent)
174        latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC
175        return latent

2D Image to Patch Embedding

PatchEmbed( height=16, width=4096, patch_size=(16, 1), in_channels=8, embed_dim=1152, bias=True)
137    def __init__(
138        self,
139        height=16,
140        width=4096,
141        patch_size=(16, 1),
142        in_channels=8,
143        embed_dim=1152,
144        bias=True,
145    ):
146        super().__init__()
147        patch_size_h, patch_size_w = patch_size
148        self.early_conv_layers = nn.Sequential(
149            nn.Conv2d(
150                in_channels,
151                in_channels * 256,
152                kernel_size=patch_size,
153                stride=patch_size,
154                padding=0,
155                bias=bias,
156            ),
157            torch.nn.GroupNorm(num_groups=32, num_channels=in_channels * 256, eps=1e-6, affine=True),
158            nn.Conv2d(
159                in_channels * 256,
160                embed_dim,
161                kernel_size=1,
162                stride=1,
163                padding=0,
164                bias=bias,
165            ),
166        )
167        self.patch_size = patch_size
168        self.height, self.width = height // patch_size_h, width // patch_size_w
169        self.base_size = self.width

Initialize internal Module state, shared by both nn.Module and ScriptModule.

early_conv_layers
patch_size
base_size
def forward(self, latent):
171    def forward(self, latent):
172        # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
173        latent = self.early_conv_layers(latent)
174        latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC
175        return latent

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

@dataclass
class Transformer2DModelOutput(diffusers.utils.outputs.BaseOutput):
178@dataclass
179class Transformer2DModelOutput(BaseOutput):
180    sample: torch.FloatTensor
181    proj_losses: Optional[Tuple[Tuple[str, torch.Tensor]]] = None
sample: torch.FloatTensor
proj_losses: Optional[Tuple[Tuple[str, torch.Tensor]]] = None
class ACEStepTransformer2DModel(diffusers.models.modeling_utils.ModelMixin, diffusers.configuration_utils.ConfigMixin, diffusers.loaders.peft.PeftAdapterMixin, diffusers.loaders.single_file_model.FromOriginalModelMixin):
184class ACEStepTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
185    _supports_gradient_checkpointing = True
186
187    @register_to_config
188    def __init__(
189        self,
190        in_channels: Optional[int] = 8,
191        num_layers: int = 28,
192        inner_dim: int = 1536,
193        attention_head_dim: int = 64,
194        num_attention_heads: int = 24,
195        mlp_ratio: float = 4.0,
196        out_channels: int = 8,
197        max_position: int = 32768,
198        rope_theta: float = 1000000.0,
199        speaker_embedding_dim: int = 512,
200        text_embedding_dim: int = 768,
201        ssl_encoder_depths: List[int] = [9, 9],
202        ssl_names: List[str] = ["mert", "m-hubert"],
203        ssl_latent_dims: List[int] = [1024, 768],
204        lyric_encoder_vocab_size: int = 6681,
205        lyric_hidden_size: int = 1024,
206        patch_size: List[int] = [16, 1],
207        max_height: int = 16,
208        max_width: int = 4096,
209        **kwargs,
210    ):
211        super().__init__()
212
213        self.num_attention_heads = num_attention_heads
214        self.attention_head_dim = attention_head_dim
215        inner_dim = num_attention_heads * attention_head_dim
216        self.inner_dim = inner_dim
217        self.out_channels = out_channels
218        self.max_position = max_position
219        self.patch_size = patch_size
220
221        self.rope_theta = rope_theta
222
223        self.rotary_emb = Qwen2RotaryEmbedding(
224            dim=self.attention_head_dim,
225            max_position_embeddings=self.max_position,
226            base=self.rope_theta,
227        )
228
229        # 2. Define input layers
230        self.in_channels = in_channels
231
232        # 3. Define transformers blocks
233        self.transformer_blocks = nn.ModuleList(
234            [
235                LinearTransformerBlock(
236                    dim=self.inner_dim,
237                    num_attention_heads=self.num_attention_heads,
238                    attention_head_dim=attention_head_dim,
239                    mlp_ratio=mlp_ratio,
240                    add_cross_attention=True,
241                    add_cross_attention_dim=self.inner_dim,
242                )
243                for i in range(self.config.num_layers)
244            ]
245        )
246        self.num_layers = num_layers
247
248        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
249        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
250        self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True))
251
252        # speaker
253        self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim)
254
255        # genre
256        self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim)
257
258        # lyric
259        self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size)
260        self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0)
261        self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim)
262
263        projector_dim = 2 * self.inner_dim
264
265        self.projectors = nn.ModuleList(
266            [
267                nn.Sequential(
268                    nn.Linear(self.inner_dim, projector_dim),
269                    nn.SiLU(),
270                    nn.Linear(projector_dim, projector_dim),
271                    nn.SiLU(),
272                    nn.Linear(projector_dim, ssl_dim),
273                )
274                for ssl_dim in ssl_latent_dims
275            ]
276        )
277
278        self.ssl_latent_dims = ssl_latent_dims
279        self.ssl_encoder_depths = ssl_encoder_depths
280
281        self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction="mean")
282        self.ssl_names = ssl_names
283
284        self.proj_in = PatchEmbed(
285            height=max_height,
286            width=max_width,
287            patch_size=patch_size,
288            embed_dim=self.inner_dim,
289            bias=True,
290        )
291
292        self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels)
293        self.gradient_checkpointing = False
294
295    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
296    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
297        """
298        Sets the attention processor to use [feed forward
299        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
300
301        Parameters:
302            chunk_size (`int`, *optional*):
303                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
304                over each tensor of dim=`dim`.
305            dim (`int`, *optional*, defaults to `0`):
306                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
307                or dim=1 (sequence length).
308        """
309        if dim not in [0, 1]:
310            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
311
312        # By default chunk size is 1
313        chunk_size = chunk_size or 1
314
315        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
316            if hasattr(module, "set_chunk_feed_forward"):
317                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
318
319            for child in module.children():
320                fn_recursive_feed_forward(child, chunk_size, dim)
321
322        for module in self.children():
323            fn_recursive_feed_forward(module, chunk_size, dim)
324
325    def forward_lyric_encoder(
326        self,
327        lyric_token_idx: Optional[torch.LongTensor] = None,
328        lyric_mask: Optional[torch.LongTensor] = None,
329    ):
330        # N x T x D
331        lyric_embs = self.lyric_embs(lyric_token_idx)
332        prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
333        prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
334        return prompt_prenet_out
335
336    def encode(
337        self,
338        encoder_text_hidden_states: Optional[torch.Tensor] = None,
339        text_attention_mask: Optional[torch.LongTensor] = None,
340        speaker_embeds: Optional[torch.FloatTensor] = None,
341        lyric_token_idx: Optional[torch.LongTensor] = None,
342        lyric_mask: Optional[torch.LongTensor] = None,
343    ):
344        bs = encoder_text_hidden_states.shape[0]
345        device = encoder_text_hidden_states.device
346
347        # speaker embedding
348        encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
349        speaker_mask = torch.ones(bs, 1, device=device)
350
351        # genre embedding
352        encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
353
354        # lyric
355        encoder_lyric_hidden_states = self.forward_lyric_encoder(
356            lyric_token_idx=lyric_token_idx,
357            lyric_mask=lyric_mask,
358        )
359
360        encoder_hidden_states = torch.cat(
361            [
362                encoder_spk_hidden_states,
363                encoder_text_hidden_states,
364                encoder_lyric_hidden_states,
365            ],
366            dim=1,
367        )
368        encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
369        return encoder_hidden_states, encoder_hidden_mask
370
371    def decode(
372        self,
373        hidden_states: torch.Tensor,
374        attention_mask: torch.Tensor,
375        encoder_hidden_states: torch.Tensor,
376        encoder_hidden_mask: torch.Tensor,
377        timestep: Optional[torch.Tensor],
378        ssl_hidden_states: Optional[List[torch.Tensor]] = None,
379        output_length: int = 0,
380        block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
381        controlnet_scale: Union[float, torch.Tensor] = 1.0,
382        return_dict: bool = True,
383    ):
384        embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
385        temb = self.t_block(embedded_timestep)
386
387        hidden_states = self.proj_in(hidden_states)
388
389        # controlnet logic
390        if block_controlnet_hidden_states is not None:
391            control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
392            hidden_states = hidden_states + control_condi * controlnet_scale
393
394        inner_hidden_states = []
395
396        rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
397        encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
398
399        for index_block, block in enumerate(self.transformer_blocks):
400            if self.training and self.gradient_checkpointing:
401                hidden_states = torch.utils.checkpoint.checkpoint(
402                    block,
403                    hidden_states=hidden_states,
404                    attention_mask=attention_mask,
405                    encoder_hidden_states=encoder_hidden_states,
406                    encoder_attention_mask=encoder_hidden_mask,
407                    rotary_freqs_cis=rotary_freqs_cis,
408                    rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
409                    temb=temb,
410                    use_reentrant=False,
411                )
412
413            else:
414                hidden_states = block(
415                    hidden_states=hidden_states,
416                    attention_mask=attention_mask,
417                    encoder_hidden_states=encoder_hidden_states,
418                    encoder_attention_mask=encoder_hidden_mask,
419                    rotary_freqs_cis=rotary_freqs_cis,
420                    rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
421                    temb=temb,
422                )
423
424            for ssl_encoder_depth in self.ssl_encoder_depths:
425                if index_block == ssl_encoder_depth:
426                    inner_hidden_states.append(hidden_states)
427
428        proj_losses = []
429        if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0:
430            for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names):
431                if ssl_hidden_state is None:
432                    continue
433                # 1. N x T x D1 -> N x D x D2
434                est_ssl_hidden_state = projector(inner_hidden_state)
435                # 3. projection loss
436                bs = inner_hidden_state.shape[0]
437                proj_loss = 0.0
438                for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)):
439                    # 2. interpolate
440                    z_tilde = (
441                        F.interpolate(
442                            z_tilde.unsqueeze(0).transpose(1, 2),
443                            size=len(z),
444                            mode="linear",
445                            align_corners=False,
446                        )
447                        .transpose(1, 2)
448                        .squeeze(0)
449                    )
450
451                    z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1)
452                    z = torch.nn.functional.normalize(z, dim=-1)
453                    # T x d -> T x 1 -> 1
454                    target = torch.ones(z.shape[0], device=z.device)
455                    proj_loss += self.cosine_loss(z, z_tilde, target)
456                proj_losses.append((ssl_name, proj_loss / bs))
457
458        output = self.final_layer(hidden_states, embedded_timestep, output_length)
459        if not return_dict:
460            return (output, proj_losses)
461
462        return Transformer2DModelOutput(sample=output, proj_losses=proj_losses)
463
464    # @torch.compile
465    def forward(
466        self,
467        hidden_states: torch.Tensor,
468        attention_mask: torch.Tensor,
469        encoder_text_hidden_states: Optional[torch.Tensor] = None,
470        text_attention_mask: Optional[torch.LongTensor] = None,
471        speaker_embeds: Optional[torch.FloatTensor] = None,
472        lyric_token_idx: Optional[torch.LongTensor] = None,
473        lyric_mask: Optional[torch.LongTensor] = None,
474        timestep: Optional[torch.Tensor] = None,
475        ssl_hidden_states: Optional[List[torch.Tensor]] = None,
476        block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
477        controlnet_scale: Union[float, torch.Tensor] = 1.0,
478        return_dict: bool = True,
479    ):
480        encoder_hidden_states, encoder_hidden_mask = self.encode(
481            encoder_text_hidden_states=encoder_text_hidden_states,
482            text_attention_mask=text_attention_mask,
483            speaker_embeds=speaker_embeds,
484            lyric_token_idx=lyric_token_idx,
485            lyric_mask=lyric_mask,
486        )
487
488        output_length = hidden_states.shape[-1]
489
490        output = self.decode(
491            hidden_states=hidden_states,
492            attention_mask=attention_mask,
493            encoder_hidden_states=encoder_hidden_states,
494            encoder_hidden_mask=encoder_hidden_mask,
495            timestep=timestep,
496            ssl_hidden_states=ssl_hidden_states,
497            output_length=output_length,
498            block_controlnet_hidden_states=block_controlnet_hidden_states,
499            controlnet_scale=controlnet_scale,
500            return_dict=return_dict,
501        )
502
503        return output

Base class for all models.

[ModelMixin] takes care of storing the model configuration and provides methods for loading, downloading and saving models.

- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
@register_to_config
ACEStepTransformer2DModel( in_channels: Optional[int] = 8, num_layers: int = 28, inner_dim: int = 1536, attention_head_dim: int = 64, num_attention_heads: int = 24, mlp_ratio: float = 4.0, out_channels: int = 8, max_position: int = 32768, rope_theta: float = 1000000.0, speaker_embedding_dim: int = 512, text_embedding_dim: int = 768, ssl_encoder_depths: List[int] = [9, 9], ssl_names: List[str] = ['mert', 'm-hubert'], ssl_latent_dims: List[int] = [1024, 768], lyric_encoder_vocab_size: int = 6681, lyric_hidden_size: int = 1024, patch_size: List[int] = [16, 1], max_height: int = 16, max_width: int = 4096, **kwargs)
187    @register_to_config
188    def __init__(
189        self,
190        in_channels: Optional[int] = 8,
191        num_layers: int = 28,
192        inner_dim: int = 1536,
193        attention_head_dim: int = 64,
194        num_attention_heads: int = 24,
195        mlp_ratio: float = 4.0,
196        out_channels: int = 8,
197        max_position: int = 32768,
198        rope_theta: float = 1000000.0,
199        speaker_embedding_dim: int = 512,
200        text_embedding_dim: int = 768,
201        ssl_encoder_depths: List[int] = [9, 9],
202        ssl_names: List[str] = ["mert", "m-hubert"],
203        ssl_latent_dims: List[int] = [1024, 768],
204        lyric_encoder_vocab_size: int = 6681,
205        lyric_hidden_size: int = 1024,
206        patch_size: List[int] = [16, 1],
207        max_height: int = 16,
208        max_width: int = 4096,
209        **kwargs,
210    ):
211        super().__init__()
212
213        self.num_attention_heads = num_attention_heads
214        self.attention_head_dim = attention_head_dim
215        inner_dim = num_attention_heads * attention_head_dim
216        self.inner_dim = inner_dim
217        self.out_channels = out_channels
218        self.max_position = max_position
219        self.patch_size = patch_size
220
221        self.rope_theta = rope_theta
222
223        self.rotary_emb = Qwen2RotaryEmbedding(
224            dim=self.attention_head_dim,
225            max_position_embeddings=self.max_position,
226            base=self.rope_theta,
227        )
228
229        # 2. Define input layers
230        self.in_channels = in_channels
231
232        # 3. Define transformers blocks
233        self.transformer_blocks = nn.ModuleList(
234            [
235                LinearTransformerBlock(
236                    dim=self.inner_dim,
237                    num_attention_heads=self.num_attention_heads,
238                    attention_head_dim=attention_head_dim,
239                    mlp_ratio=mlp_ratio,
240                    add_cross_attention=True,
241                    add_cross_attention_dim=self.inner_dim,
242                )
243                for i in range(self.config.num_layers)
244            ]
245        )
246        self.num_layers = num_layers
247
248        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
249        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
250        self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True))
251
252        # speaker
253        self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim)
254
255        # genre
256        self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim)
257
258        # lyric
259        self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size)
260        self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0)
261        self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim)
262
263        projector_dim = 2 * self.inner_dim
264
265        self.projectors = nn.ModuleList(
266            [
267                nn.Sequential(
268                    nn.Linear(self.inner_dim, projector_dim),
269                    nn.SiLU(),
270                    nn.Linear(projector_dim, projector_dim),
271                    nn.SiLU(),
272                    nn.Linear(projector_dim, ssl_dim),
273                )
274                for ssl_dim in ssl_latent_dims
275            ]
276        )
277
278        self.ssl_latent_dims = ssl_latent_dims
279        self.ssl_encoder_depths = ssl_encoder_depths
280
281        self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction="mean")
282        self.ssl_names = ssl_names
283
284        self.proj_in = PatchEmbed(
285            height=max_height,
286            width=max_width,
287            patch_size=patch_size,
288            embed_dim=self.inner_dim,
289            bias=True,
290        )
291
292        self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels)
293        self.gradient_checkpointing = False

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_attention_heads
attention_head_dim
inner_dim
out_channels
max_position
patch_size
rope_theta
rotary_emb
in_channels
transformer_blocks
num_layers
time_proj
timestep_embedder
t_block
speaker_embedder
genre_embedder
lyric_embs
lyric_encoder
lyric_proj
projectors
ssl_latent_dims
ssl_encoder_depths
cosine_loss
ssl_names
proj_in
final_layer
gradient_checkpointing
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
296    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
297        """
298        Sets the attention processor to use [feed forward
299        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
300
301        Parameters:
302            chunk_size (`int`, *optional*):
303                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
304                over each tensor of dim=`dim`.
305            dim (`int`, *optional*, defaults to `0`):
306                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
307                or dim=1 (sequence length).
308        """
309        if dim not in [0, 1]:
310            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
311
312        # By default chunk size is 1
313        chunk_size = chunk_size or 1
314
315        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
316            if hasattr(module, "set_chunk_feed_forward"):
317                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
318
319            for child in module.children():
320                fn_recursive_feed_forward(child, chunk_size, dim)
321
322        for module in self.children():
323            fn_recursive_feed_forward(module, chunk_size, dim)

Sets the attention processor to use feed forward chunking.

Parameters: chunk_size (int, optional): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=dim. dim (int, optional, defaults to 0): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length).

def forward_lyric_encoder( self, lyric_token_idx: Optional[torch.LongTensor] = None, lyric_mask: Optional[torch.LongTensor] = None):
325    def forward_lyric_encoder(
326        self,
327        lyric_token_idx: Optional[torch.LongTensor] = None,
328        lyric_mask: Optional[torch.LongTensor] = None,
329    ):
330        # N x T x D
331        lyric_embs = self.lyric_embs(lyric_token_idx)
332        prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
333        prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
334        return prompt_prenet_out
def encode( self, encoder_text_hidden_states: Optional[torch.Tensor] = None, text_attention_mask: Optional[torch.LongTensor] = None, speaker_embeds: Optional[torch.FloatTensor] = None, lyric_token_idx: Optional[torch.LongTensor] = None, lyric_mask: Optional[torch.LongTensor] = None):
336    def encode(
337        self,
338        encoder_text_hidden_states: Optional[torch.Tensor] = None,
339        text_attention_mask: Optional[torch.LongTensor] = None,
340        speaker_embeds: Optional[torch.FloatTensor] = None,
341        lyric_token_idx: Optional[torch.LongTensor] = None,
342        lyric_mask: Optional[torch.LongTensor] = None,
343    ):
344        bs = encoder_text_hidden_states.shape[0]
345        device = encoder_text_hidden_states.device
346
347        # speaker embedding
348        encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
349        speaker_mask = torch.ones(bs, 1, device=device)
350
351        # genre embedding
352        encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
353
354        # lyric
355        encoder_lyric_hidden_states = self.forward_lyric_encoder(
356            lyric_token_idx=lyric_token_idx,
357            lyric_mask=lyric_mask,
358        )
359
360        encoder_hidden_states = torch.cat(
361            [
362                encoder_spk_hidden_states,
363                encoder_text_hidden_states,
364                encoder_lyric_hidden_states,
365            ],
366            dim=1,
367        )
368        encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
369        return encoder_hidden_states, encoder_hidden_mask
def decode( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_mask: torch.Tensor, timestep: Optional[torch.Tensor], ssl_hidden_states: Optional[List[torch.Tensor]] = None, output_length: int = 0, block_controlnet_hidden_states: Union[List[torch.Tensor], torch.Tensor, NoneType] = None, controlnet_scale: Union[float, torch.Tensor] = 1.0, return_dict: bool = True):
371    def decode(
372        self,
373        hidden_states: torch.Tensor,
374        attention_mask: torch.Tensor,
375        encoder_hidden_states: torch.Tensor,
376        encoder_hidden_mask: torch.Tensor,
377        timestep: Optional[torch.Tensor],
378        ssl_hidden_states: Optional[List[torch.Tensor]] = None,
379        output_length: int = 0,
380        block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
381        controlnet_scale: Union[float, torch.Tensor] = 1.0,
382        return_dict: bool = True,
383    ):
384        embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
385        temb = self.t_block(embedded_timestep)
386
387        hidden_states = self.proj_in(hidden_states)
388
389        # controlnet logic
390        if block_controlnet_hidden_states is not None:
391            control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
392            hidden_states = hidden_states + control_condi * controlnet_scale
393
394        inner_hidden_states = []
395
396        rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
397        encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
398
399        for index_block, block in enumerate(self.transformer_blocks):
400            if self.training and self.gradient_checkpointing:
401                hidden_states = torch.utils.checkpoint.checkpoint(
402                    block,
403                    hidden_states=hidden_states,
404                    attention_mask=attention_mask,
405                    encoder_hidden_states=encoder_hidden_states,
406                    encoder_attention_mask=encoder_hidden_mask,
407                    rotary_freqs_cis=rotary_freqs_cis,
408                    rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
409                    temb=temb,
410                    use_reentrant=False,
411                )
412
413            else:
414                hidden_states = block(
415                    hidden_states=hidden_states,
416                    attention_mask=attention_mask,
417                    encoder_hidden_states=encoder_hidden_states,
418                    encoder_attention_mask=encoder_hidden_mask,
419                    rotary_freqs_cis=rotary_freqs_cis,
420                    rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
421                    temb=temb,
422                )
423
424            for ssl_encoder_depth in self.ssl_encoder_depths:
425                if index_block == ssl_encoder_depth:
426                    inner_hidden_states.append(hidden_states)
427
428        proj_losses = []
429        if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0:
430            for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names):
431                if ssl_hidden_state is None:
432                    continue
433                # 1. N x T x D1 -> N x D x D2
434                est_ssl_hidden_state = projector(inner_hidden_state)
435                # 3. projection loss
436                bs = inner_hidden_state.shape[0]
437                proj_loss = 0.0
438                for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)):
439                    # 2. interpolate
440                    z_tilde = (
441                        F.interpolate(
442                            z_tilde.unsqueeze(0).transpose(1, 2),
443                            size=len(z),
444                            mode="linear",
445                            align_corners=False,
446                        )
447                        .transpose(1, 2)
448                        .squeeze(0)
449                    )
450
451                    z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1)
452                    z = torch.nn.functional.normalize(z, dim=-1)
453                    # T x d -> T x 1 -> 1
454                    target = torch.ones(z.shape[0], device=z.device)
455                    proj_loss += self.cosine_loss(z, z_tilde, target)
456                proj_losses.append((ssl_name, proj_loss / bs))
457
458        output = self.final_layer(hidden_states, embedded_timestep, output_length)
459        if not return_dict:
460            return (output, proj_losses)
461
462        return Transformer2DModelOutput(sample=output, proj_losses=proj_losses)
def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, encoder_text_hidden_states: Optional[torch.Tensor] = None, text_attention_mask: Optional[torch.LongTensor] = None, speaker_embeds: Optional[torch.FloatTensor] = None, lyric_token_idx: Optional[torch.LongTensor] = None, lyric_mask: Optional[torch.LongTensor] = None, timestep: Optional[torch.Tensor] = None, ssl_hidden_states: Optional[List[torch.Tensor]] = None, block_controlnet_hidden_states: Union[List[torch.Tensor], torch.Tensor, NoneType] = None, controlnet_scale: Union[float, torch.Tensor] = 1.0, return_dict: bool = True):
465    def forward(
466        self,
467        hidden_states: torch.Tensor,
468        attention_mask: torch.Tensor,
469        encoder_text_hidden_states: Optional[torch.Tensor] = None,
470        text_attention_mask: Optional[torch.LongTensor] = None,
471        speaker_embeds: Optional[torch.FloatTensor] = None,
472        lyric_token_idx: Optional[torch.LongTensor] = None,
473        lyric_mask: Optional[torch.LongTensor] = None,
474        timestep: Optional[torch.Tensor] = None,
475        ssl_hidden_states: Optional[List[torch.Tensor]] = None,
476        block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
477        controlnet_scale: Union[float, torch.Tensor] = 1.0,
478        return_dict: bool = True,
479    ):
480        encoder_hidden_states, encoder_hidden_mask = self.encode(
481            encoder_text_hidden_states=encoder_text_hidden_states,
482            text_attention_mask=text_attention_mask,
483            speaker_embeds=speaker_embeds,
484            lyric_token_idx=lyric_token_idx,
485            lyric_mask=lyric_mask,
486        )
487
488        output_length = hidden_states.shape[-1]
489
490        output = self.decode(
491            hidden_states=hidden_states,
492            attention_mask=attention_mask,
493            encoder_hidden_states=encoder_hidden_states,
494            encoder_hidden_mask=encoder_hidden_mask,
495            timestep=timestep,
496            ssl_hidden_states=ssl_hidden_states,
497            output_length=output_length,
498            block_controlnet_hidden_states=block_controlnet_hidden_states,
499            controlnet_scale=controlnet_scale,
500            return_dict=return_dict,
501        )
502
503        return output

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.