divisor.acestep.pipeline_ace_step
ACE-Step: A Step Towards Music Generation Foundation Model
1# SPDX-License-Identifier:Apache-2.0 2# adapted from https://github.com/ace-step/ACE-Step 3 4""" 5ACE-Step: A Step Towards Music Generation Foundation Model 6""" 7 8import json 9import math 10import os 11import random 12import re 13import time 14from typing import Literal 15 16import torch 17import torchaudio 18from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import ( 19 retrieve_timesteps, 20) 21from diffusers.utils.peft_utils import set_weights_and_activate_adapters 22from diffusers.utils.torch_utils import randn_tensor 23from huggingface_hub import snapshot_download 24from nnll.console import nfo 25from tqdm import tqdm 26from transformers import AutoTokenizer, UMT5EncoderModel 27 28from divisor.acestep.apg_guidance import ( 29 MomentumBuffer, 30 apg_forward, 31 cfg_double_condition_forward, 32 cfg_forward, 33 cfg_zero_star, 34) 35from divisor.acestep.cpu_offload import cpu_offload 36from divisor.acestep.language_segmentation import LangSegment, language_filters 37from divisor.acestep.models.ace_step_transformer import ACEStepTransformer2DModel 38from divisor.acestep.models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer 39from divisor.acestep.music_dcae.music_dcae_pipeline import MusicDCAE 40from divisor.acestep.schedulers.scheduling_flow_match_euler_discrete import ( 41 FlowMatchEulerDiscreteScheduler, 42) 43from divisor.acestep.schedulers.scheduling_flow_match_heun_discrete import ( 44 FlowMatchHeunDiscreteScheduler, 45) 46from divisor.acestep.schedulers.scheduling_flow_match_pingpong import FlowMatchPingPongScheduler 47from divisor.registry import gfx_device, empty_cache 48 49if gfx_device.type == "cuda": 50 torch.backends.cudnn.benchmark = False 51 torch.set_float32_matmul_precision("high") 52 torch.backends.cudnn.deterministic = True 53 torch.backends.cuda.matmul.allow_tf32 = False 54 os.environ["TOKENIZERS_PARALLELISM"] = "false" 55elif gfx_device.type == "mps": 56 os.environ["DYLD_FALLBACK_LIBRARY_PATH"] = "/opt/homebrew/lib" 57 58SUPPORT_LANGUAGES = { 59 "en": 259, 60 "de": 260, 61 "fr": 262, 62 "es": 284, 63 "it": 285, 64 "pt": 286, 65 "pl": 294, 66 "tr": 295, 67 "ru": 267, 68 "cs": 293, 69 "nl": 297, 70 "ar": 5022, 71 "zh": 5023, 72 "ja": 5412, 73 "hu": 5753, 74 "ko": 6152, 75 "hi": 6680, 76} 77 78structure_pattern = re.compile(r"\[.*?\]") 79 80 81def ensure_directory_exists(directory): 82 directory = str(directory) 83 if not os.path.exists(directory): 84 os.makedirs(directory) 85 86 87REPO_ID = "ACE-Step/ACE-Step-v1-3.5B" 88REPO_ID_QUANT = REPO_ID + "-q4-K-M" # ??? update this i guess 89 90 91# class ACEStepPipeline(DiffusionPipeline): 92class ACEStepPipeline: 93 def __init__( 94 self, 95 checkpoint_dir=None, 96 device_id=0, 97 dtype="bfloat16", 98 text_encoder_checkpoint_path=None, 99 persistent_storage_path=None, 100 torch_compile=False, 101 cpu_offload=False, 102 quantized=False, 103 overlapped_decode=False, 104 **kwargs, 105 ): 106 if not checkpoint_dir: 107 if persistent_storage_path is None: 108 checkpoint_dir = os.path.join(os.path.expanduser("~"), ".cache/ace-step/checkpoints") 109 os.makedirs(checkpoint_dir, exist_ok=True) 110 else: 111 checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints") 112 ensure_directory_exists(checkpoint_dir) 113 114 self.checkpoint_dir = checkpoint_dir 115 self.lora_path = "none" 116 self.lora_weight = 1 117 self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32 118 if gfx_device.type == "mps": 119 if self.dtype == torch.bfloat16: 120 self.dtype = torch.float16 121 122 if "ACE_PIPELINE_DTYPE" in os.environ and len(os.environ["ACE_PIPELINE_DTYPE"]): 123 self.dtype = getattr(torch, os.environ["ACE_PIPELINE_DTYPE"]) 124 self.device: torch.device = gfx_device 125 self.loaded = False 126 self.torch_compile = torch_compile 127 self.cpu_offload = cpu_offload 128 self.quantized = quantized 129 self.overlapped_decode = overlapped_decode 130 131 def get_checkpoint_path(self, checkpoint_dir, repo): 132 checkpoint_dir_models = None 133 134 if checkpoint_dir is not None: 135 required_dirs = ["music_dcae_f8c8", "music_vocoder", "ace_step_transformer", "umt5-base"] 136 all_dirs_exist = True 137 for dir_name in required_dirs: 138 dir_path = os.path.join(checkpoint_dir, dir_name) 139 if not os.path.exists(dir_path): 140 all_dirs_exist = False 141 break 142 143 if all_dirs_exist: 144 nfo(f"Load models from: {checkpoint_dir}") 145 checkpoint_dir_models = checkpoint_dir 146 147 if checkpoint_dir_models is None: 148 if checkpoint_dir is None: 149 nfo(f"Download models from Hugging Face: {repo}") 150 checkpoint_dir_models = snapshot_download(repo) 151 else: 152 nfo(f"Download models from Hugging Face: {repo}, cache to: {checkpoint_dir}") 153 checkpoint_dir_models = snapshot_download(repo, cache_dir=checkpoint_dir) 154 return checkpoint_dir_models 155 156 def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False): 157 checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID) 158 dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8") 159 vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder") 160 ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer") 161 text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base") 162 163 self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype) 164 # self.ace_step_transformer.to(self.device).eval().to(self.dtype) 165 if self.cpu_offload: 166 self.ace_step_transformer = self.ace_step_transformer.to("cpu").eval().to(self.dtype) 167 else: 168 self.ace_step_transformer = self.ace_step_transformer.to(self.device).eval().to(self.dtype) 169 if self.torch_compile: 170 self.ace_step_transformer = torch.compile(self.ace_step_transformer) 171 172 self.music_dcae = MusicDCAE( 173 dcae_checkpoint_path=dcae_checkpoint_path, 174 vocoder_checkpoint_path=vocoder_checkpoint_path, 175 ) 176 # self.music_dcae.to(self.device).eval().to(self.dtype) 177 if self.cpu_offload: # might be redundant 178 self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype) 179 else: 180 self.music_dcae = self.music_dcae.to(self.device).eval().to(self.dtype) 181 if self.torch_compile: 182 self.music_dcae = torch.compile(self.music_dcae) 183 184 lang_segment = LangSegment() 185 lang_segment.setfilters(language_filters.default) 186 self.lang_segment = lang_segment 187 self.lyric_tokenizer = VoiceBpeTokenizer() 188 189 text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval() 190 # text_encoder_model = text_encoder_model.to(self.device).to(self.dtype) 191 if self.cpu_offload: 192 text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype) 193 else: 194 text_encoder_model = text_encoder_model.to(self.device).eval().to(self.dtype) 195 text_encoder_model.requires_grad_(False) 196 self.text_encoder_model = text_encoder_model 197 if self.torch_compile: 198 self.text_encoder_model = torch.compile(self.text_encoder_model) 199 200 self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path) 201 self.loaded = True 202 203 # compile 204 if self.torch_compile: 205 if export_quantized_weights: 206 from torch.ao.quantization import ( 207 Int4WeightOnlyConfig, 208 quantize_, 209 ) 210 211 group_size = 128 212 use_hqq = True 213 quantize_( 214 self.ace_step_transformer, 215 Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq), 216 ) 217 quantize_( 218 self.text_encoder_model, 219 Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq), 220 ) 221 222 # save quantized weights 223 torch.save( 224 self.ace_step_transformer.state_dict(), 225 os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), 226 ) 227 print( 228 "Quantized Weights Saved to: ", 229 os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), 230 ) 231 torch.save( 232 self.text_encoder_model.state_dict(), 233 os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), 234 ) 235 print( 236 "Quantized Weights Saved to: ", 237 os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), 238 ) 239 240 def load_quantized_checkpoint(self, checkpoint_dir=None): 241 checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID_QUANT) 242 dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8") 243 vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder") 244 ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer") 245 text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base") 246 247 self.music_dcae = MusicDCAE( 248 dcae_checkpoint_path=dcae_checkpoint_path, 249 vocoder_checkpoint_path=vocoder_checkpoint_path, 250 ) 251 if self.cpu_offload: 252 self.music_dcae.eval().to(self.dtype).to(self.device) 253 else: 254 self.music_dcae.eval().to(self.dtype).to("cpu") 255 self.music_dcae = torch.compile(self.music_dcae) 256 257 self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path) 258 self.ace_step_transformer.eval().to(self.dtype).to("cpu") 259 self.ace_step_transformer = torch.compile(self.ace_step_transformer) 260 self.ace_step_transformer.load_state_dict( 261 torch.load( 262 os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), 263 map_location=self.device, 264 ), 265 assign=True, 266 ) 267 self.ace_step_transformer.torchao_quantized = True 268 269 self.text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path) 270 self.text_encoder_model.eval().to(self.dtype).to("cpu") 271 self.text_encoder_model = torch.compile(self.text_encoder_model) 272 self.text_encoder_model.load_state_dict( 273 torch.load( 274 os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), 275 map_location=self.device, 276 ), 277 assign=True, 278 ) 279 self.text_encoder_model.torchao_quantized = True 280 281 self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path) 282 283 lang_segment = LangSegment() 284 lang_segment.setfilters(language_filters.default) 285 self.lang_segment = lang_segment 286 self.lyric_tokenizer = VoiceBpeTokenizer() 287 288 self.loaded = True 289 290 @cpu_offload("text_encoder_model") 291 def get_text_embeddings(self, texts, text_max_length=256): 292 inputs = self.text_tokenizer( 293 texts, 294 return_tensors="pt", 295 padding=True, 296 truncation=True, 297 max_length=text_max_length, 298 ) 299 inputs = {key: value.to(self.device) for key, value in inputs.items()} 300 if self.text_encoder_model.device != self.device: 301 self.text_encoder_model.to(self.device) 302 with torch.no_grad(): 303 outputs = self.text_encoder_model(**inputs) 304 last_hidden_states = outputs.last_hidden_state 305 attention_mask = inputs["attention_mask"] 306 return last_hidden_states, attention_mask 307 308 @cpu_offload("text_encoder_model") 309 def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10): 310 inputs = self.text_tokenizer( 311 texts, 312 return_tensors="pt", 313 padding=True, 314 truncation=True, 315 max_length=text_max_length, 316 ) 317 inputs = {key: value.to(self.device) for key, value in inputs.items()} 318 if self.text_encoder_model.device != self.device: 319 self.text_encoder_model.to(self.device) 320 321 def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10): 322 handlers = [] 323 324 def hook(module, input, output): 325 output[:] *= tau 326 return output 327 328 for i in range(l_min, l_max): 329 handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook) 330 handlers.append(handler) 331 332 with torch.no_grad(): 333 outputs = self.text_encoder_model(**inputs) 334 last_hidden_states = outputs.last_hidden_state 335 336 for hook in handlers: 337 hook.remove() 338 339 return last_hidden_states 340 341 last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max) 342 return last_hidden_states 343 344 def set_seeds(self, batch_size, manual_seeds=None): 345 processed_input_seeds = None 346 if manual_seeds is not None: 347 if isinstance(manual_seeds, str): 348 if "," in manual_seeds: 349 processed_input_seeds = list(map(int, manual_seeds.split(","))) 350 elif manual_seeds.isdigit(): 351 processed_input_seeds = int(manual_seeds) 352 elif isinstance(manual_seeds, list) and all(isinstance(s, int) for s in manual_seeds): 353 if len(manual_seeds) > 0: 354 processed_input_seeds = list(manual_seeds) 355 elif isinstance(manual_seeds, int): 356 processed_input_seeds = manual_seeds 357 random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)] 358 actual_seeds = [] 359 for i in range(batch_size): 360 current_seed_for_generator = None 361 if processed_input_seeds is None: 362 current_seed_for_generator = torch.randint(0, 2**32, (1,)).item() 363 elif isinstance(processed_input_seeds, int): 364 current_seed_for_generator = processed_input_seeds 365 elif isinstance(processed_input_seeds, list): 366 if i < len(processed_input_seeds): 367 current_seed_for_generator = processed_input_seeds[i] 368 else: 369 current_seed_for_generator = processed_input_seeds[-1] 370 if current_seed_for_generator is None: 371 current_seed_for_generator = torch.randint(0, 2**32, (1,)).item() 372 random_generators[i].manual_seed(current_seed_for_generator) 373 actual_seeds.append(current_seed_for_generator) 374 return random_generators, actual_seeds 375 376 def get_lang(self, text): 377 language = "en" 378 try: 379 _ = self.lang_segment.getTexts(text) 380 langCounts = self.lang_segment.getCounts() 381 language = langCounts[0][0] 382 if len(langCounts) > 1 and language == "en": 383 language = langCounts[1][0] 384 except Exception as err: 385 language = "en" 386 return language 387 388 def tokenize_lyrics(self, lyrics, debug=False): 389 lines = lyrics.split("\n") 390 lyric_token_idx = [261] 391 for line in lines: 392 line = line.strip() 393 if not line: 394 lyric_token_idx += [2] 395 continue 396 397 lang = self.get_lang(line) 398 399 if lang not in SUPPORT_LANGUAGES: 400 lang = "en" 401 if "zh" in lang: 402 lang = "zh" 403 if "spa" in lang: 404 lang = "es" 405 406 try: 407 if structure_pattern.match(line): 408 token_idx = self.lyric_tokenizer.encode(line, "en") 409 else: 410 token_idx = self.lyric_tokenizer.encode(line, lang) 411 if debug: 412 toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx]) 413 nfo(f"debbug {line} --> {lang} --> {toks}") 414 lyric_token_idx = lyric_token_idx + token_idx + [2] 415 except Exception as e: 416 print("tokenize error", e, "for line", line, "major_language", lang) 417 return lyric_token_idx 418 419 @cpu_offload("ace_step_transformer") 420 def calc_v( 421 self, 422 zt_src, 423 zt_tar, 424 t, 425 encoder_text_hidden_states, 426 text_attention_mask, 427 target_encoder_text_hidden_states, 428 target_text_attention_mask, 429 speaker_embds, 430 target_speaker_embeds, 431 lyric_token_ids, 432 lyric_mask, 433 target_lyric_token_ids, 434 target_lyric_mask, 435 do_classifier_free_guidance=False, 436 guidance_scale=1.0, 437 target_guidance_scale=1.0, 438 cfg_type="apg", 439 attention_mask=None, 440 momentum_buffer=None, 441 momentum_buffer_tar=None, 442 return_src_pred=True, 443 ): 444 noise_pred_src = None 445 if return_src_pred: 446 src_latent_model_input = torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src 447 timestep = t.expand(src_latent_model_input.shape[0]) 448 # source 449 noise_pred_src = self.ace_step_transformer( 450 hidden_states=src_latent_model_input, 451 attention_mask=attention_mask, 452 encoder_text_hidden_states=encoder_text_hidden_states, 453 text_attention_mask=text_attention_mask, 454 speaker_embeds=speaker_embds, 455 lyric_token_idx=lyric_token_ids, 456 lyric_mask=lyric_mask, 457 timestep=timestep, 458 ).sample 459 460 if do_classifier_free_guidance: 461 noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk(2) 462 if cfg_type == "apg": 463 noise_pred_src = apg_forward( 464 pred_cond=noise_pred_with_cond_src, 465 pred_uncond=noise_pred_uncond_src, 466 guidance_scale=guidance_scale, 467 momentum_buffer=momentum_buffer, 468 ) 469 elif cfg_type == "cfg": 470 noise_pred_src = cfg_forward( 471 cond_output=noise_pred_with_cond_src, 472 uncond_output=noise_pred_uncond_src, 473 cfg_strength=guidance_scale, 474 ) 475 476 tar_latent_model_input = torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar 477 timestep = t.expand(tar_latent_model_input.shape[0]) 478 # target 479 noise_pred_tar = self.ace_step_transformer( 480 hidden_states=tar_latent_model_input, 481 attention_mask=attention_mask, 482 encoder_text_hidden_states=target_encoder_text_hidden_states, 483 text_attention_mask=target_text_attention_mask, 484 speaker_embeds=target_speaker_embeds, 485 lyric_token_idx=target_lyric_token_ids, 486 lyric_mask=target_lyric_mask, 487 timestep=timestep, 488 ).sample 489 490 if do_classifier_free_guidance: 491 noise_pred_with_cond_tar, noise_pred_uncond_tar = noise_pred_tar.chunk(2) 492 if cfg_type == "apg": 493 noise_pred_tar = apg_forward( 494 pred_cond=noise_pred_with_cond_tar, 495 pred_uncond=noise_pred_uncond_tar, 496 guidance_scale=target_guidance_scale, 497 momentum_buffer=momentum_buffer_tar, 498 ) 499 elif cfg_type == "cfg": 500 noise_pred_tar = cfg_forward( 501 cond_output=noise_pred_with_cond_tar, 502 uncond_output=noise_pred_uncond_tar, 503 cfg_strength=target_guidance_scale, 504 ) 505 return noise_pred_src, noise_pred_tar 506 507 @torch.no_grad() 508 def flowedit_diffusion_process( 509 self, 510 encoder_text_hidden_states, 511 text_attention_mask, 512 speaker_embds, 513 lyric_token_ids, 514 lyric_mask, 515 target_encoder_text_hidden_states, 516 target_text_attention_mask, 517 target_speaker_embeds, 518 target_lyric_token_ids, 519 target_lyric_mask, 520 src_latents, 521 random_generators=None, 522 infer_steps=60, 523 guidance_scale=15.0, 524 n_min=0, 525 n_max=1.0, 526 n_avg=1, 527 scheduler_type="euler", 528 ): 529 do_classifier_free_guidance = True 530 if guidance_scale == 0.0 or guidance_scale == 1.0: 531 do_classifier_free_guidance = False 532 533 target_guidance_scale = guidance_scale 534 bsz = encoder_text_hidden_states.shape[0] 535 536 scheduler = FlowMatchEulerDiscreteScheduler( 537 num_train_timesteps=1000, 538 shift=3.0, 539 ) 540 541 T_steps = infer_steps 542 frame_length = src_latents.shape[-1] 543 attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype) 544 545 timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, self.device, timesteps=None) 546 547 if do_classifier_free_guidance: 548 attention_mask = torch.cat([attention_mask] * 2, dim=0) 549 550 encoder_text_hidden_states = torch.cat( 551 [ 552 encoder_text_hidden_states, 553 torch.zeros_like(encoder_text_hidden_states), 554 ], 555 0, 556 ) 557 text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0) 558 559 target_encoder_text_hidden_states = torch.cat( 560 [ 561 target_encoder_text_hidden_states, 562 torch.zeros_like(target_encoder_text_hidden_states), 563 ], 564 0, 565 ) 566 target_text_attention_mask = torch.cat([target_text_attention_mask] * 2, dim=0) 567 568 speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0) 569 target_speaker_embeds = torch.cat([target_speaker_embeds, torch.zeros_like(target_speaker_embeds)], 0) 570 571 lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0) 572 lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0) 573 574 target_lyric_token_ids = torch.cat([target_lyric_token_ids, torch.zeros_like(target_lyric_token_ids)], 0) 575 target_lyric_mask = torch.cat([target_lyric_mask, torch.zeros_like(target_lyric_mask)], 0) 576 577 momentum_buffer = MomentumBuffer() 578 momentum_buffer_tar = MomentumBuffer() 579 x_src = src_latents 580 zt_edit = x_src.clone() 581 xt_tar = None 582 n_min = int(infer_steps * n_min) 583 n_max = int(infer_steps * n_max) 584 585 nfo("flowedit start from {} to {}".format(n_min, n_max)) 586 587 for i, t in tqdm(enumerate(timesteps), total=T_steps): 588 if i < n_min: 589 continue 590 591 t_i = t / 1000 592 593 if i + 1 < len(timesteps): 594 t_im1 = (timesteps[i + 1]) / 1000 595 else: 596 t_im1 = torch.zeros_like(t_i).to(self.device) 597 598 if i < n_max: 599 # Calculate the average of the V predictions 600 V_delta_avg = torch.zeros_like(x_src) 601 for k in range(n_avg): 602 fwd_noise = randn_tensor( 603 shape=x_src.shape, 604 generator=random_generators, 605 device=self.device, 606 dtype=self.dtype, 607 ) 608 609 zt_src = (1 - t_i) * x_src + (t_i) * fwd_noise 610 611 zt_tar = zt_edit + zt_src - x_src 612 613 Vt_src, Vt_tar = self.calc_v( 614 zt_src=zt_src, 615 zt_tar=zt_tar, 616 t=t, 617 encoder_text_hidden_states=encoder_text_hidden_states, 618 text_attention_mask=text_attention_mask, 619 target_encoder_text_hidden_states=target_encoder_text_hidden_states, 620 target_text_attention_mask=target_text_attention_mask, 621 speaker_embds=speaker_embds, 622 target_speaker_embeds=target_speaker_embeds, 623 lyric_token_ids=lyric_token_ids, 624 lyric_mask=lyric_mask, 625 target_lyric_token_ids=target_lyric_token_ids, 626 target_lyric_mask=target_lyric_mask, 627 do_classifier_free_guidance=do_classifier_free_guidance, 628 guidance_scale=guidance_scale, 629 target_guidance_scale=target_guidance_scale, 630 attention_mask=attention_mask, 631 momentum_buffer=momentum_buffer, 632 ) 633 V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src) # - (hfg - 1) * (x_src) 634 635 zt_edit = zt_edit.to(torch.float32) # arbitrary, should be settable for compatibility 636 if scheduler_type != "pingpong": 637 # propagate direct ODE 638 zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg 639 zt_edit = zt_edit.to(self.dtype) 640 else: 641 # propagate pingpong SDE 642 zt_edit_denoised = zt_edit - t_i * V_delta_avg 643 noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None) 644 prev_sample = (1 - t_im1) * zt_edit_denoised + t_im1 * noise 645 646 else: # i >= T_steps-n_min # regular sampling for last n_min steps 647 if i == n_max: 648 fwd_noise = randn_tensor( 649 shape=x_src.shape, 650 generator=random_generators, 651 device=self.device, 652 dtype=self.dtype, 653 ) 654 scheduler._init_step_index(t) 655 sigma = scheduler.sigmas[scheduler.step_index] 656 xt_src = sigma * fwd_noise + (1.0 - sigma) * x_src 657 xt_tar = zt_edit + xt_src - x_src 658 659 _, Vt_tar = self.calc_v( 660 zt_src=None, 661 zt_tar=xt_tar, 662 t=t, 663 encoder_text_hidden_states=encoder_text_hidden_states, 664 text_attention_mask=text_attention_mask, 665 target_encoder_text_hidden_states=target_encoder_text_hidden_states, 666 target_text_attention_mask=target_text_attention_mask, 667 speaker_embds=speaker_embds, 668 target_speaker_embeds=target_speaker_embeds, 669 lyric_token_ids=lyric_token_ids, 670 lyric_mask=lyric_mask, 671 target_lyric_token_ids=target_lyric_token_ids, 672 target_lyric_mask=target_lyric_mask, 673 do_classifier_free_guidance=do_classifier_free_guidance, 674 guidance_scale=guidance_scale, 675 target_guidance_scale=target_guidance_scale, 676 attention_mask=attention_mask, 677 momentum_buffer_tar=momentum_buffer_tar, 678 return_src_pred=False, 679 ) 680 681 xt_tar = xt_tar.to(torch.float32) 682 if scheduler_type != "pingpong": 683 prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar 684 prev_sample = prev_sample.to(self.dtype) 685 xt_tar = prev_sample 686 else: 687 prev_sample = xt_tar - t_i * Vt_tar 688 noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None) 689 prev_sample = (1 - t_im1) * prev_sample + t_im1 * noise 690 xt_tar = prev_sample 691 692 target_latents = zt_edit if xt_tar is None else xt_tar 693 return target_latents 694 695 def add_latents_noise( 696 self, 697 gt_latents, 698 sigma_max, 699 noise, 700 scheduler_type, 701 infer_steps, 702 ): 703 bsz = gt_latents.shape[0] 704 if scheduler_type == "euler": 705 scheduler = FlowMatchEulerDiscreteScheduler( 706 num_train_timesteps=1000, 707 shift=3.0, 708 sigma_max=sigma_max, 709 ) 710 elif scheduler_type == "heun": 711 scheduler = FlowMatchHeunDiscreteScheduler( 712 num_train_timesteps=1000, 713 shift=3.0, 714 sigma_max=sigma_max, 715 ) 716 elif scheduler_type == "pingpong": 717 scheduler = FlowMatchPingPongScheduler(num_train_timesteps=1000, shift=3.0, sigma_max=sigma_max) 718 719 infer_steps = int(sigma_max * infer_steps) 720 timesteps, num_inference_steps = retrieve_timesteps( 721 scheduler, 722 num_inference_steps=infer_steps, 723 device=self.device, 724 timesteps=None, 725 ) 726 noisy_image = gt_latents * (1 - scheduler.sigma_max) + noise * scheduler.sigma_max 727 nfo(f"{scheduler.sigma_min=} {scheduler.sigma_max=} {timesteps=} {num_inference_steps=}") 728 return noisy_image, timesteps, scheduler, num_inference_steps 729 730 @cpu_offload("ace_step_transformer") 731 @torch.no_grad() 732 def text2music_diffusion_process( 733 self, 734 duration, 735 encoder_text_hidden_states, 736 text_attention_mask, 737 speaker_embds, 738 lyric_token_ids, 739 lyric_mask, 740 random_generators=None, 741 infer_steps=60, 742 guidance_scale=15.0, 743 omega_scale=10.0, 744 scheduler_type="euler", 745 cfg_type="apg", 746 zero_steps=1, 747 use_zero_init=True, 748 guidance_interval=0.5, 749 guidance_interval_decay=1.0, 750 min_guidance_scale=3.0, 751 oss_steps=[], 752 encoder_text_hidden_states_null=None, 753 use_erg_lyric=False, 754 use_erg_diffusion=False, 755 retake_random_generators=None, 756 retake_variance=0.5, 757 add_retake_noise=False, 758 guidance_scale_text=0.0, 759 guidance_scale_lyric=0.0, 760 repaint_start=0, 761 repaint_end=0, 762 src_latents=None, 763 audio2audio_enable=False, 764 ref_audio_strength=0.5, 765 ref_latents=None, 766 ): 767 nfo("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale)) 768 do_classifier_free_guidance = True 769 if guidance_scale == 0.0 or guidance_scale == 1.0: 770 do_classifier_free_guidance = False 771 772 do_double_condition_guidance = False 773 if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0: 774 do_double_condition_guidance = True 775 nfo( 776 "do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format( 777 do_double_condition_guidance, 778 guidance_scale_text, 779 guidance_scale_lyric, 780 ) 781 ) 782 783 bsz = encoder_text_hidden_states.shape[0] 784 785 if scheduler_type == "euler": 786 scheduler = FlowMatchEulerDiscreteScheduler( 787 num_train_timesteps=1000, 788 shift=3.0, 789 ) 790 elif scheduler_type == "heun": 791 scheduler = FlowMatchHeunDiscreteScheduler( 792 num_train_timesteps=1000, 793 shift=3.0, 794 ) 795 elif scheduler_type == "pingpong": 796 scheduler = FlowMatchPingPongScheduler( 797 num_train_timesteps=1000, 798 shift=3.0, 799 ) 800 801 frame_length = int(duration * 44100 / 512 / 8) 802 if src_latents is not None: 803 frame_length = src_latents.shape[-1] 804 805 if ref_latents is not None: 806 frame_length = ref_latents.shape[-1] 807 808 if len(oss_steps) > 0: 809 infer_steps = max(oss_steps) 810 scheduler.set_timesteps 811 timesteps, num_inference_steps = retrieve_timesteps( 812 scheduler, 813 num_inference_steps=infer_steps, 814 device=self.device, 815 timesteps=None, 816 ) 817 new_timesteps = torch.zeros(len(oss_steps), dtype=self.dtype, device=self.device) 818 for idx in range(len(oss_steps)): 819 new_timesteps[idx] = timesteps[oss_steps[idx] - 1] 820 num_inference_steps = len(oss_steps) 821 sigmas = (new_timesteps / 1000).float().cpu().numpy() 822 timesteps, num_inference_steps = retrieve_timesteps( 823 scheduler, 824 num_inference_steps=num_inference_steps, 825 device=self.device, 826 sigmas=sigmas, 827 ) 828 nfo(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}") 829 else: 830 timesteps, num_inference_steps = retrieve_timesteps( 831 scheduler, 832 num_inference_steps=infer_steps, 833 device=self.device, 834 timesteps=None, 835 ) 836 837 target_latents = randn_tensor( 838 shape=(bsz, 8, 16, frame_length), 839 generator=random_generators, 840 device=self.device, 841 dtype=self.dtype, 842 ) 843 844 is_repaint = False 845 is_extend = False 846 847 if add_retake_noise: 848 n_min = int(infer_steps * (1 - retake_variance)) 849 retake_variance = torch.tensor(retake_variance * math.pi / 2).to(self.device).to(self.dtype) 850 retake_latents = randn_tensor( 851 shape=(bsz, 8, 16, frame_length), 852 generator=retake_random_generators, 853 device=self.device, 854 dtype=self.dtype, 855 ) 856 repaint_start_frame = int(repaint_start * 44100 / 512 / 8) 857 repaint_end_frame = int(repaint_end * 44100 / 512 / 8) 858 x0 = src_latents 859 # retake 860 is_repaint = repaint_end_frame - repaint_start_frame != frame_length 861 862 is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length) 863 if is_extend: 864 is_repaint = True 865 866 # TODO: train a mask aware repainting controlnet 867 # to make sure mean = 0, std = 1 868 if not is_repaint: 869 target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents 870 elif not is_extend: 871 # if repaint_end_frame 872 repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) 873 repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0 874 repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents 875 repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents) 876 zt_edit = x0.clone() 877 z0 = repaint_noise 878 elif is_extend: 879 to_right_pad_gt_latents = None 880 to_left_pad_gt_latents = None 881 gt_latents = src_latents 882 src_latents_length = gt_latents.shape[-1] 883 max_infer_fame_length = int(240 * 44100 / 512 / 8) 884 left_pad_frame_length = 0 885 right_pad_frame_length = 0 886 right_trim_length = 0 887 left_trim_length = 0 888 if repaint_start_frame < 0: 889 left_pad_frame_length = abs(repaint_start_frame) 890 frame_length = left_pad_frame_length + gt_latents.shape[-1] 891 extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0) 892 if frame_length > max_infer_fame_length: 893 right_trim_length = frame_length - max_infer_fame_length 894 extend_gt_latents = extend_gt_latents[:, :, :, :max_infer_fame_length] 895 to_right_pad_gt_latents = extend_gt_latents[:, :, :, -right_trim_length:] 896 frame_length = max_infer_fame_length 897 repaint_start_frame = 0 898 gt_latents = extend_gt_latents 899 900 if repaint_end_frame > src_latents_length: 901 right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1] 902 frame_length = gt_latents.shape[-1] + right_pad_frame_length 903 extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0) 904 if frame_length > max_infer_fame_length: 905 left_trim_length = frame_length - max_infer_fame_length 906 extend_gt_latents = extend_gt_latents[:, :, :, -max_infer_fame_length:] 907 to_left_pad_gt_latents = extend_gt_latents[:, :, :, :left_trim_length] 908 frame_length = max_infer_fame_length 909 repaint_end_frame = frame_length 910 gt_latents = extend_gt_latents 911 912 repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) 913 if left_pad_frame_length > 0: 914 repaint_mask[:, :, :, :left_pad_frame_length] = 1.0 915 if right_pad_frame_length > 0: 916 repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0 917 x0 = gt_latents 918 padd_list = [] 919 if left_pad_frame_length > 0: 920 padd_list.append(retake_latents[:, :, :, :left_pad_frame_length]) 921 padd_list.append( 922 target_latents[ 923 :, 924 :, 925 :, 926 left_trim_length : target_latents.shape[-1] - right_trim_length, 927 ] 928 ) 929 if right_pad_frame_length > 0: 930 padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:]) 931 target_latents = torch.cat(padd_list, dim=-1) 932 assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}" 933 zt_edit = x0.clone() 934 z0 = target_latents 935 936 if audio2audio_enable and ref_latents is not None: 937 nfo(f"audio2audio_enable: {audio2audio_enable}, ref_latents: {ref_latents.shape}") 938 target_latents, timesteps, scheduler, num_inference_steps = self.add_latents_noise( 939 gt_latents=ref_latents, 940 sigma_max=(1 - ref_audio_strength), 941 noise=target_latents, 942 scheduler_type=scheduler_type, 943 infer_steps=infer_steps, 944 ) 945 946 attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype) 947 948 # guidance interval 949 start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2)) 950 end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5)) 951 nfo(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}") 952 953 momentum_buffer = MomentumBuffer() 954 955 def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6): 956 handlers = [] 957 958 def hook(module, input, output): 959 output[:] *= tau 960 return output 961 962 for i in range(l_min, l_max): 963 handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook) 964 handlers.append(handler) 965 966 encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs) 967 968 for hook in handlers: 969 hook.remove() 970 971 return encoder_hidden_states 972 973 # P(speaker, text, lyric) 974 encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode( 975 encoder_text_hidden_states, 976 text_attention_mask, 977 speaker_embds, 978 lyric_token_ids, 979 lyric_mask, 980 ) 981 982 if use_erg_lyric: 983 # P(null_speaker, text_weaker, lyric_weaker) 984 encoder_hidden_states_null = forward_encoder_with_temperature( 985 self, 986 inputs={ 987 "encoder_text_hidden_states": ( 988 encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states) 989 ), 990 "text_attention_mask": text_attention_mask, 991 "speaker_embeds": torch.zeros_like(speaker_embds), 992 "lyric_token_idx": lyric_token_ids, 993 "lyric_mask": lyric_mask, 994 }, 995 ) 996 else: 997 # P(null_speaker, null_text, null_lyric) 998 encoder_hidden_states_null, _ = self.ace_step_transformer.encode( 999 torch.zeros_like(encoder_text_hidden_states), 1000 text_attention_mask, 1001 torch.zeros_like(speaker_embds), 1002 torch.zeros_like(lyric_token_ids), 1003 lyric_mask, 1004 ) 1005 1006 encoder_hidden_states_no_lyric = None 1007 if do_double_condition_guidance: 1008 # P(null_speaker, text, lyric_weaker) 1009 if use_erg_lyric: 1010 encoder_hidden_states_no_lyric = forward_encoder_with_temperature( 1011 self, 1012 inputs={ 1013 "encoder_text_hidden_states": encoder_text_hidden_states, 1014 "text_attention_mask": text_attention_mask, 1015 "speaker_embeds": torch.zeros_like(speaker_embds), 1016 "lyric_token_idx": lyric_token_ids, 1017 "lyric_mask": lyric_mask, 1018 }, 1019 ) 1020 # P(null_speaker, text, no_lyric) 1021 else: 1022 encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode( 1023 encoder_text_hidden_states, 1024 text_attention_mask, 1025 torch.zeros_like(speaker_embds), 1026 torch.zeros_like(lyric_token_ids), 1027 lyric_mask, 1028 ) 1029 1030 def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20): 1031 handlers = [] 1032 1033 def hook(module, input, output): 1034 output[:] *= tau 1035 return output 1036 1037 for i in range(l_min, l_max): 1038 handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook) 1039 handlers.append(handler) 1040 handler = self.ace_step_transformer.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook) 1041 handlers.append(handler) 1042 1043 sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample 1044 1045 for hook in handlers: 1046 hook.remove() 1047 1048 return sample 1049 1050 for i, t in tqdm(enumerate(timesteps), total=num_inference_steps): 1051 if is_repaint: 1052 if i < n_min: 1053 continue 1054 elif i == n_min: 1055 t_i = t / 1000 1056 zt_src = (1 - t_i) * x0 + (t_i) * z0 1057 target_latents = zt_edit + zt_src - x0 1058 nfo(f"repaint start from {n_min} add {t_i} level of noise") 1059 1060 # expand the latents if we are doing classifier free guidance 1061 latents = target_latents 1062 1063 is_in_guidance_interval = start_idx <= i < end_idx 1064 if is_in_guidance_interval and do_classifier_free_guidance: 1065 # compute current guidance scale 1066 if guidance_interval_decay > 0: 1067 # Linearly interpolate to calculate the current guidance scale 1068 progress = (i - start_idx) / (end_idx - start_idx - 1) # 归一化到[0,1] 1069 current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay 1070 else: 1071 current_guidance_scale = guidance_scale 1072 1073 latent_model_input = latents 1074 timestep = t.expand(latent_model_input.shape[0]) 1075 output_length = latent_model_input.shape[-1] 1076 # P(x|speaker, text, lyric) 1077 noise_pred_with_cond = self.ace_step_transformer.decode( 1078 hidden_states=latent_model_input, 1079 attention_mask=attention_mask, 1080 encoder_hidden_states=encoder_hidden_states, 1081 encoder_hidden_mask=encoder_hidden_mask, 1082 output_length=output_length, 1083 timestep=timestep, 1084 ).sample 1085 1086 noise_pred_with_only_text_cond = None 1087 if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None: 1088 noise_pred_with_only_text_cond = self.ace_step_transformer.decode( 1089 hidden_states=latent_model_input, 1090 attention_mask=attention_mask, 1091 encoder_hidden_states=encoder_hidden_states_no_lyric, 1092 encoder_hidden_mask=encoder_hidden_mask, 1093 output_length=output_length, 1094 timestep=timestep, 1095 ).sample 1096 1097 if use_erg_diffusion: 1098 noise_pred_uncond = forward_diffusion_with_temperature( 1099 self, 1100 hidden_states=latent_model_input, 1101 timestep=timestep, 1102 inputs={ 1103 "encoder_hidden_states": encoder_hidden_states_null, 1104 "encoder_hidden_mask": encoder_hidden_mask, 1105 "output_length": output_length, 1106 "attention_mask": attention_mask, 1107 }, 1108 ) 1109 else: 1110 noise_pred_uncond = self.ace_step_transformer.decode( 1111 hidden_states=latent_model_input, 1112 attention_mask=attention_mask, 1113 encoder_hidden_states=encoder_hidden_states_null, 1114 encoder_hidden_mask=encoder_hidden_mask, 1115 output_length=output_length, 1116 timestep=timestep, 1117 ).sample 1118 1119 if do_double_condition_guidance and noise_pred_with_only_text_cond is not None: 1120 noise_pred = cfg_double_condition_forward( 1121 cond_output=noise_pred_with_cond, 1122 uncond_output=noise_pred_uncond, 1123 only_text_cond_output=noise_pred_with_only_text_cond, 1124 guidance_scale_text=guidance_scale_text, 1125 guidance_scale_lyric=guidance_scale_lyric, 1126 ) 1127 1128 elif cfg_type == "apg": 1129 noise_pred = apg_forward( 1130 pred_cond=noise_pred_with_cond, 1131 pred_uncond=noise_pred_uncond, 1132 guidance_scale=current_guidance_scale, 1133 momentum_buffer=momentum_buffer, 1134 ) 1135 elif cfg_type == "cfg": 1136 noise_pred = cfg_forward( 1137 cond_output=noise_pred_with_cond, 1138 uncond_output=noise_pred_uncond, 1139 cfg_strength=current_guidance_scale, 1140 ) 1141 elif cfg_type == "cfg_star": 1142 noise_pred = cfg_zero_star( 1143 noise_pred_with_cond=noise_pred_with_cond, 1144 noise_pred_uncond=noise_pred_uncond, 1145 guidance_scale=current_guidance_scale, 1146 i=i, 1147 zero_steps=zero_steps, 1148 use_zero_init=use_zero_init, 1149 ) 1150 else: 1151 latent_model_input = latents 1152 timestep = t.expand(latent_model_input.shape[0]) 1153 noise_pred = self.ace_step_transformer.decode( 1154 hidden_states=latent_model_input, 1155 attention_mask=attention_mask, 1156 encoder_hidden_states=encoder_hidden_states, 1157 encoder_hidden_mask=encoder_hidden_mask, 1158 output_length=latent_model_input.shape[-1], 1159 timestep=timestep, 1160 ).sample 1161 1162 if is_repaint and i >= n_min: 1163 t_i = t / 1000 1164 if i + 1 < len(timesteps): 1165 t_im1 = (timesteps[i + 1]) / 1000 1166 else: 1167 t_im1 = torch.zeros_like(t_i).to(self.device) 1168 target_latents = target_latents.to(torch.float32) 1169 prev_sample = target_latents + (t_im1 - t_i) * noise_pred 1170 prev_sample = prev_sample.to(self.dtype) 1171 target_latents = prev_sample 1172 zt_src = (1 - t_im1) * x0 + (t_im1) * z0 1173 target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src) 1174 else: 1175 target_latents = scheduler.step( 1176 model_output=noise_pred, 1177 timestep=t, 1178 sample=target_latents, 1179 return_dict=False, 1180 omega=omega_scale, 1181 generator=random_generators[0], 1182 )[0] 1183 1184 if is_extend: 1185 if to_right_pad_gt_latents is not None: 1186 target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1) 1187 if to_left_pad_gt_latents is not None: 1188 target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0) 1189 return target_latents 1190 1191 @cpu_offload("music_dcae") 1192 def latents2audio( 1193 self, 1194 latents, 1195 target_wav_duration_second=30, 1196 sample_rate=48000, 1197 save_path=None, 1198 format="wav", 1199 ): 1200 output_audio_paths = [] 1201 bs = latents.shape[0] 1202 pred_latents = latents 1203 with torch.no_grad(): 1204 if self.overlapped_decode and target_wav_duration_second > 48: 1205 _, pred_wavs = self.music_dcae.decode_overlap(pred_latents, sr=sample_rate) 1206 else: 1207 _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate) 1208 pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs] 1209 for i in tqdm(range(bs)): 1210 output_audio_path = self.save_wav_file( 1211 pred_wavs[i], 1212 i, 1213 save_path=save_path, 1214 sample_rate=sample_rate, 1215 format=format, 1216 ) 1217 output_audio_paths.append(output_audio_path) 1218 return output_audio_paths 1219 1220 def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="wav"): 1221 if save_path is None: 1222 nfo("save_path is None, using default path ./outputs/") 1223 base_path = "./outputs" 1224 ensure_directory_exists(base_path) 1225 output_path_wav = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format 1226 else: 1227 ensure_directory_exists(os.path.dirname(save_path)) 1228 if os.path.isdir(save_path): 1229 nfo(f"Provided save_path '{save_path}' is a directory. Appending timestamped filename.") 1230 output_path_wav = os.path.join(save_path, f"output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format) 1231 else: 1232 output_path_wav = save_path 1233 1234 target_wav = target_wav.float() 1235 backend = "soundfile" 1236 if format == "ogg": 1237 backend = "sox" 1238 nfo(f"Saving audio to {output_path_wav} using backend {backend}") 1239 try: 1240 torchaudio.save(output_path_wav, target_wav, sample_rate=sample_rate, format=format, backend=backend) 1241 except (RuntimeError, ImportError, ModuleNotFoundError): 1242 import soundfile as sf # pyright: ignore[reportMissingImports] | pylint:disable=import-error 1243 1244 sf.write(output_path_wav, target_wav, sample_rate) 1245 return output_path_wav 1246 1247 @cpu_offload("music_dcae") 1248 def infer_latents(self, input_audio_path): 1249 if input_audio_path is None: 1250 return None 1251 input_audio, sr = self.music_dcae.load_audio(input_audio_path) 1252 input_audio = input_audio.unsqueeze(0) 1253 input_audio = input_audio.to(device=self.device, dtype=self.dtype) 1254 latents, _ = self.music_dcae.encode(input_audio, sr=sr) 1255 return latents 1256 1257 def load_lora(self, lora_name_or_path, lora_weight): 1258 if (lora_name_or_path != self.lora_path or lora_weight != self.lora_weight) and lora_name_or_path != "none": 1259 if not os.path.exists(lora_name_or_path): 1260 lora_download_path = snapshot_download(lora_name_or_path, cache_dir=self.checkpoint_dir) 1261 else: 1262 lora_download_path = lora_name_or_path 1263 if self.lora_path != "none": 1264 self.ace_step_transformer.unload_lora() 1265 self.ace_step_transformer.load_lora_adapter( 1266 os.path.join(lora_download_path, "pytorch_lora_weights.safetensors"), adapter_name="ace_step_lora", with_alpha=True, prefix=None 1267 ) 1268 nfo(f"Loading lora weights from: {lora_name_or_path} download path is: {lora_download_path} weight: {lora_weight}") 1269 set_weights_and_activate_adapters(self.ace_step_transformer, ["ace_step_lora"], [lora_weight]) 1270 self.lora_path = lora_name_or_path 1271 self.lora_weight = lora_weight 1272 elif self.lora_path != "none" and lora_name_or_path == "none": 1273 nfo("No lora weights to load.") 1274 self.ace_step_transformer.unload_lora() 1275 1276 def __call__( 1277 self, 1278 format: str = "wav", 1279 audio_duration: float = 60.0, 1280 prompt: str = None, 1281 lyrics: str = None, 1282 infer_step: int = 60, 1283 guidance_scale: float = 15.0, 1284 scheduler_type: str = "euler", 1285 cfg_type: str = "apg", 1286 omega_scale: int = 10.0, 1287 manual_seeds: list = None, 1288 guidance_interval: float = 0.5, 1289 guidance_interval_decay: float = 0.0, 1290 min_guidance_scale: float = 3.0, 1291 use_erg_tag: bool = True, 1292 use_erg_lyric: bool = True, 1293 use_erg_diffusion: bool = True, 1294 oss_steps: str = None, 1295 guidance_scale_text: float = 0.0, 1296 guidance_scale_lyric: float = 0.0, 1297 audio2audio_enable: bool = False, 1298 ref_audio_strength: float = 0.5, 1299 ref_audio_input: str = None, 1300 lora_name_or_path: str = "none", 1301 lora_weight: float = 1.0, 1302 retake_seeds: list = None, 1303 retake_variance: float = 0.5, 1304 task: str = "text2music", 1305 repaint_start: int = 0, 1306 repaint_end: int = 0, 1307 src_audio_path: str = None, 1308 edit_target_prompt: str = None, 1309 edit_target_lyrics: str = None, 1310 edit_n_min: float = 0.0, 1311 edit_n_max: float = 1.0, 1312 edit_n_avg: int = 1, 1313 save_path: str = None, 1314 batch_size: int = 1, 1315 debug: bool = False, 1316 ): 1317 start_time = time.time() 1318 1319 if audio2audio_enable and ref_audio_input is not None: 1320 task = "audio2audio" 1321 1322 if not self.loaded: 1323 nfo("Checkpoint not loaded, loading checkpoint...") 1324 if self.quantized: 1325 self.load_quantized_checkpoint(self.checkpoint_dir) 1326 else: 1327 self.load_checkpoint(self.checkpoint_dir) 1328 1329 self.load_lora(lora_name_or_path, lora_weight) 1330 load_model_cost = time.time() - start_time 1331 nfo(f"Model loaded in {load_model_cost:.2f} seconds.") 1332 1333 start_time = time.time() 1334 1335 random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds) 1336 retake_random_generators, actual_retake_seeds = self.set_seeds(batch_size, retake_seeds) 1337 1338 if isinstance(oss_steps, str) and len(oss_steps) > 0: 1339 oss_steps = list(map(int, oss_steps.split(","))) 1340 else: 1341 oss_steps = [] 1342 1343 texts = [prompt] 1344 encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts) 1345 encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1) 1346 text_attention_mask = text_attention_mask.repeat(batch_size, 1) 1347 1348 encoder_text_hidden_states_null = None 1349 if use_erg_tag: 1350 encoder_text_hidden_states_null = self.get_text_embeddings_null(texts) 1351 encoder_text_hidden_states_null = encoder_text_hidden_states_null.repeat(batch_size, 1, 1) 1352 1353 # not support for released checkpoint 1354 speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype) 1355 1356 # 6 lyric 1357 lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long() 1358 lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long() 1359 if len(lyrics) > 0: 1360 lyric_token_idx = self.tokenize_lyrics(lyrics, debug=debug) 1361 lyric_mask = [1] * len(lyric_token_idx) 1362 lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1) 1363 lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1) 1364 1365 if audio_duration <= 0: 1366 audio_duration = random.uniform(30.0, 240.0) 1367 nfo(f"random audio duration: {audio_duration}") 1368 1369 end_time = time.time() 1370 preprocess_time_cost = end_time - start_time 1371 start_time = end_time 1372 1373 add_retake_noise = task in ("retake", "repaint", "extend") 1374 # retake equal to repaint 1375 if task == "retake": 1376 repaint_start = 0 1377 repaint_end = audio_duration 1378 1379 src_latents = None 1380 if src_audio_path is not None: 1381 assert src_audio_path is not None and task in ( 1382 "repaint", 1383 "edit", 1384 "extend", 1385 ), "src_audio_path is required for retake/repaint/extend task" 1386 assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist" 1387 src_latents = self.infer_latents(src_audio_path) 1388 1389 ref_latents = None 1390 if ref_audio_input is not None and audio2audio_enable: 1391 assert ref_audio_input is not None, "ref_audio_input is required for audio2audio task" 1392 assert os.path.exists(ref_audio_input), f"ref_audio_input {ref_audio_input} does not exist" 1393 ref_latents = self.infer_latents(ref_audio_input) 1394 1395 if task == "edit": 1396 texts = [edit_target_prompt] 1397 target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts) 1398 target_encoder_text_hidden_states = target_encoder_text_hidden_states.repeat(batch_size, 1, 1) 1399 target_text_attention_mask = target_text_attention_mask.repeat(batch_size, 1) 1400 1401 target_lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long() 1402 target_lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long() 1403 if len(edit_target_lyrics) > 0: 1404 target_lyric_token_idx = self.tokenize_lyrics(edit_target_lyrics, debug=True) 1405 target_lyric_mask = [1] * len(target_lyric_token_idx) 1406 target_lyric_token_idx = torch.tensor(target_lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1) 1407 target_lyric_mask = torch.tensor(target_lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1) 1408 1409 target_speaker_embeds = speaker_embeds.clone() 1410 1411 target_latents = self.flowedit_diffusion_process( 1412 encoder_text_hidden_states=encoder_text_hidden_states, 1413 text_attention_mask=text_attention_mask, 1414 speaker_embds=speaker_embeds, 1415 lyric_token_ids=lyric_token_idx, 1416 lyric_mask=lyric_mask, 1417 target_encoder_text_hidden_states=target_encoder_text_hidden_states, 1418 target_text_attention_mask=target_text_attention_mask, 1419 target_speaker_embeds=target_speaker_embeds, 1420 target_lyric_token_ids=target_lyric_token_idx, 1421 target_lyric_mask=target_lyric_mask, 1422 src_latents=src_latents, 1423 random_generators=retake_random_generators, # more diversity 1424 infer_steps=infer_step, 1425 guidance_scale=guidance_scale, 1426 n_min=edit_n_min, 1427 n_max=edit_n_max, 1428 n_avg=edit_n_avg, 1429 scheduler_type=scheduler_type, 1430 ) 1431 else: 1432 target_latents = self.text2music_diffusion_process( 1433 duration=audio_duration, 1434 encoder_text_hidden_states=encoder_text_hidden_states, 1435 text_attention_mask=text_attention_mask, 1436 speaker_embds=speaker_embeds, 1437 lyric_token_ids=lyric_token_idx, 1438 lyric_mask=lyric_mask, 1439 guidance_scale=guidance_scale, 1440 omega_scale=omega_scale, 1441 infer_steps=infer_step, 1442 random_generators=random_generators, 1443 scheduler_type=scheduler_type, 1444 cfg_type=cfg_type, 1445 guidance_interval=guidance_interval, 1446 guidance_interval_decay=guidance_interval_decay, 1447 min_guidance_scale=min_guidance_scale, 1448 oss_steps=oss_steps, 1449 encoder_text_hidden_states_null=encoder_text_hidden_states_null, 1450 use_erg_lyric=use_erg_lyric, 1451 use_erg_diffusion=use_erg_diffusion, 1452 retake_random_generators=retake_random_generators, 1453 retake_variance=retake_variance, 1454 add_retake_noise=add_retake_noise, 1455 guidance_scale_text=guidance_scale_text, 1456 guidance_scale_lyric=guidance_scale_lyric, 1457 repaint_start=repaint_start, 1458 repaint_end=repaint_end, 1459 src_latents=src_latents, 1460 audio2audio_enable=audio2audio_enable, 1461 ref_audio_strength=ref_audio_strength, 1462 ref_latents=ref_latents, 1463 ) 1464 1465 end_time = time.time() 1466 diffusion_time_cost = end_time - start_time 1467 start_time = end_time 1468 1469 output_paths = self.latents2audio( 1470 latents=target_latents, 1471 target_wav_duration_second=audio_duration, 1472 save_path=save_path, 1473 format=format, 1474 ) 1475 1476 # Clean up memory after generation 1477 empty_cache 1478 1479 end_time = time.time() 1480 latent2audio_time_cost = end_time - start_time 1481 timecosts = { 1482 "preprocess": preprocess_time_cost, 1483 "diffusion": diffusion_time_cost, 1484 "latent2audio": latent2audio_time_cost, 1485 } 1486 1487 input_params_json = { 1488 "format": format, 1489 "lora_name_or_path": lora_name_or_path, 1490 "lora_weight": lora_weight, 1491 "task": task, 1492 "prompt": prompt if task != "edit" else edit_target_prompt, 1493 "lyrics": lyrics if task != "edit" else edit_target_lyrics, 1494 "audio_duration": audio_duration, 1495 "infer_step": infer_step, 1496 "guidance_scale": guidance_scale, 1497 "scheduler_type": scheduler_type, 1498 "cfg_type": cfg_type, 1499 "omega_scale": omega_scale, 1500 "guidance_interval": guidance_interval, 1501 "guidance_interval_decay": guidance_interval_decay, 1502 "min_guidance_scale": min_guidance_scale, 1503 "use_erg_tag": use_erg_tag, 1504 "use_erg_lyric": use_erg_lyric, 1505 "use_erg_diffusion": use_erg_diffusion, 1506 "oss_steps": oss_steps, 1507 "timecosts": timecosts, 1508 "actual_seeds": actual_seeds, 1509 "retake_seeds": actual_retake_seeds, 1510 "retake_variance": retake_variance, 1511 "guidance_scale_text": guidance_scale_text, 1512 "guidance_scale_lyric": guidance_scale_lyric, 1513 "repaint_start": repaint_start, 1514 "repaint_end": repaint_end, 1515 "edit_n_min": edit_n_min, 1516 "edit_n_max": edit_n_max, 1517 "edit_n_avg": edit_n_avg, 1518 "src_audio_path": src_audio_path, 1519 "edit_target_prompt": edit_target_prompt, 1520 "edit_target_lyrics": edit_target_lyrics, 1521 "audio2audio_enable": audio2audio_enable, 1522 "ref_audio_strength": ref_audio_strength, 1523 "ref_audio_input": ref_audio_input, 1524 } 1525 # save input_params_json 1526 for output_audio_path in output_paths: 1527 input_params_json_save_path = output_audio_path.replace(f".{format}", "_input_params.json") 1528 input_params_json["audio_path"] = output_audio_path 1529 with open(input_params_json_save_path, "w", encoding="utf-8") as f: 1530 json.dump(input_params_json, f, indent=4, ensure_ascii=False) 1531 1532 return output_paths + [input_params_json]
SUPPORT_LANGUAGES =
{'en': 259, 'de': 260, 'fr': 262, 'es': 284, 'it': 285, 'pt': 286, 'pl': 294, 'tr': 295, 'ru': 267, 'cs': 293, 'nl': 297, 'ar': 5022, 'zh': 5023, 'ja': 5412, 'hu': 5753, 'ko': 6152, 'hi': 6680}
structure_pattern =
re.compile('\\[.*?\\]')
def
ensure_directory_exists(directory):
REPO_ID =
'ACE-Step/ACE-Step-v1-3.5B'
REPO_ID_QUANT =
'ACE-Step/ACE-Step-v1-3.5B-q4-K-M'
class
ACEStepPipeline:
93class ACEStepPipeline: 94 def __init__( 95 self, 96 checkpoint_dir=None, 97 device_id=0, 98 dtype="bfloat16", 99 text_encoder_checkpoint_path=None, 100 persistent_storage_path=None, 101 torch_compile=False, 102 cpu_offload=False, 103 quantized=False, 104 overlapped_decode=False, 105 **kwargs, 106 ): 107 if not checkpoint_dir: 108 if persistent_storage_path is None: 109 checkpoint_dir = os.path.join(os.path.expanduser("~"), ".cache/ace-step/checkpoints") 110 os.makedirs(checkpoint_dir, exist_ok=True) 111 else: 112 checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints") 113 ensure_directory_exists(checkpoint_dir) 114 115 self.checkpoint_dir = checkpoint_dir 116 self.lora_path = "none" 117 self.lora_weight = 1 118 self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32 119 if gfx_device.type == "mps": 120 if self.dtype == torch.bfloat16: 121 self.dtype = torch.float16 122 123 if "ACE_PIPELINE_DTYPE" in os.environ and len(os.environ["ACE_PIPELINE_DTYPE"]): 124 self.dtype = getattr(torch, os.environ["ACE_PIPELINE_DTYPE"]) 125 self.device: torch.device = gfx_device 126 self.loaded = False 127 self.torch_compile = torch_compile 128 self.cpu_offload = cpu_offload 129 self.quantized = quantized 130 self.overlapped_decode = overlapped_decode 131 132 def get_checkpoint_path(self, checkpoint_dir, repo): 133 checkpoint_dir_models = None 134 135 if checkpoint_dir is not None: 136 required_dirs = ["music_dcae_f8c8", "music_vocoder", "ace_step_transformer", "umt5-base"] 137 all_dirs_exist = True 138 for dir_name in required_dirs: 139 dir_path = os.path.join(checkpoint_dir, dir_name) 140 if not os.path.exists(dir_path): 141 all_dirs_exist = False 142 break 143 144 if all_dirs_exist: 145 nfo(f"Load models from: {checkpoint_dir}") 146 checkpoint_dir_models = checkpoint_dir 147 148 if checkpoint_dir_models is None: 149 if checkpoint_dir is None: 150 nfo(f"Download models from Hugging Face: {repo}") 151 checkpoint_dir_models = snapshot_download(repo) 152 else: 153 nfo(f"Download models from Hugging Face: {repo}, cache to: {checkpoint_dir}") 154 checkpoint_dir_models = snapshot_download(repo, cache_dir=checkpoint_dir) 155 return checkpoint_dir_models 156 157 def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False): 158 checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID) 159 dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8") 160 vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder") 161 ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer") 162 text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base") 163 164 self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype) 165 # self.ace_step_transformer.to(self.device).eval().to(self.dtype) 166 if self.cpu_offload: 167 self.ace_step_transformer = self.ace_step_transformer.to("cpu").eval().to(self.dtype) 168 else: 169 self.ace_step_transformer = self.ace_step_transformer.to(self.device).eval().to(self.dtype) 170 if self.torch_compile: 171 self.ace_step_transformer = torch.compile(self.ace_step_transformer) 172 173 self.music_dcae = MusicDCAE( 174 dcae_checkpoint_path=dcae_checkpoint_path, 175 vocoder_checkpoint_path=vocoder_checkpoint_path, 176 ) 177 # self.music_dcae.to(self.device).eval().to(self.dtype) 178 if self.cpu_offload: # might be redundant 179 self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype) 180 else: 181 self.music_dcae = self.music_dcae.to(self.device).eval().to(self.dtype) 182 if self.torch_compile: 183 self.music_dcae = torch.compile(self.music_dcae) 184 185 lang_segment = LangSegment() 186 lang_segment.setfilters(language_filters.default) 187 self.lang_segment = lang_segment 188 self.lyric_tokenizer = VoiceBpeTokenizer() 189 190 text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval() 191 # text_encoder_model = text_encoder_model.to(self.device).to(self.dtype) 192 if self.cpu_offload: 193 text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype) 194 else: 195 text_encoder_model = text_encoder_model.to(self.device).eval().to(self.dtype) 196 text_encoder_model.requires_grad_(False) 197 self.text_encoder_model = text_encoder_model 198 if self.torch_compile: 199 self.text_encoder_model = torch.compile(self.text_encoder_model) 200 201 self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path) 202 self.loaded = True 203 204 # compile 205 if self.torch_compile: 206 if export_quantized_weights: 207 from torch.ao.quantization import ( 208 Int4WeightOnlyConfig, 209 quantize_, 210 ) 211 212 group_size = 128 213 use_hqq = True 214 quantize_( 215 self.ace_step_transformer, 216 Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq), 217 ) 218 quantize_( 219 self.text_encoder_model, 220 Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq), 221 ) 222 223 # save quantized weights 224 torch.save( 225 self.ace_step_transformer.state_dict(), 226 os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), 227 ) 228 print( 229 "Quantized Weights Saved to: ", 230 os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), 231 ) 232 torch.save( 233 self.text_encoder_model.state_dict(), 234 os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), 235 ) 236 print( 237 "Quantized Weights Saved to: ", 238 os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), 239 ) 240 241 def load_quantized_checkpoint(self, checkpoint_dir=None): 242 checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID_QUANT) 243 dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8") 244 vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder") 245 ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer") 246 text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base") 247 248 self.music_dcae = MusicDCAE( 249 dcae_checkpoint_path=dcae_checkpoint_path, 250 vocoder_checkpoint_path=vocoder_checkpoint_path, 251 ) 252 if self.cpu_offload: 253 self.music_dcae.eval().to(self.dtype).to(self.device) 254 else: 255 self.music_dcae.eval().to(self.dtype).to("cpu") 256 self.music_dcae = torch.compile(self.music_dcae) 257 258 self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path) 259 self.ace_step_transformer.eval().to(self.dtype).to("cpu") 260 self.ace_step_transformer = torch.compile(self.ace_step_transformer) 261 self.ace_step_transformer.load_state_dict( 262 torch.load( 263 os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), 264 map_location=self.device, 265 ), 266 assign=True, 267 ) 268 self.ace_step_transformer.torchao_quantized = True 269 270 self.text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path) 271 self.text_encoder_model.eval().to(self.dtype).to("cpu") 272 self.text_encoder_model = torch.compile(self.text_encoder_model) 273 self.text_encoder_model.load_state_dict( 274 torch.load( 275 os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), 276 map_location=self.device, 277 ), 278 assign=True, 279 ) 280 self.text_encoder_model.torchao_quantized = True 281 282 self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path) 283 284 lang_segment = LangSegment() 285 lang_segment.setfilters(language_filters.default) 286 self.lang_segment = lang_segment 287 self.lyric_tokenizer = VoiceBpeTokenizer() 288 289 self.loaded = True 290 291 @cpu_offload("text_encoder_model") 292 def get_text_embeddings(self, texts, text_max_length=256): 293 inputs = self.text_tokenizer( 294 texts, 295 return_tensors="pt", 296 padding=True, 297 truncation=True, 298 max_length=text_max_length, 299 ) 300 inputs = {key: value.to(self.device) for key, value in inputs.items()} 301 if self.text_encoder_model.device != self.device: 302 self.text_encoder_model.to(self.device) 303 with torch.no_grad(): 304 outputs = self.text_encoder_model(**inputs) 305 last_hidden_states = outputs.last_hidden_state 306 attention_mask = inputs["attention_mask"] 307 return last_hidden_states, attention_mask 308 309 @cpu_offload("text_encoder_model") 310 def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10): 311 inputs = self.text_tokenizer( 312 texts, 313 return_tensors="pt", 314 padding=True, 315 truncation=True, 316 max_length=text_max_length, 317 ) 318 inputs = {key: value.to(self.device) for key, value in inputs.items()} 319 if self.text_encoder_model.device != self.device: 320 self.text_encoder_model.to(self.device) 321 322 def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10): 323 handlers = [] 324 325 def hook(module, input, output): 326 output[:] *= tau 327 return output 328 329 for i in range(l_min, l_max): 330 handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook) 331 handlers.append(handler) 332 333 with torch.no_grad(): 334 outputs = self.text_encoder_model(**inputs) 335 last_hidden_states = outputs.last_hidden_state 336 337 for hook in handlers: 338 hook.remove() 339 340 return last_hidden_states 341 342 last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max) 343 return last_hidden_states 344 345 def set_seeds(self, batch_size, manual_seeds=None): 346 processed_input_seeds = None 347 if manual_seeds is not None: 348 if isinstance(manual_seeds, str): 349 if "," in manual_seeds: 350 processed_input_seeds = list(map(int, manual_seeds.split(","))) 351 elif manual_seeds.isdigit(): 352 processed_input_seeds = int(manual_seeds) 353 elif isinstance(manual_seeds, list) and all(isinstance(s, int) for s in manual_seeds): 354 if len(manual_seeds) > 0: 355 processed_input_seeds = list(manual_seeds) 356 elif isinstance(manual_seeds, int): 357 processed_input_seeds = manual_seeds 358 random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)] 359 actual_seeds = [] 360 for i in range(batch_size): 361 current_seed_for_generator = None 362 if processed_input_seeds is None: 363 current_seed_for_generator = torch.randint(0, 2**32, (1,)).item() 364 elif isinstance(processed_input_seeds, int): 365 current_seed_for_generator = processed_input_seeds 366 elif isinstance(processed_input_seeds, list): 367 if i < len(processed_input_seeds): 368 current_seed_for_generator = processed_input_seeds[i] 369 else: 370 current_seed_for_generator = processed_input_seeds[-1] 371 if current_seed_for_generator is None: 372 current_seed_for_generator = torch.randint(0, 2**32, (1,)).item() 373 random_generators[i].manual_seed(current_seed_for_generator) 374 actual_seeds.append(current_seed_for_generator) 375 return random_generators, actual_seeds 376 377 def get_lang(self, text): 378 language = "en" 379 try: 380 _ = self.lang_segment.getTexts(text) 381 langCounts = self.lang_segment.getCounts() 382 language = langCounts[0][0] 383 if len(langCounts) > 1 and language == "en": 384 language = langCounts[1][0] 385 except Exception as err: 386 language = "en" 387 return language 388 389 def tokenize_lyrics(self, lyrics, debug=False): 390 lines = lyrics.split("\n") 391 lyric_token_idx = [261] 392 for line in lines: 393 line = line.strip() 394 if not line: 395 lyric_token_idx += [2] 396 continue 397 398 lang = self.get_lang(line) 399 400 if lang not in SUPPORT_LANGUAGES: 401 lang = "en" 402 if "zh" in lang: 403 lang = "zh" 404 if "spa" in lang: 405 lang = "es" 406 407 try: 408 if structure_pattern.match(line): 409 token_idx = self.lyric_tokenizer.encode(line, "en") 410 else: 411 token_idx = self.lyric_tokenizer.encode(line, lang) 412 if debug: 413 toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx]) 414 nfo(f"debbug {line} --> {lang} --> {toks}") 415 lyric_token_idx = lyric_token_idx + token_idx + [2] 416 except Exception as e: 417 print("tokenize error", e, "for line", line, "major_language", lang) 418 return lyric_token_idx 419 420 @cpu_offload("ace_step_transformer") 421 def calc_v( 422 self, 423 zt_src, 424 zt_tar, 425 t, 426 encoder_text_hidden_states, 427 text_attention_mask, 428 target_encoder_text_hidden_states, 429 target_text_attention_mask, 430 speaker_embds, 431 target_speaker_embeds, 432 lyric_token_ids, 433 lyric_mask, 434 target_lyric_token_ids, 435 target_lyric_mask, 436 do_classifier_free_guidance=False, 437 guidance_scale=1.0, 438 target_guidance_scale=1.0, 439 cfg_type="apg", 440 attention_mask=None, 441 momentum_buffer=None, 442 momentum_buffer_tar=None, 443 return_src_pred=True, 444 ): 445 noise_pred_src = None 446 if return_src_pred: 447 src_latent_model_input = torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src 448 timestep = t.expand(src_latent_model_input.shape[0]) 449 # source 450 noise_pred_src = self.ace_step_transformer( 451 hidden_states=src_latent_model_input, 452 attention_mask=attention_mask, 453 encoder_text_hidden_states=encoder_text_hidden_states, 454 text_attention_mask=text_attention_mask, 455 speaker_embeds=speaker_embds, 456 lyric_token_idx=lyric_token_ids, 457 lyric_mask=lyric_mask, 458 timestep=timestep, 459 ).sample 460 461 if do_classifier_free_guidance: 462 noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk(2) 463 if cfg_type == "apg": 464 noise_pred_src = apg_forward( 465 pred_cond=noise_pred_with_cond_src, 466 pred_uncond=noise_pred_uncond_src, 467 guidance_scale=guidance_scale, 468 momentum_buffer=momentum_buffer, 469 ) 470 elif cfg_type == "cfg": 471 noise_pred_src = cfg_forward( 472 cond_output=noise_pred_with_cond_src, 473 uncond_output=noise_pred_uncond_src, 474 cfg_strength=guidance_scale, 475 ) 476 477 tar_latent_model_input = torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar 478 timestep = t.expand(tar_latent_model_input.shape[0]) 479 # target 480 noise_pred_tar = self.ace_step_transformer( 481 hidden_states=tar_latent_model_input, 482 attention_mask=attention_mask, 483 encoder_text_hidden_states=target_encoder_text_hidden_states, 484 text_attention_mask=target_text_attention_mask, 485 speaker_embeds=target_speaker_embeds, 486 lyric_token_idx=target_lyric_token_ids, 487 lyric_mask=target_lyric_mask, 488 timestep=timestep, 489 ).sample 490 491 if do_classifier_free_guidance: 492 noise_pred_with_cond_tar, noise_pred_uncond_tar = noise_pred_tar.chunk(2) 493 if cfg_type == "apg": 494 noise_pred_tar = apg_forward( 495 pred_cond=noise_pred_with_cond_tar, 496 pred_uncond=noise_pred_uncond_tar, 497 guidance_scale=target_guidance_scale, 498 momentum_buffer=momentum_buffer_tar, 499 ) 500 elif cfg_type == "cfg": 501 noise_pred_tar = cfg_forward( 502 cond_output=noise_pred_with_cond_tar, 503 uncond_output=noise_pred_uncond_tar, 504 cfg_strength=target_guidance_scale, 505 ) 506 return noise_pred_src, noise_pred_tar 507 508 @torch.no_grad() 509 def flowedit_diffusion_process( 510 self, 511 encoder_text_hidden_states, 512 text_attention_mask, 513 speaker_embds, 514 lyric_token_ids, 515 lyric_mask, 516 target_encoder_text_hidden_states, 517 target_text_attention_mask, 518 target_speaker_embeds, 519 target_lyric_token_ids, 520 target_lyric_mask, 521 src_latents, 522 random_generators=None, 523 infer_steps=60, 524 guidance_scale=15.0, 525 n_min=0, 526 n_max=1.0, 527 n_avg=1, 528 scheduler_type="euler", 529 ): 530 do_classifier_free_guidance = True 531 if guidance_scale == 0.0 or guidance_scale == 1.0: 532 do_classifier_free_guidance = False 533 534 target_guidance_scale = guidance_scale 535 bsz = encoder_text_hidden_states.shape[0] 536 537 scheduler = FlowMatchEulerDiscreteScheduler( 538 num_train_timesteps=1000, 539 shift=3.0, 540 ) 541 542 T_steps = infer_steps 543 frame_length = src_latents.shape[-1] 544 attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype) 545 546 timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, self.device, timesteps=None) 547 548 if do_classifier_free_guidance: 549 attention_mask = torch.cat([attention_mask] * 2, dim=0) 550 551 encoder_text_hidden_states = torch.cat( 552 [ 553 encoder_text_hidden_states, 554 torch.zeros_like(encoder_text_hidden_states), 555 ], 556 0, 557 ) 558 text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0) 559 560 target_encoder_text_hidden_states = torch.cat( 561 [ 562 target_encoder_text_hidden_states, 563 torch.zeros_like(target_encoder_text_hidden_states), 564 ], 565 0, 566 ) 567 target_text_attention_mask = torch.cat([target_text_attention_mask] * 2, dim=0) 568 569 speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0) 570 target_speaker_embeds = torch.cat([target_speaker_embeds, torch.zeros_like(target_speaker_embeds)], 0) 571 572 lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0) 573 lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0) 574 575 target_lyric_token_ids = torch.cat([target_lyric_token_ids, torch.zeros_like(target_lyric_token_ids)], 0) 576 target_lyric_mask = torch.cat([target_lyric_mask, torch.zeros_like(target_lyric_mask)], 0) 577 578 momentum_buffer = MomentumBuffer() 579 momentum_buffer_tar = MomentumBuffer() 580 x_src = src_latents 581 zt_edit = x_src.clone() 582 xt_tar = None 583 n_min = int(infer_steps * n_min) 584 n_max = int(infer_steps * n_max) 585 586 nfo("flowedit start from {} to {}".format(n_min, n_max)) 587 588 for i, t in tqdm(enumerate(timesteps), total=T_steps): 589 if i < n_min: 590 continue 591 592 t_i = t / 1000 593 594 if i + 1 < len(timesteps): 595 t_im1 = (timesteps[i + 1]) / 1000 596 else: 597 t_im1 = torch.zeros_like(t_i).to(self.device) 598 599 if i < n_max: 600 # Calculate the average of the V predictions 601 V_delta_avg = torch.zeros_like(x_src) 602 for k in range(n_avg): 603 fwd_noise = randn_tensor( 604 shape=x_src.shape, 605 generator=random_generators, 606 device=self.device, 607 dtype=self.dtype, 608 ) 609 610 zt_src = (1 - t_i) * x_src + (t_i) * fwd_noise 611 612 zt_tar = zt_edit + zt_src - x_src 613 614 Vt_src, Vt_tar = self.calc_v( 615 zt_src=zt_src, 616 zt_tar=zt_tar, 617 t=t, 618 encoder_text_hidden_states=encoder_text_hidden_states, 619 text_attention_mask=text_attention_mask, 620 target_encoder_text_hidden_states=target_encoder_text_hidden_states, 621 target_text_attention_mask=target_text_attention_mask, 622 speaker_embds=speaker_embds, 623 target_speaker_embeds=target_speaker_embeds, 624 lyric_token_ids=lyric_token_ids, 625 lyric_mask=lyric_mask, 626 target_lyric_token_ids=target_lyric_token_ids, 627 target_lyric_mask=target_lyric_mask, 628 do_classifier_free_guidance=do_classifier_free_guidance, 629 guidance_scale=guidance_scale, 630 target_guidance_scale=target_guidance_scale, 631 attention_mask=attention_mask, 632 momentum_buffer=momentum_buffer, 633 ) 634 V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src) # - (hfg - 1) * (x_src) 635 636 zt_edit = zt_edit.to(torch.float32) # arbitrary, should be settable for compatibility 637 if scheduler_type != "pingpong": 638 # propagate direct ODE 639 zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg 640 zt_edit = zt_edit.to(self.dtype) 641 else: 642 # propagate pingpong SDE 643 zt_edit_denoised = zt_edit - t_i * V_delta_avg 644 noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None) 645 prev_sample = (1 - t_im1) * zt_edit_denoised + t_im1 * noise 646 647 else: # i >= T_steps-n_min # regular sampling for last n_min steps 648 if i == n_max: 649 fwd_noise = randn_tensor( 650 shape=x_src.shape, 651 generator=random_generators, 652 device=self.device, 653 dtype=self.dtype, 654 ) 655 scheduler._init_step_index(t) 656 sigma = scheduler.sigmas[scheduler.step_index] 657 xt_src = sigma * fwd_noise + (1.0 - sigma) * x_src 658 xt_tar = zt_edit + xt_src - x_src 659 660 _, Vt_tar = self.calc_v( 661 zt_src=None, 662 zt_tar=xt_tar, 663 t=t, 664 encoder_text_hidden_states=encoder_text_hidden_states, 665 text_attention_mask=text_attention_mask, 666 target_encoder_text_hidden_states=target_encoder_text_hidden_states, 667 target_text_attention_mask=target_text_attention_mask, 668 speaker_embds=speaker_embds, 669 target_speaker_embeds=target_speaker_embeds, 670 lyric_token_ids=lyric_token_ids, 671 lyric_mask=lyric_mask, 672 target_lyric_token_ids=target_lyric_token_ids, 673 target_lyric_mask=target_lyric_mask, 674 do_classifier_free_guidance=do_classifier_free_guidance, 675 guidance_scale=guidance_scale, 676 target_guidance_scale=target_guidance_scale, 677 attention_mask=attention_mask, 678 momentum_buffer_tar=momentum_buffer_tar, 679 return_src_pred=False, 680 ) 681 682 xt_tar = xt_tar.to(torch.float32) 683 if scheduler_type != "pingpong": 684 prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar 685 prev_sample = prev_sample.to(self.dtype) 686 xt_tar = prev_sample 687 else: 688 prev_sample = xt_tar - t_i * Vt_tar 689 noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None) 690 prev_sample = (1 - t_im1) * prev_sample + t_im1 * noise 691 xt_tar = prev_sample 692 693 target_latents = zt_edit if xt_tar is None else xt_tar 694 return target_latents 695 696 def add_latents_noise( 697 self, 698 gt_latents, 699 sigma_max, 700 noise, 701 scheduler_type, 702 infer_steps, 703 ): 704 bsz = gt_latents.shape[0] 705 if scheduler_type == "euler": 706 scheduler = FlowMatchEulerDiscreteScheduler( 707 num_train_timesteps=1000, 708 shift=3.0, 709 sigma_max=sigma_max, 710 ) 711 elif scheduler_type == "heun": 712 scheduler = FlowMatchHeunDiscreteScheduler( 713 num_train_timesteps=1000, 714 shift=3.0, 715 sigma_max=sigma_max, 716 ) 717 elif scheduler_type == "pingpong": 718 scheduler = FlowMatchPingPongScheduler(num_train_timesteps=1000, shift=3.0, sigma_max=sigma_max) 719 720 infer_steps = int(sigma_max * infer_steps) 721 timesteps, num_inference_steps = retrieve_timesteps( 722 scheduler, 723 num_inference_steps=infer_steps, 724 device=self.device, 725 timesteps=None, 726 ) 727 noisy_image = gt_latents * (1 - scheduler.sigma_max) + noise * scheduler.sigma_max 728 nfo(f"{scheduler.sigma_min=} {scheduler.sigma_max=} {timesteps=} {num_inference_steps=}") 729 return noisy_image, timesteps, scheduler, num_inference_steps 730 731 @cpu_offload("ace_step_transformer") 732 @torch.no_grad() 733 def text2music_diffusion_process( 734 self, 735 duration, 736 encoder_text_hidden_states, 737 text_attention_mask, 738 speaker_embds, 739 lyric_token_ids, 740 lyric_mask, 741 random_generators=None, 742 infer_steps=60, 743 guidance_scale=15.0, 744 omega_scale=10.0, 745 scheduler_type="euler", 746 cfg_type="apg", 747 zero_steps=1, 748 use_zero_init=True, 749 guidance_interval=0.5, 750 guidance_interval_decay=1.0, 751 min_guidance_scale=3.0, 752 oss_steps=[], 753 encoder_text_hidden_states_null=None, 754 use_erg_lyric=False, 755 use_erg_diffusion=False, 756 retake_random_generators=None, 757 retake_variance=0.5, 758 add_retake_noise=False, 759 guidance_scale_text=0.0, 760 guidance_scale_lyric=0.0, 761 repaint_start=0, 762 repaint_end=0, 763 src_latents=None, 764 audio2audio_enable=False, 765 ref_audio_strength=0.5, 766 ref_latents=None, 767 ): 768 nfo("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale)) 769 do_classifier_free_guidance = True 770 if guidance_scale == 0.0 or guidance_scale == 1.0: 771 do_classifier_free_guidance = False 772 773 do_double_condition_guidance = False 774 if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0: 775 do_double_condition_guidance = True 776 nfo( 777 "do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format( 778 do_double_condition_guidance, 779 guidance_scale_text, 780 guidance_scale_lyric, 781 ) 782 ) 783 784 bsz = encoder_text_hidden_states.shape[0] 785 786 if scheduler_type == "euler": 787 scheduler = FlowMatchEulerDiscreteScheduler( 788 num_train_timesteps=1000, 789 shift=3.0, 790 ) 791 elif scheduler_type == "heun": 792 scheduler = FlowMatchHeunDiscreteScheduler( 793 num_train_timesteps=1000, 794 shift=3.0, 795 ) 796 elif scheduler_type == "pingpong": 797 scheduler = FlowMatchPingPongScheduler( 798 num_train_timesteps=1000, 799 shift=3.0, 800 ) 801 802 frame_length = int(duration * 44100 / 512 / 8) 803 if src_latents is not None: 804 frame_length = src_latents.shape[-1] 805 806 if ref_latents is not None: 807 frame_length = ref_latents.shape[-1] 808 809 if len(oss_steps) > 0: 810 infer_steps = max(oss_steps) 811 scheduler.set_timesteps 812 timesteps, num_inference_steps = retrieve_timesteps( 813 scheduler, 814 num_inference_steps=infer_steps, 815 device=self.device, 816 timesteps=None, 817 ) 818 new_timesteps = torch.zeros(len(oss_steps), dtype=self.dtype, device=self.device) 819 for idx in range(len(oss_steps)): 820 new_timesteps[idx] = timesteps[oss_steps[idx] - 1] 821 num_inference_steps = len(oss_steps) 822 sigmas = (new_timesteps / 1000).float().cpu().numpy() 823 timesteps, num_inference_steps = retrieve_timesteps( 824 scheduler, 825 num_inference_steps=num_inference_steps, 826 device=self.device, 827 sigmas=sigmas, 828 ) 829 nfo(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}") 830 else: 831 timesteps, num_inference_steps = retrieve_timesteps( 832 scheduler, 833 num_inference_steps=infer_steps, 834 device=self.device, 835 timesteps=None, 836 ) 837 838 target_latents = randn_tensor( 839 shape=(bsz, 8, 16, frame_length), 840 generator=random_generators, 841 device=self.device, 842 dtype=self.dtype, 843 ) 844 845 is_repaint = False 846 is_extend = False 847 848 if add_retake_noise: 849 n_min = int(infer_steps * (1 - retake_variance)) 850 retake_variance = torch.tensor(retake_variance * math.pi / 2).to(self.device).to(self.dtype) 851 retake_latents = randn_tensor( 852 shape=(bsz, 8, 16, frame_length), 853 generator=retake_random_generators, 854 device=self.device, 855 dtype=self.dtype, 856 ) 857 repaint_start_frame = int(repaint_start * 44100 / 512 / 8) 858 repaint_end_frame = int(repaint_end * 44100 / 512 / 8) 859 x0 = src_latents 860 # retake 861 is_repaint = repaint_end_frame - repaint_start_frame != frame_length 862 863 is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length) 864 if is_extend: 865 is_repaint = True 866 867 # TODO: train a mask aware repainting controlnet 868 # to make sure mean = 0, std = 1 869 if not is_repaint: 870 target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents 871 elif not is_extend: 872 # if repaint_end_frame 873 repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) 874 repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0 875 repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents 876 repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents) 877 zt_edit = x0.clone() 878 z0 = repaint_noise 879 elif is_extend: 880 to_right_pad_gt_latents = None 881 to_left_pad_gt_latents = None 882 gt_latents = src_latents 883 src_latents_length = gt_latents.shape[-1] 884 max_infer_fame_length = int(240 * 44100 / 512 / 8) 885 left_pad_frame_length = 0 886 right_pad_frame_length = 0 887 right_trim_length = 0 888 left_trim_length = 0 889 if repaint_start_frame < 0: 890 left_pad_frame_length = abs(repaint_start_frame) 891 frame_length = left_pad_frame_length + gt_latents.shape[-1] 892 extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0) 893 if frame_length > max_infer_fame_length: 894 right_trim_length = frame_length - max_infer_fame_length 895 extend_gt_latents = extend_gt_latents[:, :, :, :max_infer_fame_length] 896 to_right_pad_gt_latents = extend_gt_latents[:, :, :, -right_trim_length:] 897 frame_length = max_infer_fame_length 898 repaint_start_frame = 0 899 gt_latents = extend_gt_latents 900 901 if repaint_end_frame > src_latents_length: 902 right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1] 903 frame_length = gt_latents.shape[-1] + right_pad_frame_length 904 extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0) 905 if frame_length > max_infer_fame_length: 906 left_trim_length = frame_length - max_infer_fame_length 907 extend_gt_latents = extend_gt_latents[:, :, :, -max_infer_fame_length:] 908 to_left_pad_gt_latents = extend_gt_latents[:, :, :, :left_trim_length] 909 frame_length = max_infer_fame_length 910 repaint_end_frame = frame_length 911 gt_latents = extend_gt_latents 912 913 repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) 914 if left_pad_frame_length > 0: 915 repaint_mask[:, :, :, :left_pad_frame_length] = 1.0 916 if right_pad_frame_length > 0: 917 repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0 918 x0 = gt_latents 919 padd_list = [] 920 if left_pad_frame_length > 0: 921 padd_list.append(retake_latents[:, :, :, :left_pad_frame_length]) 922 padd_list.append( 923 target_latents[ 924 :, 925 :, 926 :, 927 left_trim_length : target_latents.shape[-1] - right_trim_length, 928 ] 929 ) 930 if right_pad_frame_length > 0: 931 padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:]) 932 target_latents = torch.cat(padd_list, dim=-1) 933 assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}" 934 zt_edit = x0.clone() 935 z0 = target_latents 936 937 if audio2audio_enable and ref_latents is not None: 938 nfo(f"audio2audio_enable: {audio2audio_enable}, ref_latents: {ref_latents.shape}") 939 target_latents, timesteps, scheduler, num_inference_steps = self.add_latents_noise( 940 gt_latents=ref_latents, 941 sigma_max=(1 - ref_audio_strength), 942 noise=target_latents, 943 scheduler_type=scheduler_type, 944 infer_steps=infer_steps, 945 ) 946 947 attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype) 948 949 # guidance interval 950 start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2)) 951 end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5)) 952 nfo(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}") 953 954 momentum_buffer = MomentumBuffer() 955 956 def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6): 957 handlers = [] 958 959 def hook(module, input, output): 960 output[:] *= tau 961 return output 962 963 for i in range(l_min, l_max): 964 handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook) 965 handlers.append(handler) 966 967 encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs) 968 969 for hook in handlers: 970 hook.remove() 971 972 return encoder_hidden_states 973 974 # P(speaker, text, lyric) 975 encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode( 976 encoder_text_hidden_states, 977 text_attention_mask, 978 speaker_embds, 979 lyric_token_ids, 980 lyric_mask, 981 ) 982 983 if use_erg_lyric: 984 # P(null_speaker, text_weaker, lyric_weaker) 985 encoder_hidden_states_null = forward_encoder_with_temperature( 986 self, 987 inputs={ 988 "encoder_text_hidden_states": ( 989 encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states) 990 ), 991 "text_attention_mask": text_attention_mask, 992 "speaker_embeds": torch.zeros_like(speaker_embds), 993 "lyric_token_idx": lyric_token_ids, 994 "lyric_mask": lyric_mask, 995 }, 996 ) 997 else: 998 # P(null_speaker, null_text, null_lyric) 999 encoder_hidden_states_null, _ = self.ace_step_transformer.encode( 1000 torch.zeros_like(encoder_text_hidden_states), 1001 text_attention_mask, 1002 torch.zeros_like(speaker_embds), 1003 torch.zeros_like(lyric_token_ids), 1004 lyric_mask, 1005 ) 1006 1007 encoder_hidden_states_no_lyric = None 1008 if do_double_condition_guidance: 1009 # P(null_speaker, text, lyric_weaker) 1010 if use_erg_lyric: 1011 encoder_hidden_states_no_lyric = forward_encoder_with_temperature( 1012 self, 1013 inputs={ 1014 "encoder_text_hidden_states": encoder_text_hidden_states, 1015 "text_attention_mask": text_attention_mask, 1016 "speaker_embeds": torch.zeros_like(speaker_embds), 1017 "lyric_token_idx": lyric_token_ids, 1018 "lyric_mask": lyric_mask, 1019 }, 1020 ) 1021 # P(null_speaker, text, no_lyric) 1022 else: 1023 encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode( 1024 encoder_text_hidden_states, 1025 text_attention_mask, 1026 torch.zeros_like(speaker_embds), 1027 torch.zeros_like(lyric_token_ids), 1028 lyric_mask, 1029 ) 1030 1031 def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20): 1032 handlers = [] 1033 1034 def hook(module, input, output): 1035 output[:] *= tau 1036 return output 1037 1038 for i in range(l_min, l_max): 1039 handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook) 1040 handlers.append(handler) 1041 handler = self.ace_step_transformer.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook) 1042 handlers.append(handler) 1043 1044 sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample 1045 1046 for hook in handlers: 1047 hook.remove() 1048 1049 return sample 1050 1051 for i, t in tqdm(enumerate(timesteps), total=num_inference_steps): 1052 if is_repaint: 1053 if i < n_min: 1054 continue 1055 elif i == n_min: 1056 t_i = t / 1000 1057 zt_src = (1 - t_i) * x0 + (t_i) * z0 1058 target_latents = zt_edit + zt_src - x0 1059 nfo(f"repaint start from {n_min} add {t_i} level of noise") 1060 1061 # expand the latents if we are doing classifier free guidance 1062 latents = target_latents 1063 1064 is_in_guidance_interval = start_idx <= i < end_idx 1065 if is_in_guidance_interval and do_classifier_free_guidance: 1066 # compute current guidance scale 1067 if guidance_interval_decay > 0: 1068 # Linearly interpolate to calculate the current guidance scale 1069 progress = (i - start_idx) / (end_idx - start_idx - 1) # 归一化到[0,1] 1070 current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay 1071 else: 1072 current_guidance_scale = guidance_scale 1073 1074 latent_model_input = latents 1075 timestep = t.expand(latent_model_input.shape[0]) 1076 output_length = latent_model_input.shape[-1] 1077 # P(x|speaker, text, lyric) 1078 noise_pred_with_cond = self.ace_step_transformer.decode( 1079 hidden_states=latent_model_input, 1080 attention_mask=attention_mask, 1081 encoder_hidden_states=encoder_hidden_states, 1082 encoder_hidden_mask=encoder_hidden_mask, 1083 output_length=output_length, 1084 timestep=timestep, 1085 ).sample 1086 1087 noise_pred_with_only_text_cond = None 1088 if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None: 1089 noise_pred_with_only_text_cond = self.ace_step_transformer.decode( 1090 hidden_states=latent_model_input, 1091 attention_mask=attention_mask, 1092 encoder_hidden_states=encoder_hidden_states_no_lyric, 1093 encoder_hidden_mask=encoder_hidden_mask, 1094 output_length=output_length, 1095 timestep=timestep, 1096 ).sample 1097 1098 if use_erg_diffusion: 1099 noise_pred_uncond = forward_diffusion_with_temperature( 1100 self, 1101 hidden_states=latent_model_input, 1102 timestep=timestep, 1103 inputs={ 1104 "encoder_hidden_states": encoder_hidden_states_null, 1105 "encoder_hidden_mask": encoder_hidden_mask, 1106 "output_length": output_length, 1107 "attention_mask": attention_mask, 1108 }, 1109 ) 1110 else: 1111 noise_pred_uncond = self.ace_step_transformer.decode( 1112 hidden_states=latent_model_input, 1113 attention_mask=attention_mask, 1114 encoder_hidden_states=encoder_hidden_states_null, 1115 encoder_hidden_mask=encoder_hidden_mask, 1116 output_length=output_length, 1117 timestep=timestep, 1118 ).sample 1119 1120 if do_double_condition_guidance and noise_pred_with_only_text_cond is not None: 1121 noise_pred = cfg_double_condition_forward( 1122 cond_output=noise_pred_with_cond, 1123 uncond_output=noise_pred_uncond, 1124 only_text_cond_output=noise_pred_with_only_text_cond, 1125 guidance_scale_text=guidance_scale_text, 1126 guidance_scale_lyric=guidance_scale_lyric, 1127 ) 1128 1129 elif cfg_type == "apg": 1130 noise_pred = apg_forward( 1131 pred_cond=noise_pred_with_cond, 1132 pred_uncond=noise_pred_uncond, 1133 guidance_scale=current_guidance_scale, 1134 momentum_buffer=momentum_buffer, 1135 ) 1136 elif cfg_type == "cfg": 1137 noise_pred = cfg_forward( 1138 cond_output=noise_pred_with_cond, 1139 uncond_output=noise_pred_uncond, 1140 cfg_strength=current_guidance_scale, 1141 ) 1142 elif cfg_type == "cfg_star": 1143 noise_pred = cfg_zero_star( 1144 noise_pred_with_cond=noise_pred_with_cond, 1145 noise_pred_uncond=noise_pred_uncond, 1146 guidance_scale=current_guidance_scale, 1147 i=i, 1148 zero_steps=zero_steps, 1149 use_zero_init=use_zero_init, 1150 ) 1151 else: 1152 latent_model_input = latents 1153 timestep = t.expand(latent_model_input.shape[0]) 1154 noise_pred = self.ace_step_transformer.decode( 1155 hidden_states=latent_model_input, 1156 attention_mask=attention_mask, 1157 encoder_hidden_states=encoder_hidden_states, 1158 encoder_hidden_mask=encoder_hidden_mask, 1159 output_length=latent_model_input.shape[-1], 1160 timestep=timestep, 1161 ).sample 1162 1163 if is_repaint and i >= n_min: 1164 t_i = t / 1000 1165 if i + 1 < len(timesteps): 1166 t_im1 = (timesteps[i + 1]) / 1000 1167 else: 1168 t_im1 = torch.zeros_like(t_i).to(self.device) 1169 target_latents = target_latents.to(torch.float32) 1170 prev_sample = target_latents + (t_im1 - t_i) * noise_pred 1171 prev_sample = prev_sample.to(self.dtype) 1172 target_latents = prev_sample 1173 zt_src = (1 - t_im1) * x0 + (t_im1) * z0 1174 target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src) 1175 else: 1176 target_latents = scheduler.step( 1177 model_output=noise_pred, 1178 timestep=t, 1179 sample=target_latents, 1180 return_dict=False, 1181 omega=omega_scale, 1182 generator=random_generators[0], 1183 )[0] 1184 1185 if is_extend: 1186 if to_right_pad_gt_latents is not None: 1187 target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1) 1188 if to_left_pad_gt_latents is not None: 1189 target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0) 1190 return target_latents 1191 1192 @cpu_offload("music_dcae") 1193 def latents2audio( 1194 self, 1195 latents, 1196 target_wav_duration_second=30, 1197 sample_rate=48000, 1198 save_path=None, 1199 format="wav", 1200 ): 1201 output_audio_paths = [] 1202 bs = latents.shape[0] 1203 pred_latents = latents 1204 with torch.no_grad(): 1205 if self.overlapped_decode and target_wav_duration_second > 48: 1206 _, pred_wavs = self.music_dcae.decode_overlap(pred_latents, sr=sample_rate) 1207 else: 1208 _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate) 1209 pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs] 1210 for i in tqdm(range(bs)): 1211 output_audio_path = self.save_wav_file( 1212 pred_wavs[i], 1213 i, 1214 save_path=save_path, 1215 sample_rate=sample_rate, 1216 format=format, 1217 ) 1218 output_audio_paths.append(output_audio_path) 1219 return output_audio_paths 1220 1221 def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="wav"): 1222 if save_path is None: 1223 nfo("save_path is None, using default path ./outputs/") 1224 base_path = "./outputs" 1225 ensure_directory_exists(base_path) 1226 output_path_wav = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format 1227 else: 1228 ensure_directory_exists(os.path.dirname(save_path)) 1229 if os.path.isdir(save_path): 1230 nfo(f"Provided save_path '{save_path}' is a directory. Appending timestamped filename.") 1231 output_path_wav = os.path.join(save_path, f"output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format) 1232 else: 1233 output_path_wav = save_path 1234 1235 target_wav = target_wav.float() 1236 backend = "soundfile" 1237 if format == "ogg": 1238 backend = "sox" 1239 nfo(f"Saving audio to {output_path_wav} using backend {backend}") 1240 try: 1241 torchaudio.save(output_path_wav, target_wav, sample_rate=sample_rate, format=format, backend=backend) 1242 except (RuntimeError, ImportError, ModuleNotFoundError): 1243 import soundfile as sf # pyright: ignore[reportMissingImports] | pylint:disable=import-error 1244 1245 sf.write(output_path_wav, target_wav, sample_rate) 1246 return output_path_wav 1247 1248 @cpu_offload("music_dcae") 1249 def infer_latents(self, input_audio_path): 1250 if input_audio_path is None: 1251 return None 1252 input_audio, sr = self.music_dcae.load_audio(input_audio_path) 1253 input_audio = input_audio.unsqueeze(0) 1254 input_audio = input_audio.to(device=self.device, dtype=self.dtype) 1255 latents, _ = self.music_dcae.encode(input_audio, sr=sr) 1256 return latents 1257 1258 def load_lora(self, lora_name_or_path, lora_weight): 1259 if (lora_name_or_path != self.lora_path or lora_weight != self.lora_weight) and lora_name_or_path != "none": 1260 if not os.path.exists(lora_name_or_path): 1261 lora_download_path = snapshot_download(lora_name_or_path, cache_dir=self.checkpoint_dir) 1262 else: 1263 lora_download_path = lora_name_or_path 1264 if self.lora_path != "none": 1265 self.ace_step_transformer.unload_lora() 1266 self.ace_step_transformer.load_lora_adapter( 1267 os.path.join(lora_download_path, "pytorch_lora_weights.safetensors"), adapter_name="ace_step_lora", with_alpha=True, prefix=None 1268 ) 1269 nfo(f"Loading lora weights from: {lora_name_or_path} download path is: {lora_download_path} weight: {lora_weight}") 1270 set_weights_and_activate_adapters(self.ace_step_transformer, ["ace_step_lora"], [lora_weight]) 1271 self.lora_path = lora_name_or_path 1272 self.lora_weight = lora_weight 1273 elif self.lora_path != "none" and lora_name_or_path == "none": 1274 nfo("No lora weights to load.") 1275 self.ace_step_transformer.unload_lora() 1276 1277 def __call__( 1278 self, 1279 format: str = "wav", 1280 audio_duration: float = 60.0, 1281 prompt: str = None, 1282 lyrics: str = None, 1283 infer_step: int = 60, 1284 guidance_scale: float = 15.0, 1285 scheduler_type: str = "euler", 1286 cfg_type: str = "apg", 1287 omega_scale: int = 10.0, 1288 manual_seeds: list = None, 1289 guidance_interval: float = 0.5, 1290 guidance_interval_decay: float = 0.0, 1291 min_guidance_scale: float = 3.0, 1292 use_erg_tag: bool = True, 1293 use_erg_lyric: bool = True, 1294 use_erg_diffusion: bool = True, 1295 oss_steps: str = None, 1296 guidance_scale_text: float = 0.0, 1297 guidance_scale_lyric: float = 0.0, 1298 audio2audio_enable: bool = False, 1299 ref_audio_strength: float = 0.5, 1300 ref_audio_input: str = None, 1301 lora_name_or_path: str = "none", 1302 lora_weight: float = 1.0, 1303 retake_seeds: list = None, 1304 retake_variance: float = 0.5, 1305 task: str = "text2music", 1306 repaint_start: int = 0, 1307 repaint_end: int = 0, 1308 src_audio_path: str = None, 1309 edit_target_prompt: str = None, 1310 edit_target_lyrics: str = None, 1311 edit_n_min: float = 0.0, 1312 edit_n_max: float = 1.0, 1313 edit_n_avg: int = 1, 1314 save_path: str = None, 1315 batch_size: int = 1, 1316 debug: bool = False, 1317 ): 1318 start_time = time.time() 1319 1320 if audio2audio_enable and ref_audio_input is not None: 1321 task = "audio2audio" 1322 1323 if not self.loaded: 1324 nfo("Checkpoint not loaded, loading checkpoint...") 1325 if self.quantized: 1326 self.load_quantized_checkpoint(self.checkpoint_dir) 1327 else: 1328 self.load_checkpoint(self.checkpoint_dir) 1329 1330 self.load_lora(lora_name_or_path, lora_weight) 1331 load_model_cost = time.time() - start_time 1332 nfo(f"Model loaded in {load_model_cost:.2f} seconds.") 1333 1334 start_time = time.time() 1335 1336 random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds) 1337 retake_random_generators, actual_retake_seeds = self.set_seeds(batch_size, retake_seeds) 1338 1339 if isinstance(oss_steps, str) and len(oss_steps) > 0: 1340 oss_steps = list(map(int, oss_steps.split(","))) 1341 else: 1342 oss_steps = [] 1343 1344 texts = [prompt] 1345 encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts) 1346 encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1) 1347 text_attention_mask = text_attention_mask.repeat(batch_size, 1) 1348 1349 encoder_text_hidden_states_null = None 1350 if use_erg_tag: 1351 encoder_text_hidden_states_null = self.get_text_embeddings_null(texts) 1352 encoder_text_hidden_states_null = encoder_text_hidden_states_null.repeat(batch_size, 1, 1) 1353 1354 # not support for released checkpoint 1355 speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype) 1356 1357 # 6 lyric 1358 lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long() 1359 lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long() 1360 if len(lyrics) > 0: 1361 lyric_token_idx = self.tokenize_lyrics(lyrics, debug=debug) 1362 lyric_mask = [1] * len(lyric_token_idx) 1363 lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1) 1364 lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1) 1365 1366 if audio_duration <= 0: 1367 audio_duration = random.uniform(30.0, 240.0) 1368 nfo(f"random audio duration: {audio_duration}") 1369 1370 end_time = time.time() 1371 preprocess_time_cost = end_time - start_time 1372 start_time = end_time 1373 1374 add_retake_noise = task in ("retake", "repaint", "extend") 1375 # retake equal to repaint 1376 if task == "retake": 1377 repaint_start = 0 1378 repaint_end = audio_duration 1379 1380 src_latents = None 1381 if src_audio_path is not None: 1382 assert src_audio_path is not None and task in ( 1383 "repaint", 1384 "edit", 1385 "extend", 1386 ), "src_audio_path is required for retake/repaint/extend task" 1387 assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist" 1388 src_latents = self.infer_latents(src_audio_path) 1389 1390 ref_latents = None 1391 if ref_audio_input is not None and audio2audio_enable: 1392 assert ref_audio_input is not None, "ref_audio_input is required for audio2audio task" 1393 assert os.path.exists(ref_audio_input), f"ref_audio_input {ref_audio_input} does not exist" 1394 ref_latents = self.infer_latents(ref_audio_input) 1395 1396 if task == "edit": 1397 texts = [edit_target_prompt] 1398 target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts) 1399 target_encoder_text_hidden_states = target_encoder_text_hidden_states.repeat(batch_size, 1, 1) 1400 target_text_attention_mask = target_text_attention_mask.repeat(batch_size, 1) 1401 1402 target_lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long() 1403 target_lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long() 1404 if len(edit_target_lyrics) > 0: 1405 target_lyric_token_idx = self.tokenize_lyrics(edit_target_lyrics, debug=True) 1406 target_lyric_mask = [1] * len(target_lyric_token_idx) 1407 target_lyric_token_idx = torch.tensor(target_lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1) 1408 target_lyric_mask = torch.tensor(target_lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1) 1409 1410 target_speaker_embeds = speaker_embeds.clone() 1411 1412 target_latents = self.flowedit_diffusion_process( 1413 encoder_text_hidden_states=encoder_text_hidden_states, 1414 text_attention_mask=text_attention_mask, 1415 speaker_embds=speaker_embeds, 1416 lyric_token_ids=lyric_token_idx, 1417 lyric_mask=lyric_mask, 1418 target_encoder_text_hidden_states=target_encoder_text_hidden_states, 1419 target_text_attention_mask=target_text_attention_mask, 1420 target_speaker_embeds=target_speaker_embeds, 1421 target_lyric_token_ids=target_lyric_token_idx, 1422 target_lyric_mask=target_lyric_mask, 1423 src_latents=src_latents, 1424 random_generators=retake_random_generators, # more diversity 1425 infer_steps=infer_step, 1426 guidance_scale=guidance_scale, 1427 n_min=edit_n_min, 1428 n_max=edit_n_max, 1429 n_avg=edit_n_avg, 1430 scheduler_type=scheduler_type, 1431 ) 1432 else: 1433 target_latents = self.text2music_diffusion_process( 1434 duration=audio_duration, 1435 encoder_text_hidden_states=encoder_text_hidden_states, 1436 text_attention_mask=text_attention_mask, 1437 speaker_embds=speaker_embeds, 1438 lyric_token_ids=lyric_token_idx, 1439 lyric_mask=lyric_mask, 1440 guidance_scale=guidance_scale, 1441 omega_scale=omega_scale, 1442 infer_steps=infer_step, 1443 random_generators=random_generators, 1444 scheduler_type=scheduler_type, 1445 cfg_type=cfg_type, 1446 guidance_interval=guidance_interval, 1447 guidance_interval_decay=guidance_interval_decay, 1448 min_guidance_scale=min_guidance_scale, 1449 oss_steps=oss_steps, 1450 encoder_text_hidden_states_null=encoder_text_hidden_states_null, 1451 use_erg_lyric=use_erg_lyric, 1452 use_erg_diffusion=use_erg_diffusion, 1453 retake_random_generators=retake_random_generators, 1454 retake_variance=retake_variance, 1455 add_retake_noise=add_retake_noise, 1456 guidance_scale_text=guidance_scale_text, 1457 guidance_scale_lyric=guidance_scale_lyric, 1458 repaint_start=repaint_start, 1459 repaint_end=repaint_end, 1460 src_latents=src_latents, 1461 audio2audio_enable=audio2audio_enable, 1462 ref_audio_strength=ref_audio_strength, 1463 ref_latents=ref_latents, 1464 ) 1465 1466 end_time = time.time() 1467 diffusion_time_cost = end_time - start_time 1468 start_time = end_time 1469 1470 output_paths = self.latents2audio( 1471 latents=target_latents, 1472 target_wav_duration_second=audio_duration, 1473 save_path=save_path, 1474 format=format, 1475 ) 1476 1477 # Clean up memory after generation 1478 empty_cache 1479 1480 end_time = time.time() 1481 latent2audio_time_cost = end_time - start_time 1482 timecosts = { 1483 "preprocess": preprocess_time_cost, 1484 "diffusion": diffusion_time_cost, 1485 "latent2audio": latent2audio_time_cost, 1486 } 1487 1488 input_params_json = { 1489 "format": format, 1490 "lora_name_or_path": lora_name_or_path, 1491 "lora_weight": lora_weight, 1492 "task": task, 1493 "prompt": prompt if task != "edit" else edit_target_prompt, 1494 "lyrics": lyrics if task != "edit" else edit_target_lyrics, 1495 "audio_duration": audio_duration, 1496 "infer_step": infer_step, 1497 "guidance_scale": guidance_scale, 1498 "scheduler_type": scheduler_type, 1499 "cfg_type": cfg_type, 1500 "omega_scale": omega_scale, 1501 "guidance_interval": guidance_interval, 1502 "guidance_interval_decay": guidance_interval_decay, 1503 "min_guidance_scale": min_guidance_scale, 1504 "use_erg_tag": use_erg_tag, 1505 "use_erg_lyric": use_erg_lyric, 1506 "use_erg_diffusion": use_erg_diffusion, 1507 "oss_steps": oss_steps, 1508 "timecosts": timecosts, 1509 "actual_seeds": actual_seeds, 1510 "retake_seeds": actual_retake_seeds, 1511 "retake_variance": retake_variance, 1512 "guidance_scale_text": guidance_scale_text, 1513 "guidance_scale_lyric": guidance_scale_lyric, 1514 "repaint_start": repaint_start, 1515 "repaint_end": repaint_end, 1516 "edit_n_min": edit_n_min, 1517 "edit_n_max": edit_n_max, 1518 "edit_n_avg": edit_n_avg, 1519 "src_audio_path": src_audio_path, 1520 "edit_target_prompt": edit_target_prompt, 1521 "edit_target_lyrics": edit_target_lyrics, 1522 "audio2audio_enable": audio2audio_enable, 1523 "ref_audio_strength": ref_audio_strength, 1524 "ref_audio_input": ref_audio_input, 1525 } 1526 # save input_params_json 1527 for output_audio_path in output_paths: 1528 input_params_json_save_path = output_audio_path.replace(f".{format}", "_input_params.json") 1529 input_params_json["audio_path"] = output_audio_path 1530 with open(input_params_json_save_path, "w", encoding="utf-8") as f: 1531 json.dump(input_params_json, f, indent=4, ensure_ascii=False) 1532 1533 return output_paths + [input_params_json]
ACEStepPipeline( checkpoint_dir=None, device_id=0, dtype='bfloat16', text_encoder_checkpoint_path=None, persistent_storage_path=None, torch_compile=False, cpu_offload=False, quantized=False, overlapped_decode=False, **kwargs)
94 def __init__( 95 self, 96 checkpoint_dir=None, 97 device_id=0, 98 dtype="bfloat16", 99 text_encoder_checkpoint_path=None, 100 persistent_storage_path=None, 101 torch_compile=False, 102 cpu_offload=False, 103 quantized=False, 104 overlapped_decode=False, 105 **kwargs, 106 ): 107 if not checkpoint_dir: 108 if persistent_storage_path is None: 109 checkpoint_dir = os.path.join(os.path.expanduser("~"), ".cache/ace-step/checkpoints") 110 os.makedirs(checkpoint_dir, exist_ok=True) 111 else: 112 checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints") 113 ensure_directory_exists(checkpoint_dir) 114 115 self.checkpoint_dir = checkpoint_dir 116 self.lora_path = "none" 117 self.lora_weight = 1 118 self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32 119 if gfx_device.type == "mps": 120 if self.dtype == torch.bfloat16: 121 self.dtype = torch.float16 122 123 if "ACE_PIPELINE_DTYPE" in os.environ and len(os.environ["ACE_PIPELINE_DTYPE"]): 124 self.dtype = getattr(torch, os.environ["ACE_PIPELINE_DTYPE"]) 125 self.device: torch.device = gfx_device 126 self.loaded = False 127 self.torch_compile = torch_compile 128 self.cpu_offload = cpu_offload 129 self.quantized = quantized 130 self.overlapped_decode = overlapped_decode
def
get_checkpoint_path(self, checkpoint_dir, repo):
132 def get_checkpoint_path(self, checkpoint_dir, repo): 133 checkpoint_dir_models = None 134 135 if checkpoint_dir is not None: 136 required_dirs = ["music_dcae_f8c8", "music_vocoder", "ace_step_transformer", "umt5-base"] 137 all_dirs_exist = True 138 for dir_name in required_dirs: 139 dir_path = os.path.join(checkpoint_dir, dir_name) 140 if not os.path.exists(dir_path): 141 all_dirs_exist = False 142 break 143 144 if all_dirs_exist: 145 nfo(f"Load models from: {checkpoint_dir}") 146 checkpoint_dir_models = checkpoint_dir 147 148 if checkpoint_dir_models is None: 149 if checkpoint_dir is None: 150 nfo(f"Download models from Hugging Face: {repo}") 151 checkpoint_dir_models = snapshot_download(repo) 152 else: 153 nfo(f"Download models from Hugging Face: {repo}, cache to: {checkpoint_dir}") 154 checkpoint_dir_models = snapshot_download(repo, cache_dir=checkpoint_dir) 155 return checkpoint_dir_models
def
load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False):
157 def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False): 158 checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID) 159 dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8") 160 vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder") 161 ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer") 162 text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base") 163 164 self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype) 165 # self.ace_step_transformer.to(self.device).eval().to(self.dtype) 166 if self.cpu_offload: 167 self.ace_step_transformer = self.ace_step_transformer.to("cpu").eval().to(self.dtype) 168 else: 169 self.ace_step_transformer = self.ace_step_transformer.to(self.device).eval().to(self.dtype) 170 if self.torch_compile: 171 self.ace_step_transformer = torch.compile(self.ace_step_transformer) 172 173 self.music_dcae = MusicDCAE( 174 dcae_checkpoint_path=dcae_checkpoint_path, 175 vocoder_checkpoint_path=vocoder_checkpoint_path, 176 ) 177 # self.music_dcae.to(self.device).eval().to(self.dtype) 178 if self.cpu_offload: # might be redundant 179 self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype) 180 else: 181 self.music_dcae = self.music_dcae.to(self.device).eval().to(self.dtype) 182 if self.torch_compile: 183 self.music_dcae = torch.compile(self.music_dcae) 184 185 lang_segment = LangSegment() 186 lang_segment.setfilters(language_filters.default) 187 self.lang_segment = lang_segment 188 self.lyric_tokenizer = VoiceBpeTokenizer() 189 190 text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval() 191 # text_encoder_model = text_encoder_model.to(self.device).to(self.dtype) 192 if self.cpu_offload: 193 text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype) 194 else: 195 text_encoder_model = text_encoder_model.to(self.device).eval().to(self.dtype) 196 text_encoder_model.requires_grad_(False) 197 self.text_encoder_model = text_encoder_model 198 if self.torch_compile: 199 self.text_encoder_model = torch.compile(self.text_encoder_model) 200 201 self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path) 202 self.loaded = True 203 204 # compile 205 if self.torch_compile: 206 if export_quantized_weights: 207 from torch.ao.quantization import ( 208 Int4WeightOnlyConfig, 209 quantize_, 210 ) 211 212 group_size = 128 213 use_hqq = True 214 quantize_( 215 self.ace_step_transformer, 216 Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq), 217 ) 218 quantize_( 219 self.text_encoder_model, 220 Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq), 221 ) 222 223 # save quantized weights 224 torch.save( 225 self.ace_step_transformer.state_dict(), 226 os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), 227 ) 228 print( 229 "Quantized Weights Saved to: ", 230 os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), 231 ) 232 torch.save( 233 self.text_encoder_model.state_dict(), 234 os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), 235 ) 236 print( 237 "Quantized Weights Saved to: ", 238 os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), 239 )
def
load_quantized_checkpoint(self, checkpoint_dir=None):
241 def load_quantized_checkpoint(self, checkpoint_dir=None): 242 checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID_QUANT) 243 dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8") 244 vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder") 245 ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer") 246 text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base") 247 248 self.music_dcae = MusicDCAE( 249 dcae_checkpoint_path=dcae_checkpoint_path, 250 vocoder_checkpoint_path=vocoder_checkpoint_path, 251 ) 252 if self.cpu_offload: 253 self.music_dcae.eval().to(self.dtype).to(self.device) 254 else: 255 self.music_dcae.eval().to(self.dtype).to("cpu") 256 self.music_dcae = torch.compile(self.music_dcae) 257 258 self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path) 259 self.ace_step_transformer.eval().to(self.dtype).to("cpu") 260 self.ace_step_transformer = torch.compile(self.ace_step_transformer) 261 self.ace_step_transformer.load_state_dict( 262 torch.load( 263 os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), 264 map_location=self.device, 265 ), 266 assign=True, 267 ) 268 self.ace_step_transformer.torchao_quantized = True 269 270 self.text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path) 271 self.text_encoder_model.eval().to(self.dtype).to("cpu") 272 self.text_encoder_model = torch.compile(self.text_encoder_model) 273 self.text_encoder_model.load_state_dict( 274 torch.load( 275 os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), 276 map_location=self.device, 277 ), 278 assign=True, 279 ) 280 self.text_encoder_model.torchao_quantized = True 281 282 self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path) 283 284 lang_segment = LangSegment() 285 lang_segment.setfilters(language_filters.default) 286 self.lang_segment = lang_segment 287 self.lyric_tokenizer = VoiceBpeTokenizer() 288 289 self.loaded = True
@cpu_offload('text_encoder_model')
def
get_text_embeddings(self, texts, text_max_length=256):
291 @cpu_offload("text_encoder_model") 292 def get_text_embeddings(self, texts, text_max_length=256): 293 inputs = self.text_tokenizer( 294 texts, 295 return_tensors="pt", 296 padding=True, 297 truncation=True, 298 max_length=text_max_length, 299 ) 300 inputs = {key: value.to(self.device) for key, value in inputs.items()} 301 if self.text_encoder_model.device != self.device: 302 self.text_encoder_model.to(self.device) 303 with torch.no_grad(): 304 outputs = self.text_encoder_model(**inputs) 305 last_hidden_states = outputs.last_hidden_state 306 attention_mask = inputs["attention_mask"] 307 return last_hidden_states, attention_mask
@cpu_offload('text_encoder_model')
def
get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10):
309 @cpu_offload("text_encoder_model") 310 def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10): 311 inputs = self.text_tokenizer( 312 texts, 313 return_tensors="pt", 314 padding=True, 315 truncation=True, 316 max_length=text_max_length, 317 ) 318 inputs = {key: value.to(self.device) for key, value in inputs.items()} 319 if self.text_encoder_model.device != self.device: 320 self.text_encoder_model.to(self.device) 321 322 def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10): 323 handlers = [] 324 325 def hook(module, input, output): 326 output[:] *= tau 327 return output 328 329 for i in range(l_min, l_max): 330 handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook) 331 handlers.append(handler) 332 333 with torch.no_grad(): 334 outputs = self.text_encoder_model(**inputs) 335 last_hidden_states = outputs.last_hidden_state 336 337 for hook in handlers: 338 hook.remove() 339 340 return last_hidden_states 341 342 last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max) 343 return last_hidden_states
def
set_seeds(self, batch_size, manual_seeds=None):
345 def set_seeds(self, batch_size, manual_seeds=None): 346 processed_input_seeds = None 347 if manual_seeds is not None: 348 if isinstance(manual_seeds, str): 349 if "," in manual_seeds: 350 processed_input_seeds = list(map(int, manual_seeds.split(","))) 351 elif manual_seeds.isdigit(): 352 processed_input_seeds = int(manual_seeds) 353 elif isinstance(manual_seeds, list) and all(isinstance(s, int) for s in manual_seeds): 354 if len(manual_seeds) > 0: 355 processed_input_seeds = list(manual_seeds) 356 elif isinstance(manual_seeds, int): 357 processed_input_seeds = manual_seeds 358 random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)] 359 actual_seeds = [] 360 for i in range(batch_size): 361 current_seed_for_generator = None 362 if processed_input_seeds is None: 363 current_seed_for_generator = torch.randint(0, 2**32, (1,)).item() 364 elif isinstance(processed_input_seeds, int): 365 current_seed_for_generator = processed_input_seeds 366 elif isinstance(processed_input_seeds, list): 367 if i < len(processed_input_seeds): 368 current_seed_for_generator = processed_input_seeds[i] 369 else: 370 current_seed_for_generator = processed_input_seeds[-1] 371 if current_seed_for_generator is None: 372 current_seed_for_generator = torch.randint(0, 2**32, (1,)).item() 373 random_generators[i].manual_seed(current_seed_for_generator) 374 actual_seeds.append(current_seed_for_generator) 375 return random_generators, actual_seeds
def
get_lang(self, text):
377 def get_lang(self, text): 378 language = "en" 379 try: 380 _ = self.lang_segment.getTexts(text) 381 langCounts = self.lang_segment.getCounts() 382 language = langCounts[0][0] 383 if len(langCounts) > 1 and language == "en": 384 language = langCounts[1][0] 385 except Exception as err: 386 language = "en" 387 return language
def
tokenize_lyrics(self, lyrics, debug=False):
389 def tokenize_lyrics(self, lyrics, debug=False): 390 lines = lyrics.split("\n") 391 lyric_token_idx = [261] 392 for line in lines: 393 line = line.strip() 394 if not line: 395 lyric_token_idx += [2] 396 continue 397 398 lang = self.get_lang(line) 399 400 if lang not in SUPPORT_LANGUAGES: 401 lang = "en" 402 if "zh" in lang: 403 lang = "zh" 404 if "spa" in lang: 405 lang = "es" 406 407 try: 408 if structure_pattern.match(line): 409 token_idx = self.lyric_tokenizer.encode(line, "en") 410 else: 411 token_idx = self.lyric_tokenizer.encode(line, lang) 412 if debug: 413 toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx]) 414 nfo(f"debbug {line} --> {lang} --> {toks}") 415 lyric_token_idx = lyric_token_idx + token_idx + [2] 416 except Exception as e: 417 print("tokenize error", e, "for line", line, "major_language", lang) 418 return lyric_token_idx
@cpu_offload('ace_step_transformer')
def
calc_v( self, zt_src, zt_tar, t, encoder_text_hidden_states, text_attention_mask, target_encoder_text_hidden_states, target_text_attention_mask, speaker_embds, target_speaker_embeds, lyric_token_ids, lyric_mask, target_lyric_token_ids, target_lyric_mask, do_classifier_free_guidance=False, guidance_scale=1.0, target_guidance_scale=1.0, cfg_type='apg', attention_mask=None, momentum_buffer=None, momentum_buffer_tar=None, return_src_pred=True):
420 @cpu_offload("ace_step_transformer") 421 def calc_v( 422 self, 423 zt_src, 424 zt_tar, 425 t, 426 encoder_text_hidden_states, 427 text_attention_mask, 428 target_encoder_text_hidden_states, 429 target_text_attention_mask, 430 speaker_embds, 431 target_speaker_embeds, 432 lyric_token_ids, 433 lyric_mask, 434 target_lyric_token_ids, 435 target_lyric_mask, 436 do_classifier_free_guidance=False, 437 guidance_scale=1.0, 438 target_guidance_scale=1.0, 439 cfg_type="apg", 440 attention_mask=None, 441 momentum_buffer=None, 442 momentum_buffer_tar=None, 443 return_src_pred=True, 444 ): 445 noise_pred_src = None 446 if return_src_pred: 447 src_latent_model_input = torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src 448 timestep = t.expand(src_latent_model_input.shape[0]) 449 # source 450 noise_pred_src = self.ace_step_transformer( 451 hidden_states=src_latent_model_input, 452 attention_mask=attention_mask, 453 encoder_text_hidden_states=encoder_text_hidden_states, 454 text_attention_mask=text_attention_mask, 455 speaker_embeds=speaker_embds, 456 lyric_token_idx=lyric_token_ids, 457 lyric_mask=lyric_mask, 458 timestep=timestep, 459 ).sample 460 461 if do_classifier_free_guidance: 462 noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk(2) 463 if cfg_type == "apg": 464 noise_pred_src = apg_forward( 465 pred_cond=noise_pred_with_cond_src, 466 pred_uncond=noise_pred_uncond_src, 467 guidance_scale=guidance_scale, 468 momentum_buffer=momentum_buffer, 469 ) 470 elif cfg_type == "cfg": 471 noise_pred_src = cfg_forward( 472 cond_output=noise_pred_with_cond_src, 473 uncond_output=noise_pred_uncond_src, 474 cfg_strength=guidance_scale, 475 ) 476 477 tar_latent_model_input = torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar 478 timestep = t.expand(tar_latent_model_input.shape[0]) 479 # target 480 noise_pred_tar = self.ace_step_transformer( 481 hidden_states=tar_latent_model_input, 482 attention_mask=attention_mask, 483 encoder_text_hidden_states=target_encoder_text_hidden_states, 484 text_attention_mask=target_text_attention_mask, 485 speaker_embeds=target_speaker_embeds, 486 lyric_token_idx=target_lyric_token_ids, 487 lyric_mask=target_lyric_mask, 488 timestep=timestep, 489 ).sample 490 491 if do_classifier_free_guidance: 492 noise_pred_with_cond_tar, noise_pred_uncond_tar = noise_pred_tar.chunk(2) 493 if cfg_type == "apg": 494 noise_pred_tar = apg_forward( 495 pred_cond=noise_pred_with_cond_tar, 496 pred_uncond=noise_pred_uncond_tar, 497 guidance_scale=target_guidance_scale, 498 momentum_buffer=momentum_buffer_tar, 499 ) 500 elif cfg_type == "cfg": 501 noise_pred_tar = cfg_forward( 502 cond_output=noise_pred_with_cond_tar, 503 uncond_output=noise_pred_uncond_tar, 504 cfg_strength=target_guidance_scale, 505 ) 506 return noise_pred_src, noise_pred_tar
@torch.no_grad()
def
flowedit_diffusion_process( self, encoder_text_hidden_states, text_attention_mask, speaker_embds, lyric_token_ids, lyric_mask, target_encoder_text_hidden_states, target_text_attention_mask, target_speaker_embeds, target_lyric_token_ids, target_lyric_mask, src_latents, random_generators=None, infer_steps=60, guidance_scale=15.0, n_min=0, n_max=1.0, n_avg=1, scheduler_type='euler'):
508 @torch.no_grad() 509 def flowedit_diffusion_process( 510 self, 511 encoder_text_hidden_states, 512 text_attention_mask, 513 speaker_embds, 514 lyric_token_ids, 515 lyric_mask, 516 target_encoder_text_hidden_states, 517 target_text_attention_mask, 518 target_speaker_embeds, 519 target_lyric_token_ids, 520 target_lyric_mask, 521 src_latents, 522 random_generators=None, 523 infer_steps=60, 524 guidance_scale=15.0, 525 n_min=0, 526 n_max=1.0, 527 n_avg=1, 528 scheduler_type="euler", 529 ): 530 do_classifier_free_guidance = True 531 if guidance_scale == 0.0 or guidance_scale == 1.0: 532 do_classifier_free_guidance = False 533 534 target_guidance_scale = guidance_scale 535 bsz = encoder_text_hidden_states.shape[0] 536 537 scheduler = FlowMatchEulerDiscreteScheduler( 538 num_train_timesteps=1000, 539 shift=3.0, 540 ) 541 542 T_steps = infer_steps 543 frame_length = src_latents.shape[-1] 544 attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype) 545 546 timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, self.device, timesteps=None) 547 548 if do_classifier_free_guidance: 549 attention_mask = torch.cat([attention_mask] * 2, dim=0) 550 551 encoder_text_hidden_states = torch.cat( 552 [ 553 encoder_text_hidden_states, 554 torch.zeros_like(encoder_text_hidden_states), 555 ], 556 0, 557 ) 558 text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0) 559 560 target_encoder_text_hidden_states = torch.cat( 561 [ 562 target_encoder_text_hidden_states, 563 torch.zeros_like(target_encoder_text_hidden_states), 564 ], 565 0, 566 ) 567 target_text_attention_mask = torch.cat([target_text_attention_mask] * 2, dim=0) 568 569 speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0) 570 target_speaker_embeds = torch.cat([target_speaker_embeds, torch.zeros_like(target_speaker_embeds)], 0) 571 572 lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0) 573 lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0) 574 575 target_lyric_token_ids = torch.cat([target_lyric_token_ids, torch.zeros_like(target_lyric_token_ids)], 0) 576 target_lyric_mask = torch.cat([target_lyric_mask, torch.zeros_like(target_lyric_mask)], 0) 577 578 momentum_buffer = MomentumBuffer() 579 momentum_buffer_tar = MomentumBuffer() 580 x_src = src_latents 581 zt_edit = x_src.clone() 582 xt_tar = None 583 n_min = int(infer_steps * n_min) 584 n_max = int(infer_steps * n_max) 585 586 nfo("flowedit start from {} to {}".format(n_min, n_max)) 587 588 for i, t in tqdm(enumerate(timesteps), total=T_steps): 589 if i < n_min: 590 continue 591 592 t_i = t / 1000 593 594 if i + 1 < len(timesteps): 595 t_im1 = (timesteps[i + 1]) / 1000 596 else: 597 t_im1 = torch.zeros_like(t_i).to(self.device) 598 599 if i < n_max: 600 # Calculate the average of the V predictions 601 V_delta_avg = torch.zeros_like(x_src) 602 for k in range(n_avg): 603 fwd_noise = randn_tensor( 604 shape=x_src.shape, 605 generator=random_generators, 606 device=self.device, 607 dtype=self.dtype, 608 ) 609 610 zt_src = (1 - t_i) * x_src + (t_i) * fwd_noise 611 612 zt_tar = zt_edit + zt_src - x_src 613 614 Vt_src, Vt_tar = self.calc_v( 615 zt_src=zt_src, 616 zt_tar=zt_tar, 617 t=t, 618 encoder_text_hidden_states=encoder_text_hidden_states, 619 text_attention_mask=text_attention_mask, 620 target_encoder_text_hidden_states=target_encoder_text_hidden_states, 621 target_text_attention_mask=target_text_attention_mask, 622 speaker_embds=speaker_embds, 623 target_speaker_embeds=target_speaker_embeds, 624 lyric_token_ids=lyric_token_ids, 625 lyric_mask=lyric_mask, 626 target_lyric_token_ids=target_lyric_token_ids, 627 target_lyric_mask=target_lyric_mask, 628 do_classifier_free_guidance=do_classifier_free_guidance, 629 guidance_scale=guidance_scale, 630 target_guidance_scale=target_guidance_scale, 631 attention_mask=attention_mask, 632 momentum_buffer=momentum_buffer, 633 ) 634 V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src) # - (hfg - 1) * (x_src) 635 636 zt_edit = zt_edit.to(torch.float32) # arbitrary, should be settable for compatibility 637 if scheduler_type != "pingpong": 638 # propagate direct ODE 639 zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg 640 zt_edit = zt_edit.to(self.dtype) 641 else: 642 # propagate pingpong SDE 643 zt_edit_denoised = zt_edit - t_i * V_delta_avg 644 noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None) 645 prev_sample = (1 - t_im1) * zt_edit_denoised + t_im1 * noise 646 647 else: # i >= T_steps-n_min # regular sampling for last n_min steps 648 if i == n_max: 649 fwd_noise = randn_tensor( 650 shape=x_src.shape, 651 generator=random_generators, 652 device=self.device, 653 dtype=self.dtype, 654 ) 655 scheduler._init_step_index(t) 656 sigma = scheduler.sigmas[scheduler.step_index] 657 xt_src = sigma * fwd_noise + (1.0 - sigma) * x_src 658 xt_tar = zt_edit + xt_src - x_src 659 660 _, Vt_tar = self.calc_v( 661 zt_src=None, 662 zt_tar=xt_tar, 663 t=t, 664 encoder_text_hidden_states=encoder_text_hidden_states, 665 text_attention_mask=text_attention_mask, 666 target_encoder_text_hidden_states=target_encoder_text_hidden_states, 667 target_text_attention_mask=target_text_attention_mask, 668 speaker_embds=speaker_embds, 669 target_speaker_embeds=target_speaker_embeds, 670 lyric_token_ids=lyric_token_ids, 671 lyric_mask=lyric_mask, 672 target_lyric_token_ids=target_lyric_token_ids, 673 target_lyric_mask=target_lyric_mask, 674 do_classifier_free_guidance=do_classifier_free_guidance, 675 guidance_scale=guidance_scale, 676 target_guidance_scale=target_guidance_scale, 677 attention_mask=attention_mask, 678 momentum_buffer_tar=momentum_buffer_tar, 679 return_src_pred=False, 680 ) 681 682 xt_tar = xt_tar.to(torch.float32) 683 if scheduler_type != "pingpong": 684 prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar 685 prev_sample = prev_sample.to(self.dtype) 686 xt_tar = prev_sample 687 else: 688 prev_sample = xt_tar - t_i * Vt_tar 689 noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None) 690 prev_sample = (1 - t_im1) * prev_sample + t_im1 * noise 691 xt_tar = prev_sample 692 693 target_latents = zt_edit if xt_tar is None else xt_tar 694 return target_latents
def
add_latents_noise(self, gt_latents, sigma_max, noise, scheduler_type, infer_steps):
696 def add_latents_noise( 697 self, 698 gt_latents, 699 sigma_max, 700 noise, 701 scheduler_type, 702 infer_steps, 703 ): 704 bsz = gt_latents.shape[0] 705 if scheduler_type == "euler": 706 scheduler = FlowMatchEulerDiscreteScheduler( 707 num_train_timesteps=1000, 708 shift=3.0, 709 sigma_max=sigma_max, 710 ) 711 elif scheduler_type == "heun": 712 scheduler = FlowMatchHeunDiscreteScheduler( 713 num_train_timesteps=1000, 714 shift=3.0, 715 sigma_max=sigma_max, 716 ) 717 elif scheduler_type == "pingpong": 718 scheduler = FlowMatchPingPongScheduler(num_train_timesteps=1000, shift=3.0, sigma_max=sigma_max) 719 720 infer_steps = int(sigma_max * infer_steps) 721 timesteps, num_inference_steps = retrieve_timesteps( 722 scheduler, 723 num_inference_steps=infer_steps, 724 device=self.device, 725 timesteps=None, 726 ) 727 noisy_image = gt_latents * (1 - scheduler.sigma_max) + noise * scheduler.sigma_max 728 nfo(f"{scheduler.sigma_min=} {scheduler.sigma_max=} {timesteps=} {num_inference_steps=}") 729 return noisy_image, timesteps, scheduler, num_inference_steps
@cpu_offload('ace_step_transformer')
@torch.no_grad()
def
text2music_diffusion_process( self, duration, encoder_text_hidden_states, text_attention_mask, speaker_embds, lyric_token_ids, lyric_mask, random_generators=None, infer_steps=60, guidance_scale=15.0, omega_scale=10.0, scheduler_type='euler', cfg_type='apg', zero_steps=1, use_zero_init=True, guidance_interval=0.5, guidance_interval_decay=1.0, min_guidance_scale=3.0, oss_steps=[], encoder_text_hidden_states_null=None, use_erg_lyric=False, use_erg_diffusion=False, retake_random_generators=None, retake_variance=0.5, add_retake_noise=False, guidance_scale_text=0.0, guidance_scale_lyric=0.0, repaint_start=0, repaint_end=0, src_latents=None, audio2audio_enable=False, ref_audio_strength=0.5, ref_latents=None):
731 @cpu_offload("ace_step_transformer") 732 @torch.no_grad() 733 def text2music_diffusion_process( 734 self, 735 duration, 736 encoder_text_hidden_states, 737 text_attention_mask, 738 speaker_embds, 739 lyric_token_ids, 740 lyric_mask, 741 random_generators=None, 742 infer_steps=60, 743 guidance_scale=15.0, 744 omega_scale=10.0, 745 scheduler_type="euler", 746 cfg_type="apg", 747 zero_steps=1, 748 use_zero_init=True, 749 guidance_interval=0.5, 750 guidance_interval_decay=1.0, 751 min_guidance_scale=3.0, 752 oss_steps=[], 753 encoder_text_hidden_states_null=None, 754 use_erg_lyric=False, 755 use_erg_diffusion=False, 756 retake_random_generators=None, 757 retake_variance=0.5, 758 add_retake_noise=False, 759 guidance_scale_text=0.0, 760 guidance_scale_lyric=0.0, 761 repaint_start=0, 762 repaint_end=0, 763 src_latents=None, 764 audio2audio_enable=False, 765 ref_audio_strength=0.5, 766 ref_latents=None, 767 ): 768 nfo("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale)) 769 do_classifier_free_guidance = True 770 if guidance_scale == 0.0 or guidance_scale == 1.0: 771 do_classifier_free_guidance = False 772 773 do_double_condition_guidance = False 774 if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0: 775 do_double_condition_guidance = True 776 nfo( 777 "do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format( 778 do_double_condition_guidance, 779 guidance_scale_text, 780 guidance_scale_lyric, 781 ) 782 ) 783 784 bsz = encoder_text_hidden_states.shape[0] 785 786 if scheduler_type == "euler": 787 scheduler = FlowMatchEulerDiscreteScheduler( 788 num_train_timesteps=1000, 789 shift=3.0, 790 ) 791 elif scheduler_type == "heun": 792 scheduler = FlowMatchHeunDiscreteScheduler( 793 num_train_timesteps=1000, 794 shift=3.0, 795 ) 796 elif scheduler_type == "pingpong": 797 scheduler = FlowMatchPingPongScheduler( 798 num_train_timesteps=1000, 799 shift=3.0, 800 ) 801 802 frame_length = int(duration * 44100 / 512 / 8) 803 if src_latents is not None: 804 frame_length = src_latents.shape[-1] 805 806 if ref_latents is not None: 807 frame_length = ref_latents.shape[-1] 808 809 if len(oss_steps) > 0: 810 infer_steps = max(oss_steps) 811 scheduler.set_timesteps 812 timesteps, num_inference_steps = retrieve_timesteps( 813 scheduler, 814 num_inference_steps=infer_steps, 815 device=self.device, 816 timesteps=None, 817 ) 818 new_timesteps = torch.zeros(len(oss_steps), dtype=self.dtype, device=self.device) 819 for idx in range(len(oss_steps)): 820 new_timesteps[idx] = timesteps[oss_steps[idx] - 1] 821 num_inference_steps = len(oss_steps) 822 sigmas = (new_timesteps / 1000).float().cpu().numpy() 823 timesteps, num_inference_steps = retrieve_timesteps( 824 scheduler, 825 num_inference_steps=num_inference_steps, 826 device=self.device, 827 sigmas=sigmas, 828 ) 829 nfo(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}") 830 else: 831 timesteps, num_inference_steps = retrieve_timesteps( 832 scheduler, 833 num_inference_steps=infer_steps, 834 device=self.device, 835 timesteps=None, 836 ) 837 838 target_latents = randn_tensor( 839 shape=(bsz, 8, 16, frame_length), 840 generator=random_generators, 841 device=self.device, 842 dtype=self.dtype, 843 ) 844 845 is_repaint = False 846 is_extend = False 847 848 if add_retake_noise: 849 n_min = int(infer_steps * (1 - retake_variance)) 850 retake_variance = torch.tensor(retake_variance * math.pi / 2).to(self.device).to(self.dtype) 851 retake_latents = randn_tensor( 852 shape=(bsz, 8, 16, frame_length), 853 generator=retake_random_generators, 854 device=self.device, 855 dtype=self.dtype, 856 ) 857 repaint_start_frame = int(repaint_start * 44100 / 512 / 8) 858 repaint_end_frame = int(repaint_end * 44100 / 512 / 8) 859 x0 = src_latents 860 # retake 861 is_repaint = repaint_end_frame - repaint_start_frame != frame_length 862 863 is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length) 864 if is_extend: 865 is_repaint = True 866 867 # TODO: train a mask aware repainting controlnet 868 # to make sure mean = 0, std = 1 869 if not is_repaint: 870 target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents 871 elif not is_extend: 872 # if repaint_end_frame 873 repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) 874 repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0 875 repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents 876 repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents) 877 zt_edit = x0.clone() 878 z0 = repaint_noise 879 elif is_extend: 880 to_right_pad_gt_latents = None 881 to_left_pad_gt_latents = None 882 gt_latents = src_latents 883 src_latents_length = gt_latents.shape[-1] 884 max_infer_fame_length = int(240 * 44100 / 512 / 8) 885 left_pad_frame_length = 0 886 right_pad_frame_length = 0 887 right_trim_length = 0 888 left_trim_length = 0 889 if repaint_start_frame < 0: 890 left_pad_frame_length = abs(repaint_start_frame) 891 frame_length = left_pad_frame_length + gt_latents.shape[-1] 892 extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0) 893 if frame_length > max_infer_fame_length: 894 right_trim_length = frame_length - max_infer_fame_length 895 extend_gt_latents = extend_gt_latents[:, :, :, :max_infer_fame_length] 896 to_right_pad_gt_latents = extend_gt_latents[:, :, :, -right_trim_length:] 897 frame_length = max_infer_fame_length 898 repaint_start_frame = 0 899 gt_latents = extend_gt_latents 900 901 if repaint_end_frame > src_latents_length: 902 right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1] 903 frame_length = gt_latents.shape[-1] + right_pad_frame_length 904 extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0) 905 if frame_length > max_infer_fame_length: 906 left_trim_length = frame_length - max_infer_fame_length 907 extend_gt_latents = extend_gt_latents[:, :, :, -max_infer_fame_length:] 908 to_left_pad_gt_latents = extend_gt_latents[:, :, :, :left_trim_length] 909 frame_length = max_infer_fame_length 910 repaint_end_frame = frame_length 911 gt_latents = extend_gt_latents 912 913 repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) 914 if left_pad_frame_length > 0: 915 repaint_mask[:, :, :, :left_pad_frame_length] = 1.0 916 if right_pad_frame_length > 0: 917 repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0 918 x0 = gt_latents 919 padd_list = [] 920 if left_pad_frame_length > 0: 921 padd_list.append(retake_latents[:, :, :, :left_pad_frame_length]) 922 padd_list.append( 923 target_latents[ 924 :, 925 :, 926 :, 927 left_trim_length : target_latents.shape[-1] - right_trim_length, 928 ] 929 ) 930 if right_pad_frame_length > 0: 931 padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:]) 932 target_latents = torch.cat(padd_list, dim=-1) 933 assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}" 934 zt_edit = x0.clone() 935 z0 = target_latents 936 937 if audio2audio_enable and ref_latents is not None: 938 nfo(f"audio2audio_enable: {audio2audio_enable}, ref_latents: {ref_latents.shape}") 939 target_latents, timesteps, scheduler, num_inference_steps = self.add_latents_noise( 940 gt_latents=ref_latents, 941 sigma_max=(1 - ref_audio_strength), 942 noise=target_latents, 943 scheduler_type=scheduler_type, 944 infer_steps=infer_steps, 945 ) 946 947 attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype) 948 949 # guidance interval 950 start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2)) 951 end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5)) 952 nfo(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}") 953 954 momentum_buffer = MomentumBuffer() 955 956 def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6): 957 handlers = [] 958 959 def hook(module, input, output): 960 output[:] *= tau 961 return output 962 963 for i in range(l_min, l_max): 964 handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook) 965 handlers.append(handler) 966 967 encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs) 968 969 for hook in handlers: 970 hook.remove() 971 972 return encoder_hidden_states 973 974 # P(speaker, text, lyric) 975 encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode( 976 encoder_text_hidden_states, 977 text_attention_mask, 978 speaker_embds, 979 lyric_token_ids, 980 lyric_mask, 981 ) 982 983 if use_erg_lyric: 984 # P(null_speaker, text_weaker, lyric_weaker) 985 encoder_hidden_states_null = forward_encoder_with_temperature( 986 self, 987 inputs={ 988 "encoder_text_hidden_states": ( 989 encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states) 990 ), 991 "text_attention_mask": text_attention_mask, 992 "speaker_embeds": torch.zeros_like(speaker_embds), 993 "lyric_token_idx": lyric_token_ids, 994 "lyric_mask": lyric_mask, 995 }, 996 ) 997 else: 998 # P(null_speaker, null_text, null_lyric) 999 encoder_hidden_states_null, _ = self.ace_step_transformer.encode( 1000 torch.zeros_like(encoder_text_hidden_states), 1001 text_attention_mask, 1002 torch.zeros_like(speaker_embds), 1003 torch.zeros_like(lyric_token_ids), 1004 lyric_mask, 1005 ) 1006 1007 encoder_hidden_states_no_lyric = None 1008 if do_double_condition_guidance: 1009 # P(null_speaker, text, lyric_weaker) 1010 if use_erg_lyric: 1011 encoder_hidden_states_no_lyric = forward_encoder_with_temperature( 1012 self, 1013 inputs={ 1014 "encoder_text_hidden_states": encoder_text_hidden_states, 1015 "text_attention_mask": text_attention_mask, 1016 "speaker_embeds": torch.zeros_like(speaker_embds), 1017 "lyric_token_idx": lyric_token_ids, 1018 "lyric_mask": lyric_mask, 1019 }, 1020 ) 1021 # P(null_speaker, text, no_lyric) 1022 else: 1023 encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode( 1024 encoder_text_hidden_states, 1025 text_attention_mask, 1026 torch.zeros_like(speaker_embds), 1027 torch.zeros_like(lyric_token_ids), 1028 lyric_mask, 1029 ) 1030 1031 def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20): 1032 handlers = [] 1033 1034 def hook(module, input, output): 1035 output[:] *= tau 1036 return output 1037 1038 for i in range(l_min, l_max): 1039 handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook) 1040 handlers.append(handler) 1041 handler = self.ace_step_transformer.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook) 1042 handlers.append(handler) 1043 1044 sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample 1045 1046 for hook in handlers: 1047 hook.remove() 1048 1049 return sample 1050 1051 for i, t in tqdm(enumerate(timesteps), total=num_inference_steps): 1052 if is_repaint: 1053 if i < n_min: 1054 continue 1055 elif i == n_min: 1056 t_i = t / 1000 1057 zt_src = (1 - t_i) * x0 + (t_i) * z0 1058 target_latents = zt_edit + zt_src - x0 1059 nfo(f"repaint start from {n_min} add {t_i} level of noise") 1060 1061 # expand the latents if we are doing classifier free guidance 1062 latents = target_latents 1063 1064 is_in_guidance_interval = start_idx <= i < end_idx 1065 if is_in_guidance_interval and do_classifier_free_guidance: 1066 # compute current guidance scale 1067 if guidance_interval_decay > 0: 1068 # Linearly interpolate to calculate the current guidance scale 1069 progress = (i - start_idx) / (end_idx - start_idx - 1) # 归一化到[0,1] 1070 current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay 1071 else: 1072 current_guidance_scale = guidance_scale 1073 1074 latent_model_input = latents 1075 timestep = t.expand(latent_model_input.shape[0]) 1076 output_length = latent_model_input.shape[-1] 1077 # P(x|speaker, text, lyric) 1078 noise_pred_with_cond = self.ace_step_transformer.decode( 1079 hidden_states=latent_model_input, 1080 attention_mask=attention_mask, 1081 encoder_hidden_states=encoder_hidden_states, 1082 encoder_hidden_mask=encoder_hidden_mask, 1083 output_length=output_length, 1084 timestep=timestep, 1085 ).sample 1086 1087 noise_pred_with_only_text_cond = None 1088 if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None: 1089 noise_pred_with_only_text_cond = self.ace_step_transformer.decode( 1090 hidden_states=latent_model_input, 1091 attention_mask=attention_mask, 1092 encoder_hidden_states=encoder_hidden_states_no_lyric, 1093 encoder_hidden_mask=encoder_hidden_mask, 1094 output_length=output_length, 1095 timestep=timestep, 1096 ).sample 1097 1098 if use_erg_diffusion: 1099 noise_pred_uncond = forward_diffusion_with_temperature( 1100 self, 1101 hidden_states=latent_model_input, 1102 timestep=timestep, 1103 inputs={ 1104 "encoder_hidden_states": encoder_hidden_states_null, 1105 "encoder_hidden_mask": encoder_hidden_mask, 1106 "output_length": output_length, 1107 "attention_mask": attention_mask, 1108 }, 1109 ) 1110 else: 1111 noise_pred_uncond = self.ace_step_transformer.decode( 1112 hidden_states=latent_model_input, 1113 attention_mask=attention_mask, 1114 encoder_hidden_states=encoder_hidden_states_null, 1115 encoder_hidden_mask=encoder_hidden_mask, 1116 output_length=output_length, 1117 timestep=timestep, 1118 ).sample 1119 1120 if do_double_condition_guidance and noise_pred_with_only_text_cond is not None: 1121 noise_pred = cfg_double_condition_forward( 1122 cond_output=noise_pred_with_cond, 1123 uncond_output=noise_pred_uncond, 1124 only_text_cond_output=noise_pred_with_only_text_cond, 1125 guidance_scale_text=guidance_scale_text, 1126 guidance_scale_lyric=guidance_scale_lyric, 1127 ) 1128 1129 elif cfg_type == "apg": 1130 noise_pred = apg_forward( 1131 pred_cond=noise_pred_with_cond, 1132 pred_uncond=noise_pred_uncond, 1133 guidance_scale=current_guidance_scale, 1134 momentum_buffer=momentum_buffer, 1135 ) 1136 elif cfg_type == "cfg": 1137 noise_pred = cfg_forward( 1138 cond_output=noise_pred_with_cond, 1139 uncond_output=noise_pred_uncond, 1140 cfg_strength=current_guidance_scale, 1141 ) 1142 elif cfg_type == "cfg_star": 1143 noise_pred = cfg_zero_star( 1144 noise_pred_with_cond=noise_pred_with_cond, 1145 noise_pred_uncond=noise_pred_uncond, 1146 guidance_scale=current_guidance_scale, 1147 i=i, 1148 zero_steps=zero_steps, 1149 use_zero_init=use_zero_init, 1150 ) 1151 else: 1152 latent_model_input = latents 1153 timestep = t.expand(latent_model_input.shape[0]) 1154 noise_pred = self.ace_step_transformer.decode( 1155 hidden_states=latent_model_input, 1156 attention_mask=attention_mask, 1157 encoder_hidden_states=encoder_hidden_states, 1158 encoder_hidden_mask=encoder_hidden_mask, 1159 output_length=latent_model_input.shape[-1], 1160 timestep=timestep, 1161 ).sample 1162 1163 if is_repaint and i >= n_min: 1164 t_i = t / 1000 1165 if i + 1 < len(timesteps): 1166 t_im1 = (timesteps[i + 1]) / 1000 1167 else: 1168 t_im1 = torch.zeros_like(t_i).to(self.device) 1169 target_latents = target_latents.to(torch.float32) 1170 prev_sample = target_latents + (t_im1 - t_i) * noise_pred 1171 prev_sample = prev_sample.to(self.dtype) 1172 target_latents = prev_sample 1173 zt_src = (1 - t_im1) * x0 + (t_im1) * z0 1174 target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src) 1175 else: 1176 target_latents = scheduler.step( 1177 model_output=noise_pred, 1178 timestep=t, 1179 sample=target_latents, 1180 return_dict=False, 1181 omega=omega_scale, 1182 generator=random_generators[0], 1183 )[0] 1184 1185 if is_extend: 1186 if to_right_pad_gt_latents is not None: 1187 target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1) 1188 if to_left_pad_gt_latents is not None: 1189 target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0) 1190 return target_latents
@cpu_offload('music_dcae')
def
latents2audio( self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format='wav'):
1192 @cpu_offload("music_dcae") 1193 def latents2audio( 1194 self, 1195 latents, 1196 target_wav_duration_second=30, 1197 sample_rate=48000, 1198 save_path=None, 1199 format="wav", 1200 ): 1201 output_audio_paths = [] 1202 bs = latents.shape[0] 1203 pred_latents = latents 1204 with torch.no_grad(): 1205 if self.overlapped_decode and target_wav_duration_second > 48: 1206 _, pred_wavs = self.music_dcae.decode_overlap(pred_latents, sr=sample_rate) 1207 else: 1208 _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate) 1209 pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs] 1210 for i in tqdm(range(bs)): 1211 output_audio_path = self.save_wav_file( 1212 pred_wavs[i], 1213 i, 1214 save_path=save_path, 1215 sample_rate=sample_rate, 1216 format=format, 1217 ) 1218 output_audio_paths.append(output_audio_path) 1219 return output_audio_paths
def
save_wav_file( self, target_wav, idx, save_path=None, sample_rate=48000, format='wav'):
1221 def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="wav"): 1222 if save_path is None: 1223 nfo("save_path is None, using default path ./outputs/") 1224 base_path = "./outputs" 1225 ensure_directory_exists(base_path) 1226 output_path_wav = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format 1227 else: 1228 ensure_directory_exists(os.path.dirname(save_path)) 1229 if os.path.isdir(save_path): 1230 nfo(f"Provided save_path '{save_path}' is a directory. Appending timestamped filename.") 1231 output_path_wav = os.path.join(save_path, f"output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format) 1232 else: 1233 output_path_wav = save_path 1234 1235 target_wav = target_wav.float() 1236 backend = "soundfile" 1237 if format == "ogg": 1238 backend = "sox" 1239 nfo(f"Saving audio to {output_path_wav} using backend {backend}") 1240 try: 1241 torchaudio.save(output_path_wav, target_wav, sample_rate=sample_rate, format=format, backend=backend) 1242 except (RuntimeError, ImportError, ModuleNotFoundError): 1243 import soundfile as sf # pyright: ignore[reportMissingImports] | pylint:disable=import-error 1244 1245 sf.write(output_path_wav, target_wav, sample_rate) 1246 return output_path_wav
@cpu_offload('music_dcae')
def
infer_latents(self, input_audio_path):
1248 @cpu_offload("music_dcae") 1249 def infer_latents(self, input_audio_path): 1250 if input_audio_path is None: 1251 return None 1252 input_audio, sr = self.music_dcae.load_audio(input_audio_path) 1253 input_audio = input_audio.unsqueeze(0) 1254 input_audio = input_audio.to(device=self.device, dtype=self.dtype) 1255 latents, _ = self.music_dcae.encode(input_audio, sr=sr) 1256 return latents
def
load_lora(self, lora_name_or_path, lora_weight):
1258 def load_lora(self, lora_name_or_path, lora_weight): 1259 if (lora_name_or_path != self.lora_path or lora_weight != self.lora_weight) and lora_name_or_path != "none": 1260 if not os.path.exists(lora_name_or_path): 1261 lora_download_path = snapshot_download(lora_name_or_path, cache_dir=self.checkpoint_dir) 1262 else: 1263 lora_download_path = lora_name_or_path 1264 if self.lora_path != "none": 1265 self.ace_step_transformer.unload_lora() 1266 self.ace_step_transformer.load_lora_adapter( 1267 os.path.join(lora_download_path, "pytorch_lora_weights.safetensors"), adapter_name="ace_step_lora", with_alpha=True, prefix=None 1268 ) 1269 nfo(f"Loading lora weights from: {lora_name_or_path} download path is: {lora_download_path} weight: {lora_weight}") 1270 set_weights_and_activate_adapters(self.ace_step_transformer, ["ace_step_lora"], [lora_weight]) 1271 self.lora_path = lora_name_or_path 1272 self.lora_weight = lora_weight 1273 elif self.lora_path != "none" and lora_name_or_path == "none": 1274 nfo("No lora weights to load.") 1275 self.ace_step_transformer.unload_lora()