divisor.acestep.models.ace_step_transformer
1# Copyright 2024 The HuggingFace Team. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14from dataclasses import dataclass 15from typing import Any, Dict, Optional, Tuple, List, Union 16 17import torch 18import torch.nn.functional as F 19from torch import nn 20 21from diffusers.configuration_utils import ConfigMixin, register_to_config 22from diffusers.utils import BaseOutput, is_torch_version 23from diffusers.models.modeling_utils import ModelMixin 24from diffusers.models.embeddings import TimestepEmbedding, Timesteps 25from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 26 27 28from divisor.acestep.models.attention import LinearTransformerBlock, t2i_modulate 29from divisor.acestep.models.lyrics_utils.lyric_encoder import ConformerEncoder as LyricEncoder 30 31 32def cross_norm(hidden_states, controlnet_input): 33 # input N x T x c 34 mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1, 2), keepdim=True), hidden_states.std(dim=(1, 2), keepdim=True) 35 mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1, 2), keepdim=True), controlnet_input.std(dim=(1, 2), keepdim=True) 36 controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states 37 return controlnet_input 38 39 40# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2 41class Qwen2RotaryEmbedding(nn.Module): 42 def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 43 super().__init__() 44 45 self.dim = dim 46 self.max_position_embeddings = max_position_embeddings 47 self.base = base 48 inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) 49 self.register_buffer("inv_freq", inv_freq, persistent=False) 50 51 # Build here to make `torch.jit.trace` work. 52 self._set_cos_sin_cache( 53 seq_len=max_position_embeddings, 54 device=self.inv_freq.device, 55 dtype=torch.get_default_dtype(), 56 ) 57 58 def _set_cos_sin_cache(self, seq_len, device, dtype): 59 self.max_seq_len_cached = seq_len 60 t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) 61 62 freqs = torch.outer(t, self.inv_freq) 63 # Different from paper, but it uses a different permutation in order to obtain the same calculation 64 emb = torch.cat((freqs, freqs), dim=-1) 65 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 66 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 67 68 def forward(self, x, seq_len=None): 69 # x: [bs, num_attention_heads, seq_len, head_size] 70 if seq_len > self.max_seq_len_cached: 71 self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 72 73 return ( 74 self.cos_cached[:seq_len].to(dtype=x.dtype), 75 self.sin_cached[:seq_len].to(dtype=x.dtype), 76 ) 77 78 79class T2IFinalLayer(nn.Module): 80 """ 81 The final layer of Sana. 82 """ 83 84 def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256): 85 super().__init__() 86 self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) 87 self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True) 88 self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) 89 self.out_channels = out_channels 90 self.patch_size = patch_size 91 92 def unpatchfy( 93 self, 94 hidden_states: torch.Tensor, 95 width: int, 96 ): 97 # 4 unpatchify 98 new_height, new_width = 1, hidden_states.size(1) 99 hidden_states = hidden_states.reshape( 100 shape=( 101 hidden_states.shape[0], 102 new_height, 103 new_width, 104 self.patch_size[0], 105 self.patch_size[1], 106 self.out_channels, 107 ) 108 ).contiguous() 109 hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 110 output = hidden_states.reshape( 111 shape=( 112 hidden_states.shape[0], 113 self.out_channels, 114 new_height * self.patch_size[0], 115 new_width * self.patch_size[1], 116 ) 117 ).contiguous() 118 if width > new_width: 119 output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), "constant", 0) 120 elif width < new_width: 121 output = output[:, :, :, :width] 122 return output 123 124 def forward(self, x, t, output_length): 125 shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) 126 x = t2i_modulate(self.norm_final(x), shift, scale) 127 x = self.linear(x) 128 # unpatchify 129 output = self.unpatchfy(x, output_length) 130 return output 131 132 133class PatchEmbed(nn.Module): 134 """2D Image to Patch Embedding""" 135 136 def __init__( 137 self, 138 height=16, 139 width=4096, 140 patch_size=(16, 1), 141 in_channels=8, 142 embed_dim=1152, 143 bias=True, 144 ): 145 super().__init__() 146 patch_size_h, patch_size_w = patch_size 147 self.early_conv_layers = nn.Sequential( 148 nn.Conv2d( 149 in_channels, 150 in_channels * 256, 151 kernel_size=patch_size, 152 stride=patch_size, 153 padding=0, 154 bias=bias, 155 ), 156 torch.nn.GroupNorm(num_groups=32, num_channels=in_channels * 256, eps=1e-6, affine=True), 157 nn.Conv2d( 158 in_channels * 256, 159 embed_dim, 160 kernel_size=1, 161 stride=1, 162 padding=0, 163 bias=bias, 164 ), 165 ) 166 self.patch_size = patch_size 167 self.height, self.width = height // patch_size_h, width // patch_size_w 168 self.base_size = self.width 169 170 def forward(self, latent): 171 # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size 172 latent = self.early_conv_layers(latent) 173 latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 174 return latent 175 176 177@dataclass 178class Transformer2DModelOutput(BaseOutput): 179 sample: torch.FloatTensor 180 proj_losses: Optional[Tuple[Tuple[str, torch.Tensor]]] = None 181 182 183class ACEStepTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 184 _supports_gradient_checkpointing = True 185 186 @register_to_config 187 def __init__( 188 self, 189 in_channels: Optional[int] = 8, 190 num_layers: int = 28, 191 inner_dim: int = 1536, 192 attention_head_dim: int = 64, 193 num_attention_heads: int = 24, 194 mlp_ratio: float = 4.0, 195 out_channels: int = 8, 196 max_position: int = 32768, 197 rope_theta: float = 1000000.0, 198 speaker_embedding_dim: int = 512, 199 text_embedding_dim: int = 768, 200 ssl_encoder_depths: List[int] = [9, 9], 201 ssl_names: List[str] = ["mert", "m-hubert"], 202 ssl_latent_dims: List[int] = [1024, 768], 203 lyric_encoder_vocab_size: int = 6681, 204 lyric_hidden_size: int = 1024, 205 patch_size: List[int] = [16, 1], 206 max_height: int = 16, 207 max_width: int = 4096, 208 **kwargs, 209 ): 210 super().__init__() 211 212 self.num_attention_heads = num_attention_heads 213 self.attention_head_dim = attention_head_dim 214 inner_dim = num_attention_heads * attention_head_dim 215 self.inner_dim = inner_dim 216 self.out_channels = out_channels 217 self.max_position = max_position 218 self.patch_size = patch_size 219 220 self.rope_theta = rope_theta 221 222 self.rotary_emb = Qwen2RotaryEmbedding( 223 dim=self.attention_head_dim, 224 max_position_embeddings=self.max_position, 225 base=self.rope_theta, 226 ) 227 228 # 2. Define input layers 229 self.in_channels = in_channels 230 231 # 3. Define transformers blocks 232 self.transformer_blocks = nn.ModuleList( 233 [ 234 LinearTransformerBlock( 235 dim=self.inner_dim, 236 num_attention_heads=self.num_attention_heads, 237 attention_head_dim=attention_head_dim, 238 mlp_ratio=mlp_ratio, 239 add_cross_attention=True, 240 add_cross_attention_dim=self.inner_dim, 241 ) 242 for i in range(self.config.num_layers) 243 ] 244 ) 245 self.num_layers = num_layers 246 247 self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 248 self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) 249 self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True)) 250 251 # speaker 252 self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim) 253 254 # genre 255 self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim) 256 257 # lyric 258 self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size) 259 self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0) 260 self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim) 261 262 projector_dim = 2 * self.inner_dim 263 264 self.projectors = nn.ModuleList( 265 [ 266 nn.Sequential( 267 nn.Linear(self.inner_dim, projector_dim), 268 nn.SiLU(), 269 nn.Linear(projector_dim, projector_dim), 270 nn.SiLU(), 271 nn.Linear(projector_dim, ssl_dim), 272 ) 273 for ssl_dim in ssl_latent_dims 274 ] 275 ) 276 277 self.ssl_latent_dims = ssl_latent_dims 278 self.ssl_encoder_depths = ssl_encoder_depths 279 280 self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction="mean") 281 self.ssl_names = ssl_names 282 283 self.proj_in = PatchEmbed( 284 height=max_height, 285 width=max_width, 286 patch_size=patch_size, 287 embed_dim=self.inner_dim, 288 bias=True, 289 ) 290 291 self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels) 292 self.gradient_checkpointing = False 293 294 # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 295 def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 296 """ 297 Sets the attention processor to use [feed forward 298 chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 299 300 Parameters: 301 chunk_size (`int`, *optional*): 302 The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 303 over each tensor of dim=`dim`. 304 dim (`int`, *optional*, defaults to `0`): 305 The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 306 or dim=1 (sequence length). 307 """ 308 if dim not in [0, 1]: 309 raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 310 311 # By default chunk size is 1 312 chunk_size = chunk_size or 1 313 314 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 315 if hasattr(module, "set_chunk_feed_forward"): 316 module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 317 318 for child in module.children(): 319 fn_recursive_feed_forward(child, chunk_size, dim) 320 321 for module in self.children(): 322 fn_recursive_feed_forward(module, chunk_size, dim) 323 324 def forward_lyric_encoder( 325 self, 326 lyric_token_idx: Optional[torch.LongTensor] = None, 327 lyric_mask: Optional[torch.LongTensor] = None, 328 ): 329 # N x T x D 330 lyric_embs = self.lyric_embs(lyric_token_idx) 331 prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1) 332 prompt_prenet_out = self.lyric_proj(prompt_prenet_out) 333 return prompt_prenet_out 334 335 def encode( 336 self, 337 encoder_text_hidden_states: Optional[torch.Tensor] = None, 338 text_attention_mask: Optional[torch.LongTensor] = None, 339 speaker_embeds: Optional[torch.FloatTensor] = None, 340 lyric_token_idx: Optional[torch.LongTensor] = None, 341 lyric_mask: Optional[torch.LongTensor] = None, 342 ): 343 bs = encoder_text_hidden_states.shape[0] 344 device = encoder_text_hidden_states.device 345 346 # speaker embedding 347 encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1) 348 speaker_mask = torch.ones(bs, 1, device=device) 349 350 # genre embedding 351 encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states) 352 353 # lyric 354 encoder_lyric_hidden_states = self.forward_lyric_encoder( 355 lyric_token_idx=lyric_token_idx, 356 lyric_mask=lyric_mask, 357 ) 358 359 encoder_hidden_states = torch.cat( 360 [ 361 encoder_spk_hidden_states, 362 encoder_text_hidden_states, 363 encoder_lyric_hidden_states, 364 ], 365 dim=1, 366 ) 367 encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1) 368 return encoder_hidden_states, encoder_hidden_mask 369 370 def decode( 371 self, 372 hidden_states: torch.Tensor, 373 attention_mask: torch.Tensor, 374 encoder_hidden_states: torch.Tensor, 375 encoder_hidden_mask: torch.Tensor, 376 timestep: Optional[torch.Tensor], 377 ssl_hidden_states: Optional[List[torch.Tensor]] = None, 378 output_length: int = 0, 379 block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, 380 controlnet_scale: Union[float, torch.Tensor] = 1.0, 381 return_dict: bool = True, 382 ): 383 embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) 384 temb = self.t_block(embedded_timestep) 385 386 hidden_states = self.proj_in(hidden_states) 387 388 # controlnet logic 389 if block_controlnet_hidden_states is not None: 390 control_condi = cross_norm(hidden_states, block_controlnet_hidden_states) 391 hidden_states = hidden_states + control_condi * controlnet_scale 392 393 inner_hidden_states = [] 394 395 rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1]) 396 encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1]) 397 398 for index_block, block in enumerate(self.transformer_blocks): 399 if self.training and self.gradient_checkpointing: 400 hidden_states = torch.utils.checkpoint.checkpoint( 401 block, 402 hidden_states=hidden_states, 403 attention_mask=attention_mask, 404 encoder_hidden_states=encoder_hidden_states, 405 encoder_attention_mask=encoder_hidden_mask, 406 rotary_freqs_cis=rotary_freqs_cis, 407 rotary_freqs_cis_cross=encoder_rotary_freqs_cis, 408 temb=temb, 409 use_reentrant=False, 410 ) 411 412 else: 413 hidden_states = block( 414 hidden_states=hidden_states, 415 attention_mask=attention_mask, 416 encoder_hidden_states=encoder_hidden_states, 417 encoder_attention_mask=encoder_hidden_mask, 418 rotary_freqs_cis=rotary_freqs_cis, 419 rotary_freqs_cis_cross=encoder_rotary_freqs_cis, 420 temb=temb, 421 ) 422 423 for ssl_encoder_depth in self.ssl_encoder_depths: 424 if index_block == ssl_encoder_depth: 425 inner_hidden_states.append(hidden_states) 426 427 proj_losses = [] 428 if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0: 429 for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names): 430 if ssl_hidden_state is None: 431 continue 432 # 1. N x T x D1 -> N x D x D2 433 est_ssl_hidden_state = projector(inner_hidden_state) 434 # 3. projection loss 435 bs = inner_hidden_state.shape[0] 436 proj_loss = 0.0 437 for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)): 438 # 2. interpolate 439 z_tilde = ( 440 F.interpolate( 441 z_tilde.unsqueeze(0).transpose(1, 2), 442 size=len(z), 443 mode="linear", 444 align_corners=False, 445 ) 446 .transpose(1, 2) 447 .squeeze(0) 448 ) 449 450 z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1) 451 z = torch.nn.functional.normalize(z, dim=-1) 452 # T x d -> T x 1 -> 1 453 target = torch.ones(z.shape[0], device=z.device) 454 proj_loss += self.cosine_loss(z, z_tilde, target) 455 proj_losses.append((ssl_name, proj_loss / bs)) 456 457 output = self.final_layer(hidden_states, embedded_timestep, output_length) 458 if not return_dict: 459 return (output, proj_losses) 460 461 return Transformer2DModelOutput(sample=output, proj_losses=proj_losses) 462 463 # @torch.compile 464 def forward( 465 self, 466 hidden_states: torch.Tensor, 467 attention_mask: torch.Tensor, 468 encoder_text_hidden_states: Optional[torch.Tensor] = None, 469 text_attention_mask: Optional[torch.LongTensor] = None, 470 speaker_embeds: Optional[torch.FloatTensor] = None, 471 lyric_token_idx: Optional[torch.LongTensor] = None, 472 lyric_mask: Optional[torch.LongTensor] = None, 473 timestep: Optional[torch.Tensor] = None, 474 ssl_hidden_states: Optional[List[torch.Tensor]] = None, 475 block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, 476 controlnet_scale: Union[float, torch.Tensor] = 1.0, 477 return_dict: bool = True, 478 ): 479 encoder_hidden_states, encoder_hidden_mask = self.encode( 480 encoder_text_hidden_states=encoder_text_hidden_states, 481 text_attention_mask=text_attention_mask, 482 speaker_embeds=speaker_embeds, 483 lyric_token_idx=lyric_token_idx, 484 lyric_mask=lyric_mask, 485 ) 486 487 output_length = hidden_states.shape[-1] 488 489 output = self.decode( 490 hidden_states=hidden_states, 491 attention_mask=attention_mask, 492 encoder_hidden_states=encoder_hidden_states, 493 encoder_hidden_mask=encoder_hidden_mask, 494 timestep=timestep, 495 ssl_hidden_states=ssl_hidden_states, 496 output_length=output_length, 497 block_controlnet_hidden_states=block_controlnet_hidden_states, 498 controlnet_scale=controlnet_scale, 499 return_dict=return_dict, 500 ) 501 502 return output
33def cross_norm(hidden_states, controlnet_input): 34 # input N x T x c 35 mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1, 2), keepdim=True), hidden_states.std(dim=(1, 2), keepdim=True) 36 mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1, 2), keepdim=True), controlnet_input.std(dim=(1, 2), keepdim=True) 37 controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states 38 return controlnet_input
42class Qwen2RotaryEmbedding(nn.Module): 43 def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 44 super().__init__() 45 46 self.dim = dim 47 self.max_position_embeddings = max_position_embeddings 48 self.base = base 49 inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) 50 self.register_buffer("inv_freq", inv_freq, persistent=False) 51 52 # Build here to make `torch.jit.trace` work. 53 self._set_cos_sin_cache( 54 seq_len=max_position_embeddings, 55 device=self.inv_freq.device, 56 dtype=torch.get_default_dtype(), 57 ) 58 59 def _set_cos_sin_cache(self, seq_len, device, dtype): 60 self.max_seq_len_cached = seq_len 61 t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) 62 63 freqs = torch.outer(t, self.inv_freq) 64 # Different from paper, but it uses a different permutation in order to obtain the same calculation 65 emb = torch.cat((freqs, freqs), dim=-1) 66 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 67 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 68 69 def forward(self, x, seq_len=None): 70 # x: [bs, num_attention_heads, seq_len, head_size] 71 if seq_len > self.max_seq_len_cached: 72 self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 73 74 return ( 75 self.cos_cached[:seq_len].to(dtype=x.dtype), 76 self.sin_cached[:seq_len].to(dtype=x.dtype), 77 )
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
43 def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 44 super().__init__() 45 46 self.dim = dim 47 self.max_position_embeddings = max_position_embeddings 48 self.base = base 49 inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) 50 self.register_buffer("inv_freq", inv_freq, persistent=False) 51 52 # Build here to make `torch.jit.trace` work. 53 self._set_cos_sin_cache( 54 seq_len=max_position_embeddings, 55 device=self.inv_freq.device, 56 dtype=torch.get_default_dtype(), 57 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
69 def forward(self, x, seq_len=None): 70 # x: [bs, num_attention_heads, seq_len, head_size] 71 if seq_len > self.max_seq_len_cached: 72 self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 73 74 return ( 75 self.cos_cached[:seq_len].to(dtype=x.dtype), 76 self.sin_cached[:seq_len].to(dtype=x.dtype), 77 )
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
80class T2IFinalLayer(nn.Module): 81 """ 82 The final layer of Sana. 83 """ 84 85 def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256): 86 super().__init__() 87 self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) 88 self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True) 89 self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) 90 self.out_channels = out_channels 91 self.patch_size = patch_size 92 93 def unpatchfy( 94 self, 95 hidden_states: torch.Tensor, 96 width: int, 97 ): 98 # 4 unpatchify 99 new_height, new_width = 1, hidden_states.size(1) 100 hidden_states = hidden_states.reshape( 101 shape=( 102 hidden_states.shape[0], 103 new_height, 104 new_width, 105 self.patch_size[0], 106 self.patch_size[1], 107 self.out_channels, 108 ) 109 ).contiguous() 110 hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 111 output = hidden_states.reshape( 112 shape=( 113 hidden_states.shape[0], 114 self.out_channels, 115 new_height * self.patch_size[0], 116 new_width * self.patch_size[1], 117 ) 118 ).contiguous() 119 if width > new_width: 120 output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), "constant", 0) 121 elif width < new_width: 122 output = output[:, :, :, :width] 123 return output 124 125 def forward(self, x, t, output_length): 126 shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) 127 x = t2i_modulate(self.norm_final(x), shift, scale) 128 x = self.linear(x) 129 # unpatchify 130 output = self.unpatchfy(x, output_length) 131 return output
The final layer of Sana.
85 def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256): 86 super().__init__() 87 self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) 88 self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True) 89 self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) 90 self.out_channels = out_channels 91 self.patch_size = patch_size
Initialize internal Module state, shared by both nn.Module and ScriptModule.
93 def unpatchfy( 94 self, 95 hidden_states: torch.Tensor, 96 width: int, 97 ): 98 # 4 unpatchify 99 new_height, new_width = 1, hidden_states.size(1) 100 hidden_states = hidden_states.reshape( 101 shape=( 102 hidden_states.shape[0], 103 new_height, 104 new_width, 105 self.patch_size[0], 106 self.patch_size[1], 107 self.out_channels, 108 ) 109 ).contiguous() 110 hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 111 output = hidden_states.reshape( 112 shape=( 113 hidden_states.shape[0], 114 self.out_channels, 115 new_height * self.patch_size[0], 116 new_width * self.patch_size[1], 117 ) 118 ).contiguous() 119 if width > new_width: 120 output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), "constant", 0) 121 elif width < new_width: 122 output = output[:, :, :, :width] 123 return output
125 def forward(self, x, t, output_length): 126 shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) 127 x = t2i_modulate(self.norm_final(x), shift, scale) 128 x = self.linear(x) 129 # unpatchify 130 output = self.unpatchfy(x, output_length) 131 return output
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
134class PatchEmbed(nn.Module): 135 """2D Image to Patch Embedding""" 136 137 def __init__( 138 self, 139 height=16, 140 width=4096, 141 patch_size=(16, 1), 142 in_channels=8, 143 embed_dim=1152, 144 bias=True, 145 ): 146 super().__init__() 147 patch_size_h, patch_size_w = patch_size 148 self.early_conv_layers = nn.Sequential( 149 nn.Conv2d( 150 in_channels, 151 in_channels * 256, 152 kernel_size=patch_size, 153 stride=patch_size, 154 padding=0, 155 bias=bias, 156 ), 157 torch.nn.GroupNorm(num_groups=32, num_channels=in_channels * 256, eps=1e-6, affine=True), 158 nn.Conv2d( 159 in_channels * 256, 160 embed_dim, 161 kernel_size=1, 162 stride=1, 163 padding=0, 164 bias=bias, 165 ), 166 ) 167 self.patch_size = patch_size 168 self.height, self.width = height // patch_size_h, width // patch_size_w 169 self.base_size = self.width 170 171 def forward(self, latent): 172 # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size 173 latent = self.early_conv_layers(latent) 174 latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 175 return latent
2D Image to Patch Embedding
137 def __init__( 138 self, 139 height=16, 140 width=4096, 141 patch_size=(16, 1), 142 in_channels=8, 143 embed_dim=1152, 144 bias=True, 145 ): 146 super().__init__() 147 patch_size_h, patch_size_w = patch_size 148 self.early_conv_layers = nn.Sequential( 149 nn.Conv2d( 150 in_channels, 151 in_channels * 256, 152 kernel_size=patch_size, 153 stride=patch_size, 154 padding=0, 155 bias=bias, 156 ), 157 torch.nn.GroupNorm(num_groups=32, num_channels=in_channels * 256, eps=1e-6, affine=True), 158 nn.Conv2d( 159 in_channels * 256, 160 embed_dim, 161 kernel_size=1, 162 stride=1, 163 padding=0, 164 bias=bias, 165 ), 166 ) 167 self.patch_size = patch_size 168 self.height, self.width = height // patch_size_h, width // patch_size_w 169 self.base_size = self.width
Initialize internal Module state, shared by both nn.Module and ScriptModule.
171 def forward(self, latent): 172 # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size 173 latent = self.early_conv_layers(latent) 174 latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 175 return latent
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
184class ACEStepTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 185 _supports_gradient_checkpointing = True 186 187 @register_to_config 188 def __init__( 189 self, 190 in_channels: Optional[int] = 8, 191 num_layers: int = 28, 192 inner_dim: int = 1536, 193 attention_head_dim: int = 64, 194 num_attention_heads: int = 24, 195 mlp_ratio: float = 4.0, 196 out_channels: int = 8, 197 max_position: int = 32768, 198 rope_theta: float = 1000000.0, 199 speaker_embedding_dim: int = 512, 200 text_embedding_dim: int = 768, 201 ssl_encoder_depths: List[int] = [9, 9], 202 ssl_names: List[str] = ["mert", "m-hubert"], 203 ssl_latent_dims: List[int] = [1024, 768], 204 lyric_encoder_vocab_size: int = 6681, 205 lyric_hidden_size: int = 1024, 206 patch_size: List[int] = [16, 1], 207 max_height: int = 16, 208 max_width: int = 4096, 209 **kwargs, 210 ): 211 super().__init__() 212 213 self.num_attention_heads = num_attention_heads 214 self.attention_head_dim = attention_head_dim 215 inner_dim = num_attention_heads * attention_head_dim 216 self.inner_dim = inner_dim 217 self.out_channels = out_channels 218 self.max_position = max_position 219 self.patch_size = patch_size 220 221 self.rope_theta = rope_theta 222 223 self.rotary_emb = Qwen2RotaryEmbedding( 224 dim=self.attention_head_dim, 225 max_position_embeddings=self.max_position, 226 base=self.rope_theta, 227 ) 228 229 # 2. Define input layers 230 self.in_channels = in_channels 231 232 # 3. Define transformers blocks 233 self.transformer_blocks = nn.ModuleList( 234 [ 235 LinearTransformerBlock( 236 dim=self.inner_dim, 237 num_attention_heads=self.num_attention_heads, 238 attention_head_dim=attention_head_dim, 239 mlp_ratio=mlp_ratio, 240 add_cross_attention=True, 241 add_cross_attention_dim=self.inner_dim, 242 ) 243 for i in range(self.config.num_layers) 244 ] 245 ) 246 self.num_layers = num_layers 247 248 self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 249 self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) 250 self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True)) 251 252 # speaker 253 self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim) 254 255 # genre 256 self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim) 257 258 # lyric 259 self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size) 260 self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0) 261 self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim) 262 263 projector_dim = 2 * self.inner_dim 264 265 self.projectors = nn.ModuleList( 266 [ 267 nn.Sequential( 268 nn.Linear(self.inner_dim, projector_dim), 269 nn.SiLU(), 270 nn.Linear(projector_dim, projector_dim), 271 nn.SiLU(), 272 nn.Linear(projector_dim, ssl_dim), 273 ) 274 for ssl_dim in ssl_latent_dims 275 ] 276 ) 277 278 self.ssl_latent_dims = ssl_latent_dims 279 self.ssl_encoder_depths = ssl_encoder_depths 280 281 self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction="mean") 282 self.ssl_names = ssl_names 283 284 self.proj_in = PatchEmbed( 285 height=max_height, 286 width=max_width, 287 patch_size=patch_size, 288 embed_dim=self.inner_dim, 289 bias=True, 290 ) 291 292 self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels) 293 self.gradient_checkpointing = False 294 295 # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 296 def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 297 """ 298 Sets the attention processor to use [feed forward 299 chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 300 301 Parameters: 302 chunk_size (`int`, *optional*): 303 The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 304 over each tensor of dim=`dim`. 305 dim (`int`, *optional*, defaults to `0`): 306 The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 307 or dim=1 (sequence length). 308 """ 309 if dim not in [0, 1]: 310 raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 311 312 # By default chunk size is 1 313 chunk_size = chunk_size or 1 314 315 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 316 if hasattr(module, "set_chunk_feed_forward"): 317 module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 318 319 for child in module.children(): 320 fn_recursive_feed_forward(child, chunk_size, dim) 321 322 for module in self.children(): 323 fn_recursive_feed_forward(module, chunk_size, dim) 324 325 def forward_lyric_encoder( 326 self, 327 lyric_token_idx: Optional[torch.LongTensor] = None, 328 lyric_mask: Optional[torch.LongTensor] = None, 329 ): 330 # N x T x D 331 lyric_embs = self.lyric_embs(lyric_token_idx) 332 prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1) 333 prompt_prenet_out = self.lyric_proj(prompt_prenet_out) 334 return prompt_prenet_out 335 336 def encode( 337 self, 338 encoder_text_hidden_states: Optional[torch.Tensor] = None, 339 text_attention_mask: Optional[torch.LongTensor] = None, 340 speaker_embeds: Optional[torch.FloatTensor] = None, 341 lyric_token_idx: Optional[torch.LongTensor] = None, 342 lyric_mask: Optional[torch.LongTensor] = None, 343 ): 344 bs = encoder_text_hidden_states.shape[0] 345 device = encoder_text_hidden_states.device 346 347 # speaker embedding 348 encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1) 349 speaker_mask = torch.ones(bs, 1, device=device) 350 351 # genre embedding 352 encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states) 353 354 # lyric 355 encoder_lyric_hidden_states = self.forward_lyric_encoder( 356 lyric_token_idx=lyric_token_idx, 357 lyric_mask=lyric_mask, 358 ) 359 360 encoder_hidden_states = torch.cat( 361 [ 362 encoder_spk_hidden_states, 363 encoder_text_hidden_states, 364 encoder_lyric_hidden_states, 365 ], 366 dim=1, 367 ) 368 encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1) 369 return encoder_hidden_states, encoder_hidden_mask 370 371 def decode( 372 self, 373 hidden_states: torch.Tensor, 374 attention_mask: torch.Tensor, 375 encoder_hidden_states: torch.Tensor, 376 encoder_hidden_mask: torch.Tensor, 377 timestep: Optional[torch.Tensor], 378 ssl_hidden_states: Optional[List[torch.Tensor]] = None, 379 output_length: int = 0, 380 block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, 381 controlnet_scale: Union[float, torch.Tensor] = 1.0, 382 return_dict: bool = True, 383 ): 384 embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) 385 temb = self.t_block(embedded_timestep) 386 387 hidden_states = self.proj_in(hidden_states) 388 389 # controlnet logic 390 if block_controlnet_hidden_states is not None: 391 control_condi = cross_norm(hidden_states, block_controlnet_hidden_states) 392 hidden_states = hidden_states + control_condi * controlnet_scale 393 394 inner_hidden_states = [] 395 396 rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1]) 397 encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1]) 398 399 for index_block, block in enumerate(self.transformer_blocks): 400 if self.training and self.gradient_checkpointing: 401 hidden_states = torch.utils.checkpoint.checkpoint( 402 block, 403 hidden_states=hidden_states, 404 attention_mask=attention_mask, 405 encoder_hidden_states=encoder_hidden_states, 406 encoder_attention_mask=encoder_hidden_mask, 407 rotary_freqs_cis=rotary_freqs_cis, 408 rotary_freqs_cis_cross=encoder_rotary_freqs_cis, 409 temb=temb, 410 use_reentrant=False, 411 ) 412 413 else: 414 hidden_states = block( 415 hidden_states=hidden_states, 416 attention_mask=attention_mask, 417 encoder_hidden_states=encoder_hidden_states, 418 encoder_attention_mask=encoder_hidden_mask, 419 rotary_freqs_cis=rotary_freqs_cis, 420 rotary_freqs_cis_cross=encoder_rotary_freqs_cis, 421 temb=temb, 422 ) 423 424 for ssl_encoder_depth in self.ssl_encoder_depths: 425 if index_block == ssl_encoder_depth: 426 inner_hidden_states.append(hidden_states) 427 428 proj_losses = [] 429 if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0: 430 for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names): 431 if ssl_hidden_state is None: 432 continue 433 # 1. N x T x D1 -> N x D x D2 434 est_ssl_hidden_state = projector(inner_hidden_state) 435 # 3. projection loss 436 bs = inner_hidden_state.shape[0] 437 proj_loss = 0.0 438 for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)): 439 # 2. interpolate 440 z_tilde = ( 441 F.interpolate( 442 z_tilde.unsqueeze(0).transpose(1, 2), 443 size=len(z), 444 mode="linear", 445 align_corners=False, 446 ) 447 .transpose(1, 2) 448 .squeeze(0) 449 ) 450 451 z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1) 452 z = torch.nn.functional.normalize(z, dim=-1) 453 # T x d -> T x 1 -> 1 454 target = torch.ones(z.shape[0], device=z.device) 455 proj_loss += self.cosine_loss(z, z_tilde, target) 456 proj_losses.append((ssl_name, proj_loss / bs)) 457 458 output = self.final_layer(hidden_states, embedded_timestep, output_length) 459 if not return_dict: 460 return (output, proj_losses) 461 462 return Transformer2DModelOutput(sample=output, proj_losses=proj_losses) 463 464 # @torch.compile 465 def forward( 466 self, 467 hidden_states: torch.Tensor, 468 attention_mask: torch.Tensor, 469 encoder_text_hidden_states: Optional[torch.Tensor] = None, 470 text_attention_mask: Optional[torch.LongTensor] = None, 471 speaker_embeds: Optional[torch.FloatTensor] = None, 472 lyric_token_idx: Optional[torch.LongTensor] = None, 473 lyric_mask: Optional[torch.LongTensor] = None, 474 timestep: Optional[torch.Tensor] = None, 475 ssl_hidden_states: Optional[List[torch.Tensor]] = None, 476 block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, 477 controlnet_scale: Union[float, torch.Tensor] = 1.0, 478 return_dict: bool = True, 479 ): 480 encoder_hidden_states, encoder_hidden_mask = self.encode( 481 encoder_text_hidden_states=encoder_text_hidden_states, 482 text_attention_mask=text_attention_mask, 483 speaker_embeds=speaker_embeds, 484 lyric_token_idx=lyric_token_idx, 485 lyric_mask=lyric_mask, 486 ) 487 488 output_length = hidden_states.shape[-1] 489 490 output = self.decode( 491 hidden_states=hidden_states, 492 attention_mask=attention_mask, 493 encoder_hidden_states=encoder_hidden_states, 494 encoder_hidden_mask=encoder_hidden_mask, 495 timestep=timestep, 496 ssl_hidden_states=ssl_hidden_states, 497 output_length=output_length, 498 block_controlnet_hidden_states=block_controlnet_hidden_states, 499 controlnet_scale=controlnet_scale, 500 return_dict=return_dict, 501 ) 502 503 return output
Base class for all models.
[ModelMixin] takes care of storing the model configuration and provides methods for loading, downloading and
saving models.
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
187 @register_to_config 188 def __init__( 189 self, 190 in_channels: Optional[int] = 8, 191 num_layers: int = 28, 192 inner_dim: int = 1536, 193 attention_head_dim: int = 64, 194 num_attention_heads: int = 24, 195 mlp_ratio: float = 4.0, 196 out_channels: int = 8, 197 max_position: int = 32768, 198 rope_theta: float = 1000000.0, 199 speaker_embedding_dim: int = 512, 200 text_embedding_dim: int = 768, 201 ssl_encoder_depths: List[int] = [9, 9], 202 ssl_names: List[str] = ["mert", "m-hubert"], 203 ssl_latent_dims: List[int] = [1024, 768], 204 lyric_encoder_vocab_size: int = 6681, 205 lyric_hidden_size: int = 1024, 206 patch_size: List[int] = [16, 1], 207 max_height: int = 16, 208 max_width: int = 4096, 209 **kwargs, 210 ): 211 super().__init__() 212 213 self.num_attention_heads = num_attention_heads 214 self.attention_head_dim = attention_head_dim 215 inner_dim = num_attention_heads * attention_head_dim 216 self.inner_dim = inner_dim 217 self.out_channels = out_channels 218 self.max_position = max_position 219 self.patch_size = patch_size 220 221 self.rope_theta = rope_theta 222 223 self.rotary_emb = Qwen2RotaryEmbedding( 224 dim=self.attention_head_dim, 225 max_position_embeddings=self.max_position, 226 base=self.rope_theta, 227 ) 228 229 # 2. Define input layers 230 self.in_channels = in_channels 231 232 # 3. Define transformers blocks 233 self.transformer_blocks = nn.ModuleList( 234 [ 235 LinearTransformerBlock( 236 dim=self.inner_dim, 237 num_attention_heads=self.num_attention_heads, 238 attention_head_dim=attention_head_dim, 239 mlp_ratio=mlp_ratio, 240 add_cross_attention=True, 241 add_cross_attention_dim=self.inner_dim, 242 ) 243 for i in range(self.config.num_layers) 244 ] 245 ) 246 self.num_layers = num_layers 247 248 self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 249 self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) 250 self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True)) 251 252 # speaker 253 self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim) 254 255 # genre 256 self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim) 257 258 # lyric 259 self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size) 260 self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0) 261 self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim) 262 263 projector_dim = 2 * self.inner_dim 264 265 self.projectors = nn.ModuleList( 266 [ 267 nn.Sequential( 268 nn.Linear(self.inner_dim, projector_dim), 269 nn.SiLU(), 270 nn.Linear(projector_dim, projector_dim), 271 nn.SiLU(), 272 nn.Linear(projector_dim, ssl_dim), 273 ) 274 for ssl_dim in ssl_latent_dims 275 ] 276 ) 277 278 self.ssl_latent_dims = ssl_latent_dims 279 self.ssl_encoder_depths = ssl_encoder_depths 280 281 self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction="mean") 282 self.ssl_names = ssl_names 283 284 self.proj_in = PatchEmbed( 285 height=max_height, 286 width=max_width, 287 patch_size=patch_size, 288 embed_dim=self.inner_dim, 289 bias=True, 290 ) 291 292 self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels) 293 self.gradient_checkpointing = False
Initialize internal Module state, shared by both nn.Module and ScriptModule.
296 def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 297 """ 298 Sets the attention processor to use [feed forward 299 chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 300 301 Parameters: 302 chunk_size (`int`, *optional*): 303 The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 304 over each tensor of dim=`dim`. 305 dim (`int`, *optional*, defaults to `0`): 306 The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 307 or dim=1 (sequence length). 308 """ 309 if dim not in [0, 1]: 310 raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 311 312 # By default chunk size is 1 313 chunk_size = chunk_size or 1 314 315 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 316 if hasattr(module, "set_chunk_feed_forward"): 317 module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 318 319 for child in module.children(): 320 fn_recursive_feed_forward(child, chunk_size, dim) 321 322 for module in self.children(): 323 fn_recursive_feed_forward(module, chunk_size, dim)
Sets the attention processor to use feed forward chunking.
Parameters:
chunk_size (int, optional):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=dim.
dim (int, optional, defaults to 0):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
325 def forward_lyric_encoder( 326 self, 327 lyric_token_idx: Optional[torch.LongTensor] = None, 328 lyric_mask: Optional[torch.LongTensor] = None, 329 ): 330 # N x T x D 331 lyric_embs = self.lyric_embs(lyric_token_idx) 332 prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1) 333 prompt_prenet_out = self.lyric_proj(prompt_prenet_out) 334 return prompt_prenet_out
336 def encode( 337 self, 338 encoder_text_hidden_states: Optional[torch.Tensor] = None, 339 text_attention_mask: Optional[torch.LongTensor] = None, 340 speaker_embeds: Optional[torch.FloatTensor] = None, 341 lyric_token_idx: Optional[torch.LongTensor] = None, 342 lyric_mask: Optional[torch.LongTensor] = None, 343 ): 344 bs = encoder_text_hidden_states.shape[0] 345 device = encoder_text_hidden_states.device 346 347 # speaker embedding 348 encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1) 349 speaker_mask = torch.ones(bs, 1, device=device) 350 351 # genre embedding 352 encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states) 353 354 # lyric 355 encoder_lyric_hidden_states = self.forward_lyric_encoder( 356 lyric_token_idx=lyric_token_idx, 357 lyric_mask=lyric_mask, 358 ) 359 360 encoder_hidden_states = torch.cat( 361 [ 362 encoder_spk_hidden_states, 363 encoder_text_hidden_states, 364 encoder_lyric_hidden_states, 365 ], 366 dim=1, 367 ) 368 encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1) 369 return encoder_hidden_states, encoder_hidden_mask
371 def decode( 372 self, 373 hidden_states: torch.Tensor, 374 attention_mask: torch.Tensor, 375 encoder_hidden_states: torch.Tensor, 376 encoder_hidden_mask: torch.Tensor, 377 timestep: Optional[torch.Tensor], 378 ssl_hidden_states: Optional[List[torch.Tensor]] = None, 379 output_length: int = 0, 380 block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, 381 controlnet_scale: Union[float, torch.Tensor] = 1.0, 382 return_dict: bool = True, 383 ): 384 embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) 385 temb = self.t_block(embedded_timestep) 386 387 hidden_states = self.proj_in(hidden_states) 388 389 # controlnet logic 390 if block_controlnet_hidden_states is not None: 391 control_condi = cross_norm(hidden_states, block_controlnet_hidden_states) 392 hidden_states = hidden_states + control_condi * controlnet_scale 393 394 inner_hidden_states = [] 395 396 rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1]) 397 encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1]) 398 399 for index_block, block in enumerate(self.transformer_blocks): 400 if self.training and self.gradient_checkpointing: 401 hidden_states = torch.utils.checkpoint.checkpoint( 402 block, 403 hidden_states=hidden_states, 404 attention_mask=attention_mask, 405 encoder_hidden_states=encoder_hidden_states, 406 encoder_attention_mask=encoder_hidden_mask, 407 rotary_freqs_cis=rotary_freqs_cis, 408 rotary_freqs_cis_cross=encoder_rotary_freqs_cis, 409 temb=temb, 410 use_reentrant=False, 411 ) 412 413 else: 414 hidden_states = block( 415 hidden_states=hidden_states, 416 attention_mask=attention_mask, 417 encoder_hidden_states=encoder_hidden_states, 418 encoder_attention_mask=encoder_hidden_mask, 419 rotary_freqs_cis=rotary_freqs_cis, 420 rotary_freqs_cis_cross=encoder_rotary_freqs_cis, 421 temb=temb, 422 ) 423 424 for ssl_encoder_depth in self.ssl_encoder_depths: 425 if index_block == ssl_encoder_depth: 426 inner_hidden_states.append(hidden_states) 427 428 proj_losses = [] 429 if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0: 430 for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names): 431 if ssl_hidden_state is None: 432 continue 433 # 1. N x T x D1 -> N x D x D2 434 est_ssl_hidden_state = projector(inner_hidden_state) 435 # 3. projection loss 436 bs = inner_hidden_state.shape[0] 437 proj_loss = 0.0 438 for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)): 439 # 2. interpolate 440 z_tilde = ( 441 F.interpolate( 442 z_tilde.unsqueeze(0).transpose(1, 2), 443 size=len(z), 444 mode="linear", 445 align_corners=False, 446 ) 447 .transpose(1, 2) 448 .squeeze(0) 449 ) 450 451 z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1) 452 z = torch.nn.functional.normalize(z, dim=-1) 453 # T x d -> T x 1 -> 1 454 target = torch.ones(z.shape[0], device=z.device) 455 proj_loss += self.cosine_loss(z, z_tilde, target) 456 proj_losses.append((ssl_name, proj_loss / bs)) 457 458 output = self.final_layer(hidden_states, embedded_timestep, output_length) 459 if not return_dict: 460 return (output, proj_losses) 461 462 return Transformer2DModelOutput(sample=output, proj_losses=proj_losses)
465 def forward( 466 self, 467 hidden_states: torch.Tensor, 468 attention_mask: torch.Tensor, 469 encoder_text_hidden_states: Optional[torch.Tensor] = None, 470 text_attention_mask: Optional[torch.LongTensor] = None, 471 speaker_embeds: Optional[torch.FloatTensor] = None, 472 lyric_token_idx: Optional[torch.LongTensor] = None, 473 lyric_mask: Optional[torch.LongTensor] = None, 474 timestep: Optional[torch.Tensor] = None, 475 ssl_hidden_states: Optional[List[torch.Tensor]] = None, 476 block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, 477 controlnet_scale: Union[float, torch.Tensor] = 1.0, 478 return_dict: bool = True, 479 ): 480 encoder_hidden_states, encoder_hidden_mask = self.encode( 481 encoder_text_hidden_states=encoder_text_hidden_states, 482 text_attention_mask=text_attention_mask, 483 speaker_embeds=speaker_embeds, 484 lyric_token_idx=lyric_token_idx, 485 lyric_mask=lyric_mask, 486 ) 487 488 output_length = hidden_states.shape[-1] 489 490 output = self.decode( 491 hidden_states=hidden_states, 492 attention_mask=attention_mask, 493 encoder_hidden_states=encoder_hidden_states, 494 encoder_hidden_mask=encoder_hidden_mask, 495 timestep=timestep, 496 ssl_hidden_states=ssl_hidden_states, 497 output_length=output_length, 498 block_controlnet_hidden_states=block_controlnet_hidden_states, 499 controlnet_scale=controlnet_scale, 500 return_dict=return_dict, 501 ) 502 503 return output
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.