divisor.acestep.pipeline_ace_step

ACE-Step: A Step Towards Music Generation Foundation Model

   1# SPDX-License-Identifier:Apache-2.0
   2# adapted from https://github.com/ace-step/ACE-Step
   3
   4"""
   5ACE-Step: A Step Towards Music Generation Foundation Model
   6"""
   7
   8import json
   9import math
  10import os
  11import random
  12import re
  13import time
  14from typing import Literal
  15
  16import torch
  17import torchaudio
  18from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import (
  19    retrieve_timesteps,
  20)
  21from diffusers.utils.peft_utils import set_weights_and_activate_adapters
  22from diffusers.utils.torch_utils import randn_tensor
  23from huggingface_hub import snapshot_download
  24from nnll.console import nfo
  25from tqdm import tqdm
  26from transformers import AutoTokenizer, UMT5EncoderModel
  27
  28from divisor.acestep.apg_guidance import (
  29    MomentumBuffer,
  30    apg_forward,
  31    cfg_double_condition_forward,
  32    cfg_forward,
  33    cfg_zero_star,
  34)
  35from divisor.acestep.cpu_offload import cpu_offload
  36from divisor.acestep.language_segmentation import LangSegment, language_filters
  37from divisor.acestep.models.ace_step_transformer import ACEStepTransformer2DModel
  38from divisor.acestep.models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
  39from divisor.acestep.music_dcae.music_dcae_pipeline import MusicDCAE
  40from divisor.acestep.schedulers.scheduling_flow_match_euler_discrete import (
  41    FlowMatchEulerDiscreteScheduler,
  42)
  43from divisor.acestep.schedulers.scheduling_flow_match_heun_discrete import (
  44    FlowMatchHeunDiscreteScheduler,
  45)
  46from divisor.acestep.schedulers.scheduling_flow_match_pingpong import FlowMatchPingPongScheduler
  47from divisor.registry import gfx_device, empty_cache
  48
  49if gfx_device.type == "cuda":
  50    torch.backends.cudnn.benchmark = False
  51    torch.set_float32_matmul_precision("high")
  52    torch.backends.cudnn.deterministic = True
  53    torch.backends.cuda.matmul.allow_tf32 = False
  54    os.environ["TOKENIZERS_PARALLELISM"] = "false"
  55elif gfx_device.type == "mps":
  56    os.environ["DYLD_FALLBACK_LIBRARY_PATH"] = "/opt/homebrew/lib"
  57
  58SUPPORT_LANGUAGES = {
  59    "en": 259,
  60    "de": 260,
  61    "fr": 262,
  62    "es": 284,
  63    "it": 285,
  64    "pt": 286,
  65    "pl": 294,
  66    "tr": 295,
  67    "ru": 267,
  68    "cs": 293,
  69    "nl": 297,
  70    "ar": 5022,
  71    "zh": 5023,
  72    "ja": 5412,
  73    "hu": 5753,
  74    "ko": 6152,
  75    "hi": 6680,
  76}
  77
  78structure_pattern = re.compile(r"\[.*?\]")
  79
  80
  81def ensure_directory_exists(directory):
  82    directory = str(directory)
  83    if not os.path.exists(directory):
  84        os.makedirs(directory)
  85
  86
  87REPO_ID = "ACE-Step/ACE-Step-v1-3.5B"
  88REPO_ID_QUANT = REPO_ID + "-q4-K-M"  # ??? update this i guess
  89
  90
  91# class ACEStepPipeline(DiffusionPipeline):
  92class ACEStepPipeline:
  93    def __init__(
  94        self,
  95        checkpoint_dir=None,
  96        device_id=0,
  97        dtype="bfloat16",
  98        text_encoder_checkpoint_path=None,
  99        persistent_storage_path=None,
 100        torch_compile=False,
 101        cpu_offload=False,
 102        quantized=False,
 103        overlapped_decode=False,
 104        **kwargs,
 105    ):
 106        if not checkpoint_dir:
 107            if persistent_storage_path is None:
 108                checkpoint_dir = os.path.join(os.path.expanduser("~"), ".cache/ace-step/checkpoints")
 109                os.makedirs(checkpoint_dir, exist_ok=True)
 110            else:
 111                checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints")
 112        ensure_directory_exists(checkpoint_dir)
 113
 114        self.checkpoint_dir = checkpoint_dir
 115        self.lora_path = "none"
 116        self.lora_weight = 1
 117        self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
 118        if gfx_device.type == "mps":
 119            if self.dtype == torch.bfloat16:
 120                self.dtype = torch.float16
 121
 122        if "ACE_PIPELINE_DTYPE" in os.environ and len(os.environ["ACE_PIPELINE_DTYPE"]):
 123            self.dtype = getattr(torch, os.environ["ACE_PIPELINE_DTYPE"])
 124        self.device: torch.device = gfx_device
 125        self.loaded = False
 126        self.torch_compile = torch_compile
 127        self.cpu_offload = cpu_offload
 128        self.quantized = quantized
 129        self.overlapped_decode = overlapped_decode
 130
 131    def get_checkpoint_path(self, checkpoint_dir, repo):
 132        checkpoint_dir_models = None
 133
 134        if checkpoint_dir is not None:
 135            required_dirs = ["music_dcae_f8c8", "music_vocoder", "ace_step_transformer", "umt5-base"]
 136            all_dirs_exist = True
 137            for dir_name in required_dirs:
 138                dir_path = os.path.join(checkpoint_dir, dir_name)
 139                if not os.path.exists(dir_path):
 140                    all_dirs_exist = False
 141                    break
 142
 143            if all_dirs_exist:
 144                nfo(f"Load models from: {checkpoint_dir}")
 145                checkpoint_dir_models = checkpoint_dir
 146
 147        if checkpoint_dir_models is None:
 148            if checkpoint_dir is None:
 149                nfo(f"Download models from Hugging Face: {repo}")
 150                checkpoint_dir_models = snapshot_download(repo)
 151            else:
 152                nfo(f"Download models from Hugging Face: {repo}, cache to: {checkpoint_dir}")
 153                checkpoint_dir_models = snapshot_download(repo, cache_dir=checkpoint_dir)
 154        return checkpoint_dir_models
 155
 156    def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False):
 157        checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID)
 158        dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
 159        vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder")
 160        ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer")
 161        text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base")
 162
 163        self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype)
 164        # self.ace_step_transformer.to(self.device).eval().to(self.dtype)
 165        if self.cpu_offload:
 166            self.ace_step_transformer = self.ace_step_transformer.to("cpu").eval().to(self.dtype)
 167        else:
 168            self.ace_step_transformer = self.ace_step_transformer.to(self.device).eval().to(self.dtype)
 169        if self.torch_compile:
 170            self.ace_step_transformer = torch.compile(self.ace_step_transformer)
 171
 172        self.music_dcae = MusicDCAE(
 173            dcae_checkpoint_path=dcae_checkpoint_path,
 174            vocoder_checkpoint_path=vocoder_checkpoint_path,
 175        )
 176        # self.music_dcae.to(self.device).eval().to(self.dtype)
 177        if self.cpu_offload:  # might be redundant
 178            self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype)
 179        else:
 180            self.music_dcae = self.music_dcae.to(self.device).eval().to(self.dtype)
 181        if self.torch_compile:
 182            self.music_dcae = torch.compile(self.music_dcae)
 183
 184        lang_segment = LangSegment()
 185        lang_segment.setfilters(language_filters.default)
 186        self.lang_segment = lang_segment
 187        self.lyric_tokenizer = VoiceBpeTokenizer()
 188
 189        text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval()
 190        # text_encoder_model = text_encoder_model.to(self.device).to(self.dtype)
 191        if self.cpu_offload:
 192            text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype)
 193        else:
 194            text_encoder_model = text_encoder_model.to(self.device).eval().to(self.dtype)
 195        text_encoder_model.requires_grad_(False)
 196        self.text_encoder_model = text_encoder_model
 197        if self.torch_compile:
 198            self.text_encoder_model = torch.compile(self.text_encoder_model)
 199
 200        self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path)
 201        self.loaded = True
 202
 203        # compile
 204        if self.torch_compile:
 205            if export_quantized_weights:
 206                from torch.ao.quantization import (
 207                    Int4WeightOnlyConfig,
 208                    quantize_,
 209                )
 210
 211                group_size = 128
 212                use_hqq = True
 213                quantize_(
 214                    self.ace_step_transformer,
 215                    Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq),
 216                )
 217                quantize_(
 218                    self.text_encoder_model,
 219                    Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq),
 220                )
 221
 222                # save quantized weights
 223                torch.save(
 224                    self.ace_step_transformer.state_dict(),
 225                    os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"),
 226                )
 227                print(
 228                    "Quantized Weights Saved to: ",
 229                    os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"),
 230                )
 231                torch.save(
 232                    self.text_encoder_model.state_dict(),
 233                    os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"),
 234                )
 235                print(
 236                    "Quantized Weights Saved to: ",
 237                    os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"),
 238                )
 239
 240    def load_quantized_checkpoint(self, checkpoint_dir=None):
 241        checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID_QUANT)
 242        dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
 243        vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder")
 244        ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer")
 245        text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base")
 246
 247        self.music_dcae = MusicDCAE(
 248            dcae_checkpoint_path=dcae_checkpoint_path,
 249            vocoder_checkpoint_path=vocoder_checkpoint_path,
 250        )
 251        if self.cpu_offload:
 252            self.music_dcae.eval().to(self.dtype).to(self.device)
 253        else:
 254            self.music_dcae.eval().to(self.dtype).to("cpu")
 255        self.music_dcae = torch.compile(self.music_dcae)
 256
 257        self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path)
 258        self.ace_step_transformer.eval().to(self.dtype).to("cpu")
 259        self.ace_step_transformer = torch.compile(self.ace_step_transformer)
 260        self.ace_step_transformer.load_state_dict(
 261            torch.load(
 262                os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"),
 263                map_location=self.device,
 264            ),
 265            assign=True,
 266        )
 267        self.ace_step_transformer.torchao_quantized = True
 268
 269        self.text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path)
 270        self.text_encoder_model.eval().to(self.dtype).to("cpu")
 271        self.text_encoder_model = torch.compile(self.text_encoder_model)
 272        self.text_encoder_model.load_state_dict(
 273            torch.load(
 274                os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"),
 275                map_location=self.device,
 276            ),
 277            assign=True,
 278        )
 279        self.text_encoder_model.torchao_quantized = True
 280
 281        self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path)
 282
 283        lang_segment = LangSegment()
 284        lang_segment.setfilters(language_filters.default)
 285        self.lang_segment = lang_segment
 286        self.lyric_tokenizer = VoiceBpeTokenizer()
 287
 288        self.loaded = True
 289
 290    @cpu_offload("text_encoder_model")
 291    def get_text_embeddings(self, texts, text_max_length=256):
 292        inputs = self.text_tokenizer(
 293            texts,
 294            return_tensors="pt",
 295            padding=True,
 296            truncation=True,
 297            max_length=text_max_length,
 298        )
 299        inputs = {key: value.to(self.device) for key, value in inputs.items()}
 300        if self.text_encoder_model.device != self.device:
 301            self.text_encoder_model.to(self.device)
 302        with torch.no_grad():
 303            outputs = self.text_encoder_model(**inputs)
 304            last_hidden_states = outputs.last_hidden_state
 305        attention_mask = inputs["attention_mask"]
 306        return last_hidden_states, attention_mask
 307
 308    @cpu_offload("text_encoder_model")
 309    def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10):
 310        inputs = self.text_tokenizer(
 311            texts,
 312            return_tensors="pt",
 313            padding=True,
 314            truncation=True,
 315            max_length=text_max_length,
 316        )
 317        inputs = {key: value.to(self.device) for key, value in inputs.items()}
 318        if self.text_encoder_model.device != self.device:
 319            self.text_encoder_model.to(self.device)
 320
 321        def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10):
 322            handlers = []
 323
 324            def hook(module, input, output):
 325                output[:] *= tau
 326                return output
 327
 328            for i in range(l_min, l_max):
 329                handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook)
 330                handlers.append(handler)
 331
 332            with torch.no_grad():
 333                outputs = self.text_encoder_model(**inputs)
 334                last_hidden_states = outputs.last_hidden_state
 335
 336            for hook in handlers:
 337                hook.remove()
 338
 339            return last_hidden_states
 340
 341        last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max)
 342        return last_hidden_states
 343
 344    def set_seeds(self, batch_size, manual_seeds=None):
 345        processed_input_seeds = None
 346        if manual_seeds is not None:
 347            if isinstance(manual_seeds, str):
 348                if "," in manual_seeds:
 349                    processed_input_seeds = list(map(int, manual_seeds.split(",")))
 350                elif manual_seeds.isdigit():
 351                    processed_input_seeds = int(manual_seeds)
 352            elif isinstance(manual_seeds, list) and all(isinstance(s, int) for s in manual_seeds):
 353                if len(manual_seeds) > 0:
 354                    processed_input_seeds = list(manual_seeds)
 355            elif isinstance(manual_seeds, int):
 356                processed_input_seeds = manual_seeds
 357        random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)]
 358        actual_seeds = []
 359        for i in range(batch_size):
 360            current_seed_for_generator = None
 361            if processed_input_seeds is None:
 362                current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
 363            elif isinstance(processed_input_seeds, int):
 364                current_seed_for_generator = processed_input_seeds
 365            elif isinstance(processed_input_seeds, list):
 366                if i < len(processed_input_seeds):
 367                    current_seed_for_generator = processed_input_seeds[i]
 368                else:
 369                    current_seed_for_generator = processed_input_seeds[-1]
 370            if current_seed_for_generator is None:
 371                current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
 372            random_generators[i].manual_seed(current_seed_for_generator)
 373            actual_seeds.append(current_seed_for_generator)
 374        return random_generators, actual_seeds
 375
 376    def get_lang(self, text):
 377        language = "en"
 378        try:
 379            _ = self.lang_segment.getTexts(text)
 380            langCounts = self.lang_segment.getCounts()
 381            language = langCounts[0][0]
 382            if len(langCounts) > 1 and language == "en":
 383                language = langCounts[1][0]
 384        except Exception as err:
 385            language = "en"
 386        return language
 387
 388    def tokenize_lyrics(self, lyrics, debug=False):
 389        lines = lyrics.split("\n")
 390        lyric_token_idx = [261]
 391        for line in lines:
 392            line = line.strip()
 393            if not line:
 394                lyric_token_idx += [2]
 395                continue
 396
 397            lang = self.get_lang(line)
 398
 399            if lang not in SUPPORT_LANGUAGES:
 400                lang = "en"
 401            if "zh" in lang:
 402                lang = "zh"
 403            if "spa" in lang:
 404                lang = "es"
 405
 406            try:
 407                if structure_pattern.match(line):
 408                    token_idx = self.lyric_tokenizer.encode(line, "en")
 409                else:
 410                    token_idx = self.lyric_tokenizer.encode(line, lang)
 411                if debug:
 412                    toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx])
 413                    nfo(f"debbug {line} --> {lang} --> {toks}")
 414                lyric_token_idx = lyric_token_idx + token_idx + [2]
 415            except Exception as e:
 416                print("tokenize error", e, "for line", line, "major_language", lang)
 417        return lyric_token_idx
 418
 419    @cpu_offload("ace_step_transformer")
 420    def calc_v(
 421        self,
 422        zt_src,
 423        zt_tar,
 424        t,
 425        encoder_text_hidden_states,
 426        text_attention_mask,
 427        target_encoder_text_hidden_states,
 428        target_text_attention_mask,
 429        speaker_embds,
 430        target_speaker_embeds,
 431        lyric_token_ids,
 432        lyric_mask,
 433        target_lyric_token_ids,
 434        target_lyric_mask,
 435        do_classifier_free_guidance=False,
 436        guidance_scale=1.0,
 437        target_guidance_scale=1.0,
 438        cfg_type="apg",
 439        attention_mask=None,
 440        momentum_buffer=None,
 441        momentum_buffer_tar=None,
 442        return_src_pred=True,
 443    ):
 444        noise_pred_src = None
 445        if return_src_pred:
 446            src_latent_model_input = torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src
 447            timestep = t.expand(src_latent_model_input.shape[0])
 448            # source
 449            noise_pred_src = self.ace_step_transformer(
 450                hidden_states=src_latent_model_input,
 451                attention_mask=attention_mask,
 452                encoder_text_hidden_states=encoder_text_hidden_states,
 453                text_attention_mask=text_attention_mask,
 454                speaker_embeds=speaker_embds,
 455                lyric_token_idx=lyric_token_ids,
 456                lyric_mask=lyric_mask,
 457                timestep=timestep,
 458            ).sample
 459
 460            if do_classifier_free_guidance:
 461                noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk(2)
 462                if cfg_type == "apg":
 463                    noise_pred_src = apg_forward(
 464                        pred_cond=noise_pred_with_cond_src,
 465                        pred_uncond=noise_pred_uncond_src,
 466                        guidance_scale=guidance_scale,
 467                        momentum_buffer=momentum_buffer,
 468                    )
 469                elif cfg_type == "cfg":
 470                    noise_pred_src = cfg_forward(
 471                        cond_output=noise_pred_with_cond_src,
 472                        uncond_output=noise_pred_uncond_src,
 473                        cfg_strength=guidance_scale,
 474                    )
 475
 476        tar_latent_model_input = torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar
 477        timestep = t.expand(tar_latent_model_input.shape[0])
 478        # target
 479        noise_pred_tar = self.ace_step_transformer(
 480            hidden_states=tar_latent_model_input,
 481            attention_mask=attention_mask,
 482            encoder_text_hidden_states=target_encoder_text_hidden_states,
 483            text_attention_mask=target_text_attention_mask,
 484            speaker_embeds=target_speaker_embeds,
 485            lyric_token_idx=target_lyric_token_ids,
 486            lyric_mask=target_lyric_mask,
 487            timestep=timestep,
 488        ).sample
 489
 490        if do_classifier_free_guidance:
 491            noise_pred_with_cond_tar, noise_pred_uncond_tar = noise_pred_tar.chunk(2)
 492            if cfg_type == "apg":
 493                noise_pred_tar = apg_forward(
 494                    pred_cond=noise_pred_with_cond_tar,
 495                    pred_uncond=noise_pred_uncond_tar,
 496                    guidance_scale=target_guidance_scale,
 497                    momentum_buffer=momentum_buffer_tar,
 498                )
 499            elif cfg_type == "cfg":
 500                noise_pred_tar = cfg_forward(
 501                    cond_output=noise_pred_with_cond_tar,
 502                    uncond_output=noise_pred_uncond_tar,
 503                    cfg_strength=target_guidance_scale,
 504                )
 505        return noise_pred_src, noise_pred_tar
 506
 507    @torch.no_grad()
 508    def flowedit_diffusion_process(
 509        self,
 510        encoder_text_hidden_states,
 511        text_attention_mask,
 512        speaker_embds,
 513        lyric_token_ids,
 514        lyric_mask,
 515        target_encoder_text_hidden_states,
 516        target_text_attention_mask,
 517        target_speaker_embeds,
 518        target_lyric_token_ids,
 519        target_lyric_mask,
 520        src_latents,
 521        random_generators=None,
 522        infer_steps=60,
 523        guidance_scale=15.0,
 524        n_min=0,
 525        n_max=1.0,
 526        n_avg=1,
 527        scheduler_type="euler",
 528    ):
 529        do_classifier_free_guidance = True
 530        if guidance_scale == 0.0 or guidance_scale == 1.0:
 531            do_classifier_free_guidance = False
 532
 533        target_guidance_scale = guidance_scale
 534        bsz = encoder_text_hidden_states.shape[0]
 535
 536        scheduler = FlowMatchEulerDiscreteScheduler(
 537            num_train_timesteps=1000,
 538            shift=3.0,
 539        )
 540
 541        T_steps = infer_steps
 542        frame_length = src_latents.shape[-1]
 543        attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype)
 544
 545        timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, self.device, timesteps=None)
 546
 547        if do_classifier_free_guidance:
 548            attention_mask = torch.cat([attention_mask] * 2, dim=0)
 549
 550            encoder_text_hidden_states = torch.cat(
 551                [
 552                    encoder_text_hidden_states,
 553                    torch.zeros_like(encoder_text_hidden_states),
 554                ],
 555                0,
 556            )
 557            text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
 558
 559            target_encoder_text_hidden_states = torch.cat(
 560                [
 561                    target_encoder_text_hidden_states,
 562                    torch.zeros_like(target_encoder_text_hidden_states),
 563                ],
 564                0,
 565            )
 566            target_text_attention_mask = torch.cat([target_text_attention_mask] * 2, dim=0)
 567
 568            speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0)
 569            target_speaker_embeds = torch.cat([target_speaker_embeds, torch.zeros_like(target_speaker_embeds)], 0)
 570
 571            lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0)
 572            lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
 573
 574            target_lyric_token_ids = torch.cat([target_lyric_token_ids, torch.zeros_like(target_lyric_token_ids)], 0)
 575            target_lyric_mask = torch.cat([target_lyric_mask, torch.zeros_like(target_lyric_mask)], 0)
 576
 577        momentum_buffer = MomentumBuffer()
 578        momentum_buffer_tar = MomentumBuffer()
 579        x_src = src_latents
 580        zt_edit = x_src.clone()
 581        xt_tar = None
 582        n_min = int(infer_steps * n_min)
 583        n_max = int(infer_steps * n_max)
 584
 585        nfo("flowedit start from {} to {}".format(n_min, n_max))
 586
 587        for i, t in tqdm(enumerate(timesteps), total=T_steps):
 588            if i < n_min:
 589                continue
 590
 591            t_i = t / 1000
 592
 593            if i + 1 < len(timesteps):
 594                t_im1 = (timesteps[i + 1]) / 1000
 595            else:
 596                t_im1 = torch.zeros_like(t_i).to(self.device)
 597
 598            if i < n_max:
 599                # Calculate the average of the V predictions
 600                V_delta_avg = torch.zeros_like(x_src)
 601                for k in range(n_avg):
 602                    fwd_noise = randn_tensor(
 603                        shape=x_src.shape,
 604                        generator=random_generators,
 605                        device=self.device,
 606                        dtype=self.dtype,
 607                    )
 608
 609                    zt_src = (1 - t_i) * x_src + (t_i) * fwd_noise
 610
 611                    zt_tar = zt_edit + zt_src - x_src
 612
 613                    Vt_src, Vt_tar = self.calc_v(
 614                        zt_src=zt_src,
 615                        zt_tar=zt_tar,
 616                        t=t,
 617                        encoder_text_hidden_states=encoder_text_hidden_states,
 618                        text_attention_mask=text_attention_mask,
 619                        target_encoder_text_hidden_states=target_encoder_text_hidden_states,
 620                        target_text_attention_mask=target_text_attention_mask,
 621                        speaker_embds=speaker_embds,
 622                        target_speaker_embeds=target_speaker_embeds,
 623                        lyric_token_ids=lyric_token_ids,
 624                        lyric_mask=lyric_mask,
 625                        target_lyric_token_ids=target_lyric_token_ids,
 626                        target_lyric_mask=target_lyric_mask,
 627                        do_classifier_free_guidance=do_classifier_free_guidance,
 628                        guidance_scale=guidance_scale,
 629                        target_guidance_scale=target_guidance_scale,
 630                        attention_mask=attention_mask,
 631                        momentum_buffer=momentum_buffer,
 632                    )
 633                    V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src)  # - (hfg - 1) * (x_src)
 634
 635                zt_edit = zt_edit.to(torch.float32)  # arbitrary, should be settable for compatibility
 636                if scheduler_type != "pingpong":
 637                    # propagate direct ODE
 638                    zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg
 639                    zt_edit = zt_edit.to(self.dtype)
 640                else:
 641                    # propagate pingpong SDE
 642                    zt_edit_denoised = zt_edit - t_i * V_delta_avg
 643                    noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None)
 644                    prev_sample = (1 - t_im1) * zt_edit_denoised + t_im1 * noise
 645
 646            else:  # i >= T_steps-n_min # regular sampling for last n_min steps
 647                if i == n_max:
 648                    fwd_noise = randn_tensor(
 649                        shape=x_src.shape,
 650                        generator=random_generators,
 651                        device=self.device,
 652                        dtype=self.dtype,
 653                    )
 654                    scheduler._init_step_index(t)
 655                    sigma = scheduler.sigmas[scheduler.step_index]
 656                    xt_src = sigma * fwd_noise + (1.0 - sigma) * x_src
 657                    xt_tar = zt_edit + xt_src - x_src
 658
 659                _, Vt_tar = self.calc_v(
 660                    zt_src=None,
 661                    zt_tar=xt_tar,
 662                    t=t,
 663                    encoder_text_hidden_states=encoder_text_hidden_states,
 664                    text_attention_mask=text_attention_mask,
 665                    target_encoder_text_hidden_states=target_encoder_text_hidden_states,
 666                    target_text_attention_mask=target_text_attention_mask,
 667                    speaker_embds=speaker_embds,
 668                    target_speaker_embeds=target_speaker_embeds,
 669                    lyric_token_ids=lyric_token_ids,
 670                    lyric_mask=lyric_mask,
 671                    target_lyric_token_ids=target_lyric_token_ids,
 672                    target_lyric_mask=target_lyric_mask,
 673                    do_classifier_free_guidance=do_classifier_free_guidance,
 674                    guidance_scale=guidance_scale,
 675                    target_guidance_scale=target_guidance_scale,
 676                    attention_mask=attention_mask,
 677                    momentum_buffer_tar=momentum_buffer_tar,
 678                    return_src_pred=False,
 679                )
 680
 681                xt_tar = xt_tar.to(torch.float32)
 682                if scheduler_type != "pingpong":
 683                    prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar
 684                    prev_sample = prev_sample.to(self.dtype)
 685                    xt_tar = prev_sample
 686                else:
 687                    prev_sample = xt_tar - t_i * Vt_tar
 688                    noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None)
 689                    prev_sample = (1 - t_im1) * prev_sample + t_im1 * noise
 690                    xt_tar = prev_sample
 691
 692        target_latents = zt_edit if xt_tar is None else xt_tar
 693        return target_latents
 694
 695    def add_latents_noise(
 696        self,
 697        gt_latents,
 698        sigma_max,
 699        noise,
 700        scheduler_type,
 701        infer_steps,
 702    ):
 703        bsz = gt_latents.shape[0]
 704        if scheduler_type == "euler":
 705            scheduler = FlowMatchEulerDiscreteScheduler(
 706                num_train_timesteps=1000,
 707                shift=3.0,
 708                sigma_max=sigma_max,
 709            )
 710        elif scheduler_type == "heun":
 711            scheduler = FlowMatchHeunDiscreteScheduler(
 712                num_train_timesteps=1000,
 713                shift=3.0,
 714                sigma_max=sigma_max,
 715            )
 716        elif scheduler_type == "pingpong":
 717            scheduler = FlowMatchPingPongScheduler(num_train_timesteps=1000, shift=3.0, sigma_max=sigma_max)
 718
 719        infer_steps = int(sigma_max * infer_steps)
 720        timesteps, num_inference_steps = retrieve_timesteps(
 721            scheduler,
 722            num_inference_steps=infer_steps,
 723            device=self.device,
 724            timesteps=None,
 725        )
 726        noisy_image = gt_latents * (1 - scheduler.sigma_max) + noise * scheduler.sigma_max
 727        nfo(f"{scheduler.sigma_min=} {scheduler.sigma_max=} {timesteps=} {num_inference_steps=}")
 728        return noisy_image, timesteps, scheduler, num_inference_steps
 729
 730    @cpu_offload("ace_step_transformer")
 731    @torch.no_grad()
 732    def text2music_diffusion_process(
 733        self,
 734        duration,
 735        encoder_text_hidden_states,
 736        text_attention_mask,
 737        speaker_embds,
 738        lyric_token_ids,
 739        lyric_mask,
 740        random_generators=None,
 741        infer_steps=60,
 742        guidance_scale=15.0,
 743        omega_scale=10.0,
 744        scheduler_type="euler",
 745        cfg_type="apg",
 746        zero_steps=1,
 747        use_zero_init=True,
 748        guidance_interval=0.5,
 749        guidance_interval_decay=1.0,
 750        min_guidance_scale=3.0,
 751        oss_steps=[],
 752        encoder_text_hidden_states_null=None,
 753        use_erg_lyric=False,
 754        use_erg_diffusion=False,
 755        retake_random_generators=None,
 756        retake_variance=0.5,
 757        add_retake_noise=False,
 758        guidance_scale_text=0.0,
 759        guidance_scale_lyric=0.0,
 760        repaint_start=0,
 761        repaint_end=0,
 762        src_latents=None,
 763        audio2audio_enable=False,
 764        ref_audio_strength=0.5,
 765        ref_latents=None,
 766    ):
 767        nfo("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
 768        do_classifier_free_guidance = True
 769        if guidance_scale == 0.0 or guidance_scale == 1.0:
 770            do_classifier_free_guidance = False
 771
 772        do_double_condition_guidance = False
 773        if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0:
 774            do_double_condition_guidance = True
 775            nfo(
 776                "do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format(
 777                    do_double_condition_guidance,
 778                    guidance_scale_text,
 779                    guidance_scale_lyric,
 780                )
 781            )
 782
 783        bsz = encoder_text_hidden_states.shape[0]
 784
 785        if scheduler_type == "euler":
 786            scheduler = FlowMatchEulerDiscreteScheduler(
 787                num_train_timesteps=1000,
 788                shift=3.0,
 789            )
 790        elif scheduler_type == "heun":
 791            scheduler = FlowMatchHeunDiscreteScheduler(
 792                num_train_timesteps=1000,
 793                shift=3.0,
 794            )
 795        elif scheduler_type == "pingpong":
 796            scheduler = FlowMatchPingPongScheduler(
 797                num_train_timesteps=1000,
 798                shift=3.0,
 799            )
 800
 801        frame_length = int(duration * 44100 / 512 / 8)
 802        if src_latents is not None:
 803            frame_length = src_latents.shape[-1]
 804
 805        if ref_latents is not None:
 806            frame_length = ref_latents.shape[-1]
 807
 808        if len(oss_steps) > 0:
 809            infer_steps = max(oss_steps)
 810            scheduler.set_timesteps
 811            timesteps, num_inference_steps = retrieve_timesteps(
 812                scheduler,
 813                num_inference_steps=infer_steps,
 814                device=self.device,
 815                timesteps=None,
 816            )
 817            new_timesteps = torch.zeros(len(oss_steps), dtype=self.dtype, device=self.device)
 818            for idx in range(len(oss_steps)):
 819                new_timesteps[idx] = timesteps[oss_steps[idx] - 1]
 820            num_inference_steps = len(oss_steps)
 821            sigmas = (new_timesteps / 1000).float().cpu().numpy()
 822            timesteps, num_inference_steps = retrieve_timesteps(
 823                scheduler,
 824                num_inference_steps=num_inference_steps,
 825                device=self.device,
 826                sigmas=sigmas,
 827            )
 828            nfo(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}")
 829        else:
 830            timesteps, num_inference_steps = retrieve_timesteps(
 831                scheduler,
 832                num_inference_steps=infer_steps,
 833                device=self.device,
 834                timesteps=None,
 835            )
 836
 837        target_latents = randn_tensor(
 838            shape=(bsz, 8, 16, frame_length),
 839            generator=random_generators,
 840            device=self.device,
 841            dtype=self.dtype,
 842        )
 843
 844        is_repaint = False
 845        is_extend = False
 846
 847        if add_retake_noise:
 848            n_min = int(infer_steps * (1 - retake_variance))
 849            retake_variance = torch.tensor(retake_variance * math.pi / 2).to(self.device).to(self.dtype)
 850            retake_latents = randn_tensor(
 851                shape=(bsz, 8, 16, frame_length),
 852                generator=retake_random_generators,
 853                device=self.device,
 854                dtype=self.dtype,
 855            )
 856            repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
 857            repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
 858            x0 = src_latents
 859            # retake
 860            is_repaint = repaint_end_frame - repaint_start_frame != frame_length
 861
 862            is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length)
 863            if is_extend:
 864                is_repaint = True
 865
 866            # TODO: train a mask aware repainting controlnet
 867            # to make sure mean = 0, std = 1
 868            if not is_repaint:
 869                target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
 870            elif not is_extend:
 871                # if repaint_end_frame
 872                repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype)
 873                repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
 874                repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
 875                repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
 876                zt_edit = x0.clone()
 877                z0 = repaint_noise
 878            elif is_extend:
 879                to_right_pad_gt_latents = None
 880                to_left_pad_gt_latents = None
 881                gt_latents = src_latents
 882                src_latents_length = gt_latents.shape[-1]
 883                max_infer_fame_length = int(240 * 44100 / 512 / 8)
 884                left_pad_frame_length = 0
 885                right_pad_frame_length = 0
 886                right_trim_length = 0
 887                left_trim_length = 0
 888                if repaint_start_frame < 0:
 889                    left_pad_frame_length = abs(repaint_start_frame)
 890                    frame_length = left_pad_frame_length + gt_latents.shape[-1]
 891                    extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0)
 892                    if frame_length > max_infer_fame_length:
 893                        right_trim_length = frame_length - max_infer_fame_length
 894                        extend_gt_latents = extend_gt_latents[:, :, :, :max_infer_fame_length]
 895                        to_right_pad_gt_latents = extend_gt_latents[:, :, :, -right_trim_length:]
 896                        frame_length = max_infer_fame_length
 897                    repaint_start_frame = 0
 898                    gt_latents = extend_gt_latents
 899
 900                if repaint_end_frame > src_latents_length:
 901                    right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1]
 902                    frame_length = gt_latents.shape[-1] + right_pad_frame_length
 903                    extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0)
 904                    if frame_length > max_infer_fame_length:
 905                        left_trim_length = frame_length - max_infer_fame_length
 906                        extend_gt_latents = extend_gt_latents[:, :, :, -max_infer_fame_length:]
 907                        to_left_pad_gt_latents = extend_gt_latents[:, :, :, :left_trim_length]
 908                        frame_length = max_infer_fame_length
 909                    repaint_end_frame = frame_length
 910                    gt_latents = extend_gt_latents
 911
 912                repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype)
 913                if left_pad_frame_length > 0:
 914                    repaint_mask[:, :, :, :left_pad_frame_length] = 1.0
 915                if right_pad_frame_length > 0:
 916                    repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0
 917                x0 = gt_latents
 918                padd_list = []
 919                if left_pad_frame_length > 0:
 920                    padd_list.append(retake_latents[:, :, :, :left_pad_frame_length])
 921                padd_list.append(
 922                    target_latents[
 923                        :,
 924                        :,
 925                        :,
 926                        left_trim_length : target_latents.shape[-1] - right_trim_length,
 927                    ]
 928                )
 929                if right_pad_frame_length > 0:
 930                    padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
 931                target_latents = torch.cat(padd_list, dim=-1)
 932                assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
 933                zt_edit = x0.clone()
 934                z0 = target_latents
 935
 936        if audio2audio_enable and ref_latents is not None:
 937            nfo(f"audio2audio_enable: {audio2audio_enable}, ref_latents: {ref_latents.shape}")
 938            target_latents, timesteps, scheduler, num_inference_steps = self.add_latents_noise(
 939                gt_latents=ref_latents,
 940                sigma_max=(1 - ref_audio_strength),
 941                noise=target_latents,
 942                scheduler_type=scheduler_type,
 943                infer_steps=infer_steps,
 944            )
 945
 946        attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype)
 947
 948        # guidance interval
 949        start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
 950        end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
 951        nfo(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
 952
 953        momentum_buffer = MomentumBuffer()
 954
 955        def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
 956            handlers = []
 957
 958            def hook(module, input, output):
 959                output[:] *= tau
 960                return output
 961
 962            for i in range(l_min, l_max):
 963                handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
 964                handlers.append(handler)
 965
 966            encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
 967
 968            for hook in handlers:
 969                hook.remove()
 970
 971            return encoder_hidden_states
 972
 973        # P(speaker, text, lyric)
 974        encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(
 975            encoder_text_hidden_states,
 976            text_attention_mask,
 977            speaker_embds,
 978            lyric_token_ids,
 979            lyric_mask,
 980        )
 981
 982        if use_erg_lyric:
 983            # P(null_speaker, text_weaker, lyric_weaker)
 984            encoder_hidden_states_null = forward_encoder_with_temperature(
 985                self,
 986                inputs={
 987                    "encoder_text_hidden_states": (
 988                        encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states)
 989                    ),
 990                    "text_attention_mask": text_attention_mask,
 991                    "speaker_embeds": torch.zeros_like(speaker_embds),
 992                    "lyric_token_idx": lyric_token_ids,
 993                    "lyric_mask": lyric_mask,
 994                },
 995            )
 996        else:
 997            # P(null_speaker, null_text, null_lyric)
 998            encoder_hidden_states_null, _ = self.ace_step_transformer.encode(
 999                torch.zeros_like(encoder_text_hidden_states),
1000                text_attention_mask,
1001                torch.zeros_like(speaker_embds),
1002                torch.zeros_like(lyric_token_ids),
1003                lyric_mask,
1004            )
1005
1006        encoder_hidden_states_no_lyric = None
1007        if do_double_condition_guidance:
1008            # P(null_speaker, text, lyric_weaker)
1009            if use_erg_lyric:
1010                encoder_hidden_states_no_lyric = forward_encoder_with_temperature(
1011                    self,
1012                    inputs={
1013                        "encoder_text_hidden_states": encoder_text_hidden_states,
1014                        "text_attention_mask": text_attention_mask,
1015                        "speaker_embeds": torch.zeros_like(speaker_embds),
1016                        "lyric_token_idx": lyric_token_ids,
1017                        "lyric_mask": lyric_mask,
1018                    },
1019                )
1020            # P(null_speaker, text, no_lyric)
1021            else:
1022                encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode(
1023                    encoder_text_hidden_states,
1024                    text_attention_mask,
1025                    torch.zeros_like(speaker_embds),
1026                    torch.zeros_like(lyric_token_ids),
1027                    lyric_mask,
1028                )
1029
1030        def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
1031            handlers = []
1032
1033            def hook(module, input, output):
1034                output[:] *= tau
1035                return output
1036
1037            for i in range(l_min, l_max):
1038                handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
1039                handlers.append(handler)
1040                handler = self.ace_step_transformer.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook)
1041                handlers.append(handler)
1042
1043            sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
1044
1045            for hook in handlers:
1046                hook.remove()
1047
1048            return sample
1049
1050        for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
1051            if is_repaint:
1052                if i < n_min:
1053                    continue
1054                elif i == n_min:
1055                    t_i = t / 1000
1056                    zt_src = (1 - t_i) * x0 + (t_i) * z0
1057                    target_latents = zt_edit + zt_src - x0
1058                    nfo(f"repaint start from {n_min} add {t_i} level of noise")
1059
1060            # expand the latents if we are doing classifier free guidance
1061            latents = target_latents
1062
1063            is_in_guidance_interval = start_idx <= i < end_idx
1064            if is_in_guidance_interval and do_classifier_free_guidance:
1065                # compute current guidance scale
1066                if guidance_interval_decay > 0:
1067                    # Linearly interpolate to calculate the current guidance scale
1068                    progress = (i - start_idx) / (end_idx - start_idx - 1)  # 归一化到[0,1]
1069                    current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay
1070                else:
1071                    current_guidance_scale = guidance_scale
1072
1073                latent_model_input = latents
1074                timestep = t.expand(latent_model_input.shape[0])
1075                output_length = latent_model_input.shape[-1]
1076                # P(x|speaker, text, lyric)
1077                noise_pred_with_cond = self.ace_step_transformer.decode(
1078                    hidden_states=latent_model_input,
1079                    attention_mask=attention_mask,
1080                    encoder_hidden_states=encoder_hidden_states,
1081                    encoder_hidden_mask=encoder_hidden_mask,
1082                    output_length=output_length,
1083                    timestep=timestep,
1084                ).sample
1085
1086                noise_pred_with_only_text_cond = None
1087                if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None:
1088                    noise_pred_with_only_text_cond = self.ace_step_transformer.decode(
1089                        hidden_states=latent_model_input,
1090                        attention_mask=attention_mask,
1091                        encoder_hidden_states=encoder_hidden_states_no_lyric,
1092                        encoder_hidden_mask=encoder_hidden_mask,
1093                        output_length=output_length,
1094                        timestep=timestep,
1095                    ).sample
1096
1097                if use_erg_diffusion:
1098                    noise_pred_uncond = forward_diffusion_with_temperature(
1099                        self,
1100                        hidden_states=latent_model_input,
1101                        timestep=timestep,
1102                        inputs={
1103                            "encoder_hidden_states": encoder_hidden_states_null,
1104                            "encoder_hidden_mask": encoder_hidden_mask,
1105                            "output_length": output_length,
1106                            "attention_mask": attention_mask,
1107                        },
1108                    )
1109                else:
1110                    noise_pred_uncond = self.ace_step_transformer.decode(
1111                        hidden_states=latent_model_input,
1112                        attention_mask=attention_mask,
1113                        encoder_hidden_states=encoder_hidden_states_null,
1114                        encoder_hidden_mask=encoder_hidden_mask,
1115                        output_length=output_length,
1116                        timestep=timestep,
1117                    ).sample
1118
1119                if do_double_condition_guidance and noise_pred_with_only_text_cond is not None:
1120                    noise_pred = cfg_double_condition_forward(
1121                        cond_output=noise_pred_with_cond,
1122                        uncond_output=noise_pred_uncond,
1123                        only_text_cond_output=noise_pred_with_only_text_cond,
1124                        guidance_scale_text=guidance_scale_text,
1125                        guidance_scale_lyric=guidance_scale_lyric,
1126                    )
1127
1128                elif cfg_type == "apg":
1129                    noise_pred = apg_forward(
1130                        pred_cond=noise_pred_with_cond,
1131                        pred_uncond=noise_pred_uncond,
1132                        guidance_scale=current_guidance_scale,
1133                        momentum_buffer=momentum_buffer,
1134                    )
1135                elif cfg_type == "cfg":
1136                    noise_pred = cfg_forward(
1137                        cond_output=noise_pred_with_cond,
1138                        uncond_output=noise_pred_uncond,
1139                        cfg_strength=current_guidance_scale,
1140                    )
1141                elif cfg_type == "cfg_star":
1142                    noise_pred = cfg_zero_star(
1143                        noise_pred_with_cond=noise_pred_with_cond,
1144                        noise_pred_uncond=noise_pred_uncond,
1145                        guidance_scale=current_guidance_scale,
1146                        i=i,
1147                        zero_steps=zero_steps,
1148                        use_zero_init=use_zero_init,
1149                    )
1150            else:
1151                latent_model_input = latents
1152                timestep = t.expand(latent_model_input.shape[0])
1153                noise_pred = self.ace_step_transformer.decode(
1154                    hidden_states=latent_model_input,
1155                    attention_mask=attention_mask,
1156                    encoder_hidden_states=encoder_hidden_states,
1157                    encoder_hidden_mask=encoder_hidden_mask,
1158                    output_length=latent_model_input.shape[-1],
1159                    timestep=timestep,
1160                ).sample
1161
1162            if is_repaint and i >= n_min:
1163                t_i = t / 1000
1164                if i + 1 < len(timesteps):
1165                    t_im1 = (timesteps[i + 1]) / 1000
1166                else:
1167                    t_im1 = torch.zeros_like(t_i).to(self.device)
1168                target_latents = target_latents.to(torch.float32)
1169                prev_sample = target_latents + (t_im1 - t_i) * noise_pred
1170                prev_sample = prev_sample.to(self.dtype)
1171                target_latents = prev_sample
1172                zt_src = (1 - t_im1) * x0 + (t_im1) * z0
1173                target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src)
1174            else:
1175                target_latents = scheduler.step(
1176                    model_output=noise_pred,
1177                    timestep=t,
1178                    sample=target_latents,
1179                    return_dict=False,
1180                    omega=omega_scale,
1181                    generator=random_generators[0],
1182                )[0]
1183
1184        if is_extend:
1185            if to_right_pad_gt_latents is not None:
1186                target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1)
1187            if to_left_pad_gt_latents is not None:
1188                target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0)
1189        return target_latents
1190
1191    @cpu_offload("music_dcae")
1192    def latents2audio(
1193        self,
1194        latents,
1195        target_wav_duration_second=30,
1196        sample_rate=48000,
1197        save_path=None,
1198        format="wav",
1199    ):
1200        output_audio_paths = []
1201        bs = latents.shape[0]
1202        pred_latents = latents
1203        with torch.no_grad():
1204            if self.overlapped_decode and target_wav_duration_second > 48:
1205                _, pred_wavs = self.music_dcae.decode_overlap(pred_latents, sr=sample_rate)
1206            else:
1207                _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
1208        pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
1209        for i in tqdm(range(bs)):
1210            output_audio_path = self.save_wav_file(
1211                pred_wavs[i],
1212                i,
1213                save_path=save_path,
1214                sample_rate=sample_rate,
1215                format=format,
1216            )
1217            output_audio_paths.append(output_audio_path)
1218        return output_audio_paths
1219
1220    def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="wav"):
1221        if save_path is None:
1222            nfo("save_path is None, using default path ./outputs/")
1223            base_path = "./outputs"
1224            ensure_directory_exists(base_path)
1225            output_path_wav = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format
1226        else:
1227            ensure_directory_exists(os.path.dirname(save_path))
1228            if os.path.isdir(save_path):
1229                nfo(f"Provided save_path '{save_path}' is a directory. Appending timestamped filename.")
1230                output_path_wav = os.path.join(save_path, f"output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format)
1231            else:
1232                output_path_wav = save_path
1233
1234        target_wav = target_wav.float()
1235        backend = "soundfile"
1236        if format == "ogg":
1237            backend = "sox"
1238        nfo(f"Saving audio to {output_path_wav} using backend {backend}")
1239        try:
1240            torchaudio.save(output_path_wav, target_wav, sample_rate=sample_rate, format=format, backend=backend)
1241        except (RuntimeError, ImportError, ModuleNotFoundError):
1242            import soundfile as sf  # pyright: ignore[reportMissingImports] | pylint:disable=import-error
1243
1244            sf.write(output_path_wav, target_wav, sample_rate)
1245        return output_path_wav
1246
1247    @cpu_offload("music_dcae")
1248    def infer_latents(self, input_audio_path):
1249        if input_audio_path is None:
1250            return None
1251        input_audio, sr = self.music_dcae.load_audio(input_audio_path)
1252        input_audio = input_audio.unsqueeze(0)
1253        input_audio = input_audio.to(device=self.device, dtype=self.dtype)
1254        latents, _ = self.music_dcae.encode(input_audio, sr=sr)
1255        return latents
1256
1257    def load_lora(self, lora_name_or_path, lora_weight):
1258        if (lora_name_or_path != self.lora_path or lora_weight != self.lora_weight) and lora_name_or_path != "none":
1259            if not os.path.exists(lora_name_or_path):
1260                lora_download_path = snapshot_download(lora_name_or_path, cache_dir=self.checkpoint_dir)
1261            else:
1262                lora_download_path = lora_name_or_path
1263            if self.lora_path != "none":
1264                self.ace_step_transformer.unload_lora()
1265            self.ace_step_transformer.load_lora_adapter(
1266                os.path.join(lora_download_path, "pytorch_lora_weights.safetensors"), adapter_name="ace_step_lora", with_alpha=True, prefix=None
1267            )
1268            nfo(f"Loading lora weights from: {lora_name_or_path} download path is: {lora_download_path} weight: {lora_weight}")
1269            set_weights_and_activate_adapters(self.ace_step_transformer, ["ace_step_lora"], [lora_weight])
1270            self.lora_path = lora_name_or_path
1271            self.lora_weight = lora_weight
1272        elif self.lora_path != "none" and lora_name_or_path == "none":
1273            nfo("No lora weights to load.")
1274            self.ace_step_transformer.unload_lora()
1275
1276    def __call__(
1277        self,
1278        format: str = "wav",
1279        audio_duration: float = 60.0,
1280        prompt: str = None,
1281        lyrics: str = None,
1282        infer_step: int = 60,
1283        guidance_scale: float = 15.0,
1284        scheduler_type: str = "euler",
1285        cfg_type: str = "apg",
1286        omega_scale: int = 10.0,
1287        manual_seeds: list = None,
1288        guidance_interval: float = 0.5,
1289        guidance_interval_decay: float = 0.0,
1290        min_guidance_scale: float = 3.0,
1291        use_erg_tag: bool = True,
1292        use_erg_lyric: bool = True,
1293        use_erg_diffusion: bool = True,
1294        oss_steps: str = None,
1295        guidance_scale_text: float = 0.0,
1296        guidance_scale_lyric: float = 0.0,
1297        audio2audio_enable: bool = False,
1298        ref_audio_strength: float = 0.5,
1299        ref_audio_input: str = None,
1300        lora_name_or_path: str = "none",
1301        lora_weight: float = 1.0,
1302        retake_seeds: list = None,
1303        retake_variance: float = 0.5,
1304        task: str = "text2music",
1305        repaint_start: int = 0,
1306        repaint_end: int = 0,
1307        src_audio_path: str = None,
1308        edit_target_prompt: str = None,
1309        edit_target_lyrics: str = None,
1310        edit_n_min: float = 0.0,
1311        edit_n_max: float = 1.0,
1312        edit_n_avg: int = 1,
1313        save_path: str = None,
1314        batch_size: int = 1,
1315        debug: bool = False,
1316    ):
1317        start_time = time.time()
1318
1319        if audio2audio_enable and ref_audio_input is not None:
1320            task = "audio2audio"
1321
1322        if not self.loaded:
1323            nfo("Checkpoint not loaded, loading checkpoint...")
1324            if self.quantized:
1325                self.load_quantized_checkpoint(self.checkpoint_dir)
1326            else:
1327                self.load_checkpoint(self.checkpoint_dir)
1328
1329        self.load_lora(lora_name_or_path, lora_weight)
1330        load_model_cost = time.time() - start_time
1331        nfo(f"Model loaded in {load_model_cost:.2f} seconds.")
1332
1333        start_time = time.time()
1334
1335        random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds)
1336        retake_random_generators, actual_retake_seeds = self.set_seeds(batch_size, retake_seeds)
1337
1338        if isinstance(oss_steps, str) and len(oss_steps) > 0:
1339            oss_steps = list(map(int, oss_steps.split(",")))
1340        else:
1341            oss_steps = []
1342
1343        texts = [prompt]
1344        encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts)
1345        encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
1346        text_attention_mask = text_attention_mask.repeat(batch_size, 1)
1347
1348        encoder_text_hidden_states_null = None
1349        if use_erg_tag:
1350            encoder_text_hidden_states_null = self.get_text_embeddings_null(texts)
1351            encoder_text_hidden_states_null = encoder_text_hidden_states_null.repeat(batch_size, 1, 1)
1352
1353        # not support for released checkpoint
1354        speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype)
1355
1356        # 6 lyric
1357        lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1358        lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1359        if len(lyrics) > 0:
1360            lyric_token_idx = self.tokenize_lyrics(lyrics, debug=debug)
1361            lyric_mask = [1] * len(lyric_token_idx)
1362            lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1363            lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1364
1365        if audio_duration <= 0:
1366            audio_duration = random.uniform(30.0, 240.0)
1367            nfo(f"random audio duration: {audio_duration}")
1368
1369        end_time = time.time()
1370        preprocess_time_cost = end_time - start_time
1371        start_time = end_time
1372
1373        add_retake_noise = task in ("retake", "repaint", "extend")
1374        # retake equal to repaint
1375        if task == "retake":
1376            repaint_start = 0
1377            repaint_end = audio_duration
1378
1379        src_latents = None
1380        if src_audio_path is not None:
1381            assert src_audio_path is not None and task in (
1382                "repaint",
1383                "edit",
1384                "extend",
1385            ), "src_audio_path is required for retake/repaint/extend task"
1386            assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
1387            src_latents = self.infer_latents(src_audio_path)
1388
1389        ref_latents = None
1390        if ref_audio_input is not None and audio2audio_enable:
1391            assert ref_audio_input is not None, "ref_audio_input is required for audio2audio task"
1392            assert os.path.exists(ref_audio_input), f"ref_audio_input {ref_audio_input} does not exist"
1393            ref_latents = self.infer_latents(ref_audio_input)
1394
1395        if task == "edit":
1396            texts = [edit_target_prompt]
1397            target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts)
1398            target_encoder_text_hidden_states = target_encoder_text_hidden_states.repeat(batch_size, 1, 1)
1399            target_text_attention_mask = target_text_attention_mask.repeat(batch_size, 1)
1400
1401            target_lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1402            target_lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1403            if len(edit_target_lyrics) > 0:
1404                target_lyric_token_idx = self.tokenize_lyrics(edit_target_lyrics, debug=True)
1405                target_lyric_mask = [1] * len(target_lyric_token_idx)
1406                target_lyric_token_idx = torch.tensor(target_lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1407                target_lyric_mask = torch.tensor(target_lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1408
1409            target_speaker_embeds = speaker_embeds.clone()
1410
1411            target_latents = self.flowedit_diffusion_process(
1412                encoder_text_hidden_states=encoder_text_hidden_states,
1413                text_attention_mask=text_attention_mask,
1414                speaker_embds=speaker_embeds,
1415                lyric_token_ids=lyric_token_idx,
1416                lyric_mask=lyric_mask,
1417                target_encoder_text_hidden_states=target_encoder_text_hidden_states,
1418                target_text_attention_mask=target_text_attention_mask,
1419                target_speaker_embeds=target_speaker_embeds,
1420                target_lyric_token_ids=target_lyric_token_idx,
1421                target_lyric_mask=target_lyric_mask,
1422                src_latents=src_latents,
1423                random_generators=retake_random_generators,  # more diversity
1424                infer_steps=infer_step,
1425                guidance_scale=guidance_scale,
1426                n_min=edit_n_min,
1427                n_max=edit_n_max,
1428                n_avg=edit_n_avg,
1429                scheduler_type=scheduler_type,
1430            )
1431        else:
1432            target_latents = self.text2music_diffusion_process(
1433                duration=audio_duration,
1434                encoder_text_hidden_states=encoder_text_hidden_states,
1435                text_attention_mask=text_attention_mask,
1436                speaker_embds=speaker_embeds,
1437                lyric_token_ids=lyric_token_idx,
1438                lyric_mask=lyric_mask,
1439                guidance_scale=guidance_scale,
1440                omega_scale=omega_scale,
1441                infer_steps=infer_step,
1442                random_generators=random_generators,
1443                scheduler_type=scheduler_type,
1444                cfg_type=cfg_type,
1445                guidance_interval=guidance_interval,
1446                guidance_interval_decay=guidance_interval_decay,
1447                min_guidance_scale=min_guidance_scale,
1448                oss_steps=oss_steps,
1449                encoder_text_hidden_states_null=encoder_text_hidden_states_null,
1450                use_erg_lyric=use_erg_lyric,
1451                use_erg_diffusion=use_erg_diffusion,
1452                retake_random_generators=retake_random_generators,
1453                retake_variance=retake_variance,
1454                add_retake_noise=add_retake_noise,
1455                guidance_scale_text=guidance_scale_text,
1456                guidance_scale_lyric=guidance_scale_lyric,
1457                repaint_start=repaint_start,
1458                repaint_end=repaint_end,
1459                src_latents=src_latents,
1460                audio2audio_enable=audio2audio_enable,
1461                ref_audio_strength=ref_audio_strength,
1462                ref_latents=ref_latents,
1463            )
1464
1465        end_time = time.time()
1466        diffusion_time_cost = end_time - start_time
1467        start_time = end_time
1468
1469        output_paths = self.latents2audio(
1470            latents=target_latents,
1471            target_wav_duration_second=audio_duration,
1472            save_path=save_path,
1473            format=format,
1474        )
1475
1476        # Clean up memory after generation
1477        empty_cache
1478
1479        end_time = time.time()
1480        latent2audio_time_cost = end_time - start_time
1481        timecosts = {
1482            "preprocess": preprocess_time_cost,
1483            "diffusion": diffusion_time_cost,
1484            "latent2audio": latent2audio_time_cost,
1485        }
1486
1487        input_params_json = {
1488            "format": format,
1489            "lora_name_or_path": lora_name_or_path,
1490            "lora_weight": lora_weight,
1491            "task": task,
1492            "prompt": prompt if task != "edit" else edit_target_prompt,
1493            "lyrics": lyrics if task != "edit" else edit_target_lyrics,
1494            "audio_duration": audio_duration,
1495            "infer_step": infer_step,
1496            "guidance_scale": guidance_scale,
1497            "scheduler_type": scheduler_type,
1498            "cfg_type": cfg_type,
1499            "omega_scale": omega_scale,
1500            "guidance_interval": guidance_interval,
1501            "guidance_interval_decay": guidance_interval_decay,
1502            "min_guidance_scale": min_guidance_scale,
1503            "use_erg_tag": use_erg_tag,
1504            "use_erg_lyric": use_erg_lyric,
1505            "use_erg_diffusion": use_erg_diffusion,
1506            "oss_steps": oss_steps,
1507            "timecosts": timecosts,
1508            "actual_seeds": actual_seeds,
1509            "retake_seeds": actual_retake_seeds,
1510            "retake_variance": retake_variance,
1511            "guidance_scale_text": guidance_scale_text,
1512            "guidance_scale_lyric": guidance_scale_lyric,
1513            "repaint_start": repaint_start,
1514            "repaint_end": repaint_end,
1515            "edit_n_min": edit_n_min,
1516            "edit_n_max": edit_n_max,
1517            "edit_n_avg": edit_n_avg,
1518            "src_audio_path": src_audio_path,
1519            "edit_target_prompt": edit_target_prompt,
1520            "edit_target_lyrics": edit_target_lyrics,
1521            "audio2audio_enable": audio2audio_enable,
1522            "ref_audio_strength": ref_audio_strength,
1523            "ref_audio_input": ref_audio_input,
1524        }
1525        # save input_params_json
1526        for output_audio_path in output_paths:
1527            input_params_json_save_path = output_audio_path.replace(f".{format}", "_input_params.json")
1528            input_params_json["audio_path"] = output_audio_path
1529            with open(input_params_json_save_path, "w", encoding="utf-8") as f:
1530                json.dump(input_params_json, f, indent=4, ensure_ascii=False)
1531
1532        return output_paths + [input_params_json]
SUPPORT_LANGUAGES = {'en': 259, 'de': 260, 'fr': 262, 'es': 284, 'it': 285, 'pt': 286, 'pl': 294, 'tr': 295, 'ru': 267, 'cs': 293, 'nl': 297, 'ar': 5022, 'zh': 5023, 'ja': 5412, 'hu': 5753, 'ko': 6152, 'hi': 6680}
structure_pattern = re.compile('\\[.*?\\]')
def ensure_directory_exists(directory):
82def ensure_directory_exists(directory):
83    directory = str(directory)
84    if not os.path.exists(directory):
85        os.makedirs(directory)
REPO_ID = 'ACE-Step/ACE-Step-v1-3.5B'
REPO_ID_QUANT = 'ACE-Step/ACE-Step-v1-3.5B-q4-K-M'
class ACEStepPipeline:
  93class ACEStepPipeline:
  94    def __init__(
  95        self,
  96        checkpoint_dir=None,
  97        device_id=0,
  98        dtype="bfloat16",
  99        text_encoder_checkpoint_path=None,
 100        persistent_storage_path=None,
 101        torch_compile=False,
 102        cpu_offload=False,
 103        quantized=False,
 104        overlapped_decode=False,
 105        **kwargs,
 106    ):
 107        if not checkpoint_dir:
 108            if persistent_storage_path is None:
 109                checkpoint_dir = os.path.join(os.path.expanduser("~"), ".cache/ace-step/checkpoints")
 110                os.makedirs(checkpoint_dir, exist_ok=True)
 111            else:
 112                checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints")
 113        ensure_directory_exists(checkpoint_dir)
 114
 115        self.checkpoint_dir = checkpoint_dir
 116        self.lora_path = "none"
 117        self.lora_weight = 1
 118        self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
 119        if gfx_device.type == "mps":
 120            if self.dtype == torch.bfloat16:
 121                self.dtype = torch.float16
 122
 123        if "ACE_PIPELINE_DTYPE" in os.environ and len(os.environ["ACE_PIPELINE_DTYPE"]):
 124            self.dtype = getattr(torch, os.environ["ACE_PIPELINE_DTYPE"])
 125        self.device: torch.device = gfx_device
 126        self.loaded = False
 127        self.torch_compile = torch_compile
 128        self.cpu_offload = cpu_offload
 129        self.quantized = quantized
 130        self.overlapped_decode = overlapped_decode
 131
 132    def get_checkpoint_path(self, checkpoint_dir, repo):
 133        checkpoint_dir_models = None
 134
 135        if checkpoint_dir is not None:
 136            required_dirs = ["music_dcae_f8c8", "music_vocoder", "ace_step_transformer", "umt5-base"]
 137            all_dirs_exist = True
 138            for dir_name in required_dirs:
 139                dir_path = os.path.join(checkpoint_dir, dir_name)
 140                if not os.path.exists(dir_path):
 141                    all_dirs_exist = False
 142                    break
 143
 144            if all_dirs_exist:
 145                nfo(f"Load models from: {checkpoint_dir}")
 146                checkpoint_dir_models = checkpoint_dir
 147
 148        if checkpoint_dir_models is None:
 149            if checkpoint_dir is None:
 150                nfo(f"Download models from Hugging Face: {repo}")
 151                checkpoint_dir_models = snapshot_download(repo)
 152            else:
 153                nfo(f"Download models from Hugging Face: {repo}, cache to: {checkpoint_dir}")
 154                checkpoint_dir_models = snapshot_download(repo, cache_dir=checkpoint_dir)
 155        return checkpoint_dir_models
 156
 157    def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False):
 158        checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID)
 159        dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
 160        vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder")
 161        ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer")
 162        text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base")
 163
 164        self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype)
 165        # self.ace_step_transformer.to(self.device).eval().to(self.dtype)
 166        if self.cpu_offload:
 167            self.ace_step_transformer = self.ace_step_transformer.to("cpu").eval().to(self.dtype)
 168        else:
 169            self.ace_step_transformer = self.ace_step_transformer.to(self.device).eval().to(self.dtype)
 170        if self.torch_compile:
 171            self.ace_step_transformer = torch.compile(self.ace_step_transformer)
 172
 173        self.music_dcae = MusicDCAE(
 174            dcae_checkpoint_path=dcae_checkpoint_path,
 175            vocoder_checkpoint_path=vocoder_checkpoint_path,
 176        )
 177        # self.music_dcae.to(self.device).eval().to(self.dtype)
 178        if self.cpu_offload:  # might be redundant
 179            self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype)
 180        else:
 181            self.music_dcae = self.music_dcae.to(self.device).eval().to(self.dtype)
 182        if self.torch_compile:
 183            self.music_dcae = torch.compile(self.music_dcae)
 184
 185        lang_segment = LangSegment()
 186        lang_segment.setfilters(language_filters.default)
 187        self.lang_segment = lang_segment
 188        self.lyric_tokenizer = VoiceBpeTokenizer()
 189
 190        text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval()
 191        # text_encoder_model = text_encoder_model.to(self.device).to(self.dtype)
 192        if self.cpu_offload:
 193            text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype)
 194        else:
 195            text_encoder_model = text_encoder_model.to(self.device).eval().to(self.dtype)
 196        text_encoder_model.requires_grad_(False)
 197        self.text_encoder_model = text_encoder_model
 198        if self.torch_compile:
 199            self.text_encoder_model = torch.compile(self.text_encoder_model)
 200
 201        self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path)
 202        self.loaded = True
 203
 204        # compile
 205        if self.torch_compile:
 206            if export_quantized_weights:
 207                from torch.ao.quantization import (
 208                    Int4WeightOnlyConfig,
 209                    quantize_,
 210                )
 211
 212                group_size = 128
 213                use_hqq = True
 214                quantize_(
 215                    self.ace_step_transformer,
 216                    Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq),
 217                )
 218                quantize_(
 219                    self.text_encoder_model,
 220                    Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq),
 221                )
 222
 223                # save quantized weights
 224                torch.save(
 225                    self.ace_step_transformer.state_dict(),
 226                    os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"),
 227                )
 228                print(
 229                    "Quantized Weights Saved to: ",
 230                    os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"),
 231                )
 232                torch.save(
 233                    self.text_encoder_model.state_dict(),
 234                    os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"),
 235                )
 236                print(
 237                    "Quantized Weights Saved to: ",
 238                    os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"),
 239                )
 240
 241    def load_quantized_checkpoint(self, checkpoint_dir=None):
 242        checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID_QUANT)
 243        dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
 244        vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder")
 245        ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer")
 246        text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base")
 247
 248        self.music_dcae = MusicDCAE(
 249            dcae_checkpoint_path=dcae_checkpoint_path,
 250            vocoder_checkpoint_path=vocoder_checkpoint_path,
 251        )
 252        if self.cpu_offload:
 253            self.music_dcae.eval().to(self.dtype).to(self.device)
 254        else:
 255            self.music_dcae.eval().to(self.dtype).to("cpu")
 256        self.music_dcae = torch.compile(self.music_dcae)
 257
 258        self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path)
 259        self.ace_step_transformer.eval().to(self.dtype).to("cpu")
 260        self.ace_step_transformer = torch.compile(self.ace_step_transformer)
 261        self.ace_step_transformer.load_state_dict(
 262            torch.load(
 263                os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"),
 264                map_location=self.device,
 265            ),
 266            assign=True,
 267        )
 268        self.ace_step_transformer.torchao_quantized = True
 269
 270        self.text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path)
 271        self.text_encoder_model.eval().to(self.dtype).to("cpu")
 272        self.text_encoder_model = torch.compile(self.text_encoder_model)
 273        self.text_encoder_model.load_state_dict(
 274            torch.load(
 275                os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"),
 276                map_location=self.device,
 277            ),
 278            assign=True,
 279        )
 280        self.text_encoder_model.torchao_quantized = True
 281
 282        self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path)
 283
 284        lang_segment = LangSegment()
 285        lang_segment.setfilters(language_filters.default)
 286        self.lang_segment = lang_segment
 287        self.lyric_tokenizer = VoiceBpeTokenizer()
 288
 289        self.loaded = True
 290
 291    @cpu_offload("text_encoder_model")
 292    def get_text_embeddings(self, texts, text_max_length=256):
 293        inputs = self.text_tokenizer(
 294            texts,
 295            return_tensors="pt",
 296            padding=True,
 297            truncation=True,
 298            max_length=text_max_length,
 299        )
 300        inputs = {key: value.to(self.device) for key, value in inputs.items()}
 301        if self.text_encoder_model.device != self.device:
 302            self.text_encoder_model.to(self.device)
 303        with torch.no_grad():
 304            outputs = self.text_encoder_model(**inputs)
 305            last_hidden_states = outputs.last_hidden_state
 306        attention_mask = inputs["attention_mask"]
 307        return last_hidden_states, attention_mask
 308
 309    @cpu_offload("text_encoder_model")
 310    def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10):
 311        inputs = self.text_tokenizer(
 312            texts,
 313            return_tensors="pt",
 314            padding=True,
 315            truncation=True,
 316            max_length=text_max_length,
 317        )
 318        inputs = {key: value.to(self.device) for key, value in inputs.items()}
 319        if self.text_encoder_model.device != self.device:
 320            self.text_encoder_model.to(self.device)
 321
 322        def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10):
 323            handlers = []
 324
 325            def hook(module, input, output):
 326                output[:] *= tau
 327                return output
 328
 329            for i in range(l_min, l_max):
 330                handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook)
 331                handlers.append(handler)
 332
 333            with torch.no_grad():
 334                outputs = self.text_encoder_model(**inputs)
 335                last_hidden_states = outputs.last_hidden_state
 336
 337            for hook in handlers:
 338                hook.remove()
 339
 340            return last_hidden_states
 341
 342        last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max)
 343        return last_hidden_states
 344
 345    def set_seeds(self, batch_size, manual_seeds=None):
 346        processed_input_seeds = None
 347        if manual_seeds is not None:
 348            if isinstance(manual_seeds, str):
 349                if "," in manual_seeds:
 350                    processed_input_seeds = list(map(int, manual_seeds.split(",")))
 351                elif manual_seeds.isdigit():
 352                    processed_input_seeds = int(manual_seeds)
 353            elif isinstance(manual_seeds, list) and all(isinstance(s, int) for s in manual_seeds):
 354                if len(manual_seeds) > 0:
 355                    processed_input_seeds = list(manual_seeds)
 356            elif isinstance(manual_seeds, int):
 357                processed_input_seeds = manual_seeds
 358        random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)]
 359        actual_seeds = []
 360        for i in range(batch_size):
 361            current_seed_for_generator = None
 362            if processed_input_seeds is None:
 363                current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
 364            elif isinstance(processed_input_seeds, int):
 365                current_seed_for_generator = processed_input_seeds
 366            elif isinstance(processed_input_seeds, list):
 367                if i < len(processed_input_seeds):
 368                    current_seed_for_generator = processed_input_seeds[i]
 369                else:
 370                    current_seed_for_generator = processed_input_seeds[-1]
 371            if current_seed_for_generator is None:
 372                current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
 373            random_generators[i].manual_seed(current_seed_for_generator)
 374            actual_seeds.append(current_seed_for_generator)
 375        return random_generators, actual_seeds
 376
 377    def get_lang(self, text):
 378        language = "en"
 379        try:
 380            _ = self.lang_segment.getTexts(text)
 381            langCounts = self.lang_segment.getCounts()
 382            language = langCounts[0][0]
 383            if len(langCounts) > 1 and language == "en":
 384                language = langCounts[1][0]
 385        except Exception as err:
 386            language = "en"
 387        return language
 388
 389    def tokenize_lyrics(self, lyrics, debug=False):
 390        lines = lyrics.split("\n")
 391        lyric_token_idx = [261]
 392        for line in lines:
 393            line = line.strip()
 394            if not line:
 395                lyric_token_idx += [2]
 396                continue
 397
 398            lang = self.get_lang(line)
 399
 400            if lang not in SUPPORT_LANGUAGES:
 401                lang = "en"
 402            if "zh" in lang:
 403                lang = "zh"
 404            if "spa" in lang:
 405                lang = "es"
 406
 407            try:
 408                if structure_pattern.match(line):
 409                    token_idx = self.lyric_tokenizer.encode(line, "en")
 410                else:
 411                    token_idx = self.lyric_tokenizer.encode(line, lang)
 412                if debug:
 413                    toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx])
 414                    nfo(f"debbug {line} --> {lang} --> {toks}")
 415                lyric_token_idx = lyric_token_idx + token_idx + [2]
 416            except Exception as e:
 417                print("tokenize error", e, "for line", line, "major_language", lang)
 418        return lyric_token_idx
 419
 420    @cpu_offload("ace_step_transformer")
 421    def calc_v(
 422        self,
 423        zt_src,
 424        zt_tar,
 425        t,
 426        encoder_text_hidden_states,
 427        text_attention_mask,
 428        target_encoder_text_hidden_states,
 429        target_text_attention_mask,
 430        speaker_embds,
 431        target_speaker_embeds,
 432        lyric_token_ids,
 433        lyric_mask,
 434        target_lyric_token_ids,
 435        target_lyric_mask,
 436        do_classifier_free_guidance=False,
 437        guidance_scale=1.0,
 438        target_guidance_scale=1.0,
 439        cfg_type="apg",
 440        attention_mask=None,
 441        momentum_buffer=None,
 442        momentum_buffer_tar=None,
 443        return_src_pred=True,
 444    ):
 445        noise_pred_src = None
 446        if return_src_pred:
 447            src_latent_model_input = torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src
 448            timestep = t.expand(src_latent_model_input.shape[0])
 449            # source
 450            noise_pred_src = self.ace_step_transformer(
 451                hidden_states=src_latent_model_input,
 452                attention_mask=attention_mask,
 453                encoder_text_hidden_states=encoder_text_hidden_states,
 454                text_attention_mask=text_attention_mask,
 455                speaker_embeds=speaker_embds,
 456                lyric_token_idx=lyric_token_ids,
 457                lyric_mask=lyric_mask,
 458                timestep=timestep,
 459            ).sample
 460
 461            if do_classifier_free_guidance:
 462                noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk(2)
 463                if cfg_type == "apg":
 464                    noise_pred_src = apg_forward(
 465                        pred_cond=noise_pred_with_cond_src,
 466                        pred_uncond=noise_pred_uncond_src,
 467                        guidance_scale=guidance_scale,
 468                        momentum_buffer=momentum_buffer,
 469                    )
 470                elif cfg_type == "cfg":
 471                    noise_pred_src = cfg_forward(
 472                        cond_output=noise_pred_with_cond_src,
 473                        uncond_output=noise_pred_uncond_src,
 474                        cfg_strength=guidance_scale,
 475                    )
 476
 477        tar_latent_model_input = torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar
 478        timestep = t.expand(tar_latent_model_input.shape[0])
 479        # target
 480        noise_pred_tar = self.ace_step_transformer(
 481            hidden_states=tar_latent_model_input,
 482            attention_mask=attention_mask,
 483            encoder_text_hidden_states=target_encoder_text_hidden_states,
 484            text_attention_mask=target_text_attention_mask,
 485            speaker_embeds=target_speaker_embeds,
 486            lyric_token_idx=target_lyric_token_ids,
 487            lyric_mask=target_lyric_mask,
 488            timestep=timestep,
 489        ).sample
 490
 491        if do_classifier_free_guidance:
 492            noise_pred_with_cond_tar, noise_pred_uncond_tar = noise_pred_tar.chunk(2)
 493            if cfg_type == "apg":
 494                noise_pred_tar = apg_forward(
 495                    pred_cond=noise_pred_with_cond_tar,
 496                    pred_uncond=noise_pred_uncond_tar,
 497                    guidance_scale=target_guidance_scale,
 498                    momentum_buffer=momentum_buffer_tar,
 499                )
 500            elif cfg_type == "cfg":
 501                noise_pred_tar = cfg_forward(
 502                    cond_output=noise_pred_with_cond_tar,
 503                    uncond_output=noise_pred_uncond_tar,
 504                    cfg_strength=target_guidance_scale,
 505                )
 506        return noise_pred_src, noise_pred_tar
 507
 508    @torch.no_grad()
 509    def flowedit_diffusion_process(
 510        self,
 511        encoder_text_hidden_states,
 512        text_attention_mask,
 513        speaker_embds,
 514        lyric_token_ids,
 515        lyric_mask,
 516        target_encoder_text_hidden_states,
 517        target_text_attention_mask,
 518        target_speaker_embeds,
 519        target_lyric_token_ids,
 520        target_lyric_mask,
 521        src_latents,
 522        random_generators=None,
 523        infer_steps=60,
 524        guidance_scale=15.0,
 525        n_min=0,
 526        n_max=1.0,
 527        n_avg=1,
 528        scheduler_type="euler",
 529    ):
 530        do_classifier_free_guidance = True
 531        if guidance_scale == 0.0 or guidance_scale == 1.0:
 532            do_classifier_free_guidance = False
 533
 534        target_guidance_scale = guidance_scale
 535        bsz = encoder_text_hidden_states.shape[0]
 536
 537        scheduler = FlowMatchEulerDiscreteScheduler(
 538            num_train_timesteps=1000,
 539            shift=3.0,
 540        )
 541
 542        T_steps = infer_steps
 543        frame_length = src_latents.shape[-1]
 544        attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype)
 545
 546        timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, self.device, timesteps=None)
 547
 548        if do_classifier_free_guidance:
 549            attention_mask = torch.cat([attention_mask] * 2, dim=0)
 550
 551            encoder_text_hidden_states = torch.cat(
 552                [
 553                    encoder_text_hidden_states,
 554                    torch.zeros_like(encoder_text_hidden_states),
 555                ],
 556                0,
 557            )
 558            text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
 559
 560            target_encoder_text_hidden_states = torch.cat(
 561                [
 562                    target_encoder_text_hidden_states,
 563                    torch.zeros_like(target_encoder_text_hidden_states),
 564                ],
 565                0,
 566            )
 567            target_text_attention_mask = torch.cat([target_text_attention_mask] * 2, dim=0)
 568
 569            speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0)
 570            target_speaker_embeds = torch.cat([target_speaker_embeds, torch.zeros_like(target_speaker_embeds)], 0)
 571
 572            lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0)
 573            lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
 574
 575            target_lyric_token_ids = torch.cat([target_lyric_token_ids, torch.zeros_like(target_lyric_token_ids)], 0)
 576            target_lyric_mask = torch.cat([target_lyric_mask, torch.zeros_like(target_lyric_mask)], 0)
 577
 578        momentum_buffer = MomentumBuffer()
 579        momentum_buffer_tar = MomentumBuffer()
 580        x_src = src_latents
 581        zt_edit = x_src.clone()
 582        xt_tar = None
 583        n_min = int(infer_steps * n_min)
 584        n_max = int(infer_steps * n_max)
 585
 586        nfo("flowedit start from {} to {}".format(n_min, n_max))
 587
 588        for i, t in tqdm(enumerate(timesteps), total=T_steps):
 589            if i < n_min:
 590                continue
 591
 592            t_i = t / 1000
 593
 594            if i + 1 < len(timesteps):
 595                t_im1 = (timesteps[i + 1]) / 1000
 596            else:
 597                t_im1 = torch.zeros_like(t_i).to(self.device)
 598
 599            if i < n_max:
 600                # Calculate the average of the V predictions
 601                V_delta_avg = torch.zeros_like(x_src)
 602                for k in range(n_avg):
 603                    fwd_noise = randn_tensor(
 604                        shape=x_src.shape,
 605                        generator=random_generators,
 606                        device=self.device,
 607                        dtype=self.dtype,
 608                    )
 609
 610                    zt_src = (1 - t_i) * x_src + (t_i) * fwd_noise
 611
 612                    zt_tar = zt_edit + zt_src - x_src
 613
 614                    Vt_src, Vt_tar = self.calc_v(
 615                        zt_src=zt_src,
 616                        zt_tar=zt_tar,
 617                        t=t,
 618                        encoder_text_hidden_states=encoder_text_hidden_states,
 619                        text_attention_mask=text_attention_mask,
 620                        target_encoder_text_hidden_states=target_encoder_text_hidden_states,
 621                        target_text_attention_mask=target_text_attention_mask,
 622                        speaker_embds=speaker_embds,
 623                        target_speaker_embeds=target_speaker_embeds,
 624                        lyric_token_ids=lyric_token_ids,
 625                        lyric_mask=lyric_mask,
 626                        target_lyric_token_ids=target_lyric_token_ids,
 627                        target_lyric_mask=target_lyric_mask,
 628                        do_classifier_free_guidance=do_classifier_free_guidance,
 629                        guidance_scale=guidance_scale,
 630                        target_guidance_scale=target_guidance_scale,
 631                        attention_mask=attention_mask,
 632                        momentum_buffer=momentum_buffer,
 633                    )
 634                    V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src)  # - (hfg - 1) * (x_src)
 635
 636                zt_edit = zt_edit.to(torch.float32)  # arbitrary, should be settable for compatibility
 637                if scheduler_type != "pingpong":
 638                    # propagate direct ODE
 639                    zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg
 640                    zt_edit = zt_edit.to(self.dtype)
 641                else:
 642                    # propagate pingpong SDE
 643                    zt_edit_denoised = zt_edit - t_i * V_delta_avg
 644                    noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None)
 645                    prev_sample = (1 - t_im1) * zt_edit_denoised + t_im1 * noise
 646
 647            else:  # i >= T_steps-n_min # regular sampling for last n_min steps
 648                if i == n_max:
 649                    fwd_noise = randn_tensor(
 650                        shape=x_src.shape,
 651                        generator=random_generators,
 652                        device=self.device,
 653                        dtype=self.dtype,
 654                    )
 655                    scheduler._init_step_index(t)
 656                    sigma = scheduler.sigmas[scheduler.step_index]
 657                    xt_src = sigma * fwd_noise + (1.0 - sigma) * x_src
 658                    xt_tar = zt_edit + xt_src - x_src
 659
 660                _, Vt_tar = self.calc_v(
 661                    zt_src=None,
 662                    zt_tar=xt_tar,
 663                    t=t,
 664                    encoder_text_hidden_states=encoder_text_hidden_states,
 665                    text_attention_mask=text_attention_mask,
 666                    target_encoder_text_hidden_states=target_encoder_text_hidden_states,
 667                    target_text_attention_mask=target_text_attention_mask,
 668                    speaker_embds=speaker_embds,
 669                    target_speaker_embeds=target_speaker_embeds,
 670                    lyric_token_ids=lyric_token_ids,
 671                    lyric_mask=lyric_mask,
 672                    target_lyric_token_ids=target_lyric_token_ids,
 673                    target_lyric_mask=target_lyric_mask,
 674                    do_classifier_free_guidance=do_classifier_free_guidance,
 675                    guidance_scale=guidance_scale,
 676                    target_guidance_scale=target_guidance_scale,
 677                    attention_mask=attention_mask,
 678                    momentum_buffer_tar=momentum_buffer_tar,
 679                    return_src_pred=False,
 680                )
 681
 682                xt_tar = xt_tar.to(torch.float32)
 683                if scheduler_type != "pingpong":
 684                    prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar
 685                    prev_sample = prev_sample.to(self.dtype)
 686                    xt_tar = prev_sample
 687                else:
 688                    prev_sample = xt_tar - t_i * Vt_tar
 689                    noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None)
 690                    prev_sample = (1 - t_im1) * prev_sample + t_im1 * noise
 691                    xt_tar = prev_sample
 692
 693        target_latents = zt_edit if xt_tar is None else xt_tar
 694        return target_latents
 695
 696    def add_latents_noise(
 697        self,
 698        gt_latents,
 699        sigma_max,
 700        noise,
 701        scheduler_type,
 702        infer_steps,
 703    ):
 704        bsz = gt_latents.shape[0]
 705        if scheduler_type == "euler":
 706            scheduler = FlowMatchEulerDiscreteScheduler(
 707                num_train_timesteps=1000,
 708                shift=3.0,
 709                sigma_max=sigma_max,
 710            )
 711        elif scheduler_type == "heun":
 712            scheduler = FlowMatchHeunDiscreteScheduler(
 713                num_train_timesteps=1000,
 714                shift=3.0,
 715                sigma_max=sigma_max,
 716            )
 717        elif scheduler_type == "pingpong":
 718            scheduler = FlowMatchPingPongScheduler(num_train_timesteps=1000, shift=3.0, sigma_max=sigma_max)
 719
 720        infer_steps = int(sigma_max * infer_steps)
 721        timesteps, num_inference_steps = retrieve_timesteps(
 722            scheduler,
 723            num_inference_steps=infer_steps,
 724            device=self.device,
 725            timesteps=None,
 726        )
 727        noisy_image = gt_latents * (1 - scheduler.sigma_max) + noise * scheduler.sigma_max
 728        nfo(f"{scheduler.sigma_min=} {scheduler.sigma_max=} {timesteps=} {num_inference_steps=}")
 729        return noisy_image, timesteps, scheduler, num_inference_steps
 730
 731    @cpu_offload("ace_step_transformer")
 732    @torch.no_grad()
 733    def text2music_diffusion_process(
 734        self,
 735        duration,
 736        encoder_text_hidden_states,
 737        text_attention_mask,
 738        speaker_embds,
 739        lyric_token_ids,
 740        lyric_mask,
 741        random_generators=None,
 742        infer_steps=60,
 743        guidance_scale=15.0,
 744        omega_scale=10.0,
 745        scheduler_type="euler",
 746        cfg_type="apg",
 747        zero_steps=1,
 748        use_zero_init=True,
 749        guidance_interval=0.5,
 750        guidance_interval_decay=1.0,
 751        min_guidance_scale=3.0,
 752        oss_steps=[],
 753        encoder_text_hidden_states_null=None,
 754        use_erg_lyric=False,
 755        use_erg_diffusion=False,
 756        retake_random_generators=None,
 757        retake_variance=0.5,
 758        add_retake_noise=False,
 759        guidance_scale_text=0.0,
 760        guidance_scale_lyric=0.0,
 761        repaint_start=0,
 762        repaint_end=0,
 763        src_latents=None,
 764        audio2audio_enable=False,
 765        ref_audio_strength=0.5,
 766        ref_latents=None,
 767    ):
 768        nfo("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
 769        do_classifier_free_guidance = True
 770        if guidance_scale == 0.0 or guidance_scale == 1.0:
 771            do_classifier_free_guidance = False
 772
 773        do_double_condition_guidance = False
 774        if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0:
 775            do_double_condition_guidance = True
 776            nfo(
 777                "do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format(
 778                    do_double_condition_guidance,
 779                    guidance_scale_text,
 780                    guidance_scale_lyric,
 781                )
 782            )
 783
 784        bsz = encoder_text_hidden_states.shape[0]
 785
 786        if scheduler_type == "euler":
 787            scheduler = FlowMatchEulerDiscreteScheduler(
 788                num_train_timesteps=1000,
 789                shift=3.0,
 790            )
 791        elif scheduler_type == "heun":
 792            scheduler = FlowMatchHeunDiscreteScheduler(
 793                num_train_timesteps=1000,
 794                shift=3.0,
 795            )
 796        elif scheduler_type == "pingpong":
 797            scheduler = FlowMatchPingPongScheduler(
 798                num_train_timesteps=1000,
 799                shift=3.0,
 800            )
 801
 802        frame_length = int(duration * 44100 / 512 / 8)
 803        if src_latents is not None:
 804            frame_length = src_latents.shape[-1]
 805
 806        if ref_latents is not None:
 807            frame_length = ref_latents.shape[-1]
 808
 809        if len(oss_steps) > 0:
 810            infer_steps = max(oss_steps)
 811            scheduler.set_timesteps
 812            timesteps, num_inference_steps = retrieve_timesteps(
 813                scheduler,
 814                num_inference_steps=infer_steps,
 815                device=self.device,
 816                timesteps=None,
 817            )
 818            new_timesteps = torch.zeros(len(oss_steps), dtype=self.dtype, device=self.device)
 819            for idx in range(len(oss_steps)):
 820                new_timesteps[idx] = timesteps[oss_steps[idx] - 1]
 821            num_inference_steps = len(oss_steps)
 822            sigmas = (new_timesteps / 1000).float().cpu().numpy()
 823            timesteps, num_inference_steps = retrieve_timesteps(
 824                scheduler,
 825                num_inference_steps=num_inference_steps,
 826                device=self.device,
 827                sigmas=sigmas,
 828            )
 829            nfo(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}")
 830        else:
 831            timesteps, num_inference_steps = retrieve_timesteps(
 832                scheduler,
 833                num_inference_steps=infer_steps,
 834                device=self.device,
 835                timesteps=None,
 836            )
 837
 838        target_latents = randn_tensor(
 839            shape=(bsz, 8, 16, frame_length),
 840            generator=random_generators,
 841            device=self.device,
 842            dtype=self.dtype,
 843        )
 844
 845        is_repaint = False
 846        is_extend = False
 847
 848        if add_retake_noise:
 849            n_min = int(infer_steps * (1 - retake_variance))
 850            retake_variance = torch.tensor(retake_variance * math.pi / 2).to(self.device).to(self.dtype)
 851            retake_latents = randn_tensor(
 852                shape=(bsz, 8, 16, frame_length),
 853                generator=retake_random_generators,
 854                device=self.device,
 855                dtype=self.dtype,
 856            )
 857            repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
 858            repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
 859            x0 = src_latents
 860            # retake
 861            is_repaint = repaint_end_frame - repaint_start_frame != frame_length
 862
 863            is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length)
 864            if is_extend:
 865                is_repaint = True
 866
 867            # TODO: train a mask aware repainting controlnet
 868            # to make sure mean = 0, std = 1
 869            if not is_repaint:
 870                target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
 871            elif not is_extend:
 872                # if repaint_end_frame
 873                repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype)
 874                repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
 875                repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
 876                repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
 877                zt_edit = x0.clone()
 878                z0 = repaint_noise
 879            elif is_extend:
 880                to_right_pad_gt_latents = None
 881                to_left_pad_gt_latents = None
 882                gt_latents = src_latents
 883                src_latents_length = gt_latents.shape[-1]
 884                max_infer_fame_length = int(240 * 44100 / 512 / 8)
 885                left_pad_frame_length = 0
 886                right_pad_frame_length = 0
 887                right_trim_length = 0
 888                left_trim_length = 0
 889                if repaint_start_frame < 0:
 890                    left_pad_frame_length = abs(repaint_start_frame)
 891                    frame_length = left_pad_frame_length + gt_latents.shape[-1]
 892                    extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0)
 893                    if frame_length > max_infer_fame_length:
 894                        right_trim_length = frame_length - max_infer_fame_length
 895                        extend_gt_latents = extend_gt_latents[:, :, :, :max_infer_fame_length]
 896                        to_right_pad_gt_latents = extend_gt_latents[:, :, :, -right_trim_length:]
 897                        frame_length = max_infer_fame_length
 898                    repaint_start_frame = 0
 899                    gt_latents = extend_gt_latents
 900
 901                if repaint_end_frame > src_latents_length:
 902                    right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1]
 903                    frame_length = gt_latents.shape[-1] + right_pad_frame_length
 904                    extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0)
 905                    if frame_length > max_infer_fame_length:
 906                        left_trim_length = frame_length - max_infer_fame_length
 907                        extend_gt_latents = extend_gt_latents[:, :, :, -max_infer_fame_length:]
 908                        to_left_pad_gt_latents = extend_gt_latents[:, :, :, :left_trim_length]
 909                        frame_length = max_infer_fame_length
 910                    repaint_end_frame = frame_length
 911                    gt_latents = extend_gt_latents
 912
 913                repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype)
 914                if left_pad_frame_length > 0:
 915                    repaint_mask[:, :, :, :left_pad_frame_length] = 1.0
 916                if right_pad_frame_length > 0:
 917                    repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0
 918                x0 = gt_latents
 919                padd_list = []
 920                if left_pad_frame_length > 0:
 921                    padd_list.append(retake_latents[:, :, :, :left_pad_frame_length])
 922                padd_list.append(
 923                    target_latents[
 924                        :,
 925                        :,
 926                        :,
 927                        left_trim_length : target_latents.shape[-1] - right_trim_length,
 928                    ]
 929                )
 930                if right_pad_frame_length > 0:
 931                    padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
 932                target_latents = torch.cat(padd_list, dim=-1)
 933                assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
 934                zt_edit = x0.clone()
 935                z0 = target_latents
 936
 937        if audio2audio_enable and ref_latents is not None:
 938            nfo(f"audio2audio_enable: {audio2audio_enable}, ref_latents: {ref_latents.shape}")
 939            target_latents, timesteps, scheduler, num_inference_steps = self.add_latents_noise(
 940                gt_latents=ref_latents,
 941                sigma_max=(1 - ref_audio_strength),
 942                noise=target_latents,
 943                scheduler_type=scheduler_type,
 944                infer_steps=infer_steps,
 945            )
 946
 947        attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype)
 948
 949        # guidance interval
 950        start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
 951        end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
 952        nfo(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
 953
 954        momentum_buffer = MomentumBuffer()
 955
 956        def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
 957            handlers = []
 958
 959            def hook(module, input, output):
 960                output[:] *= tau
 961                return output
 962
 963            for i in range(l_min, l_max):
 964                handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
 965                handlers.append(handler)
 966
 967            encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
 968
 969            for hook in handlers:
 970                hook.remove()
 971
 972            return encoder_hidden_states
 973
 974        # P(speaker, text, lyric)
 975        encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(
 976            encoder_text_hidden_states,
 977            text_attention_mask,
 978            speaker_embds,
 979            lyric_token_ids,
 980            lyric_mask,
 981        )
 982
 983        if use_erg_lyric:
 984            # P(null_speaker, text_weaker, lyric_weaker)
 985            encoder_hidden_states_null = forward_encoder_with_temperature(
 986                self,
 987                inputs={
 988                    "encoder_text_hidden_states": (
 989                        encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states)
 990                    ),
 991                    "text_attention_mask": text_attention_mask,
 992                    "speaker_embeds": torch.zeros_like(speaker_embds),
 993                    "lyric_token_idx": lyric_token_ids,
 994                    "lyric_mask": lyric_mask,
 995                },
 996            )
 997        else:
 998            # P(null_speaker, null_text, null_lyric)
 999            encoder_hidden_states_null, _ = self.ace_step_transformer.encode(
1000                torch.zeros_like(encoder_text_hidden_states),
1001                text_attention_mask,
1002                torch.zeros_like(speaker_embds),
1003                torch.zeros_like(lyric_token_ids),
1004                lyric_mask,
1005            )
1006
1007        encoder_hidden_states_no_lyric = None
1008        if do_double_condition_guidance:
1009            # P(null_speaker, text, lyric_weaker)
1010            if use_erg_lyric:
1011                encoder_hidden_states_no_lyric = forward_encoder_with_temperature(
1012                    self,
1013                    inputs={
1014                        "encoder_text_hidden_states": encoder_text_hidden_states,
1015                        "text_attention_mask": text_attention_mask,
1016                        "speaker_embeds": torch.zeros_like(speaker_embds),
1017                        "lyric_token_idx": lyric_token_ids,
1018                        "lyric_mask": lyric_mask,
1019                    },
1020                )
1021            # P(null_speaker, text, no_lyric)
1022            else:
1023                encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode(
1024                    encoder_text_hidden_states,
1025                    text_attention_mask,
1026                    torch.zeros_like(speaker_embds),
1027                    torch.zeros_like(lyric_token_ids),
1028                    lyric_mask,
1029                )
1030
1031        def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
1032            handlers = []
1033
1034            def hook(module, input, output):
1035                output[:] *= tau
1036                return output
1037
1038            for i in range(l_min, l_max):
1039                handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
1040                handlers.append(handler)
1041                handler = self.ace_step_transformer.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook)
1042                handlers.append(handler)
1043
1044            sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
1045
1046            for hook in handlers:
1047                hook.remove()
1048
1049            return sample
1050
1051        for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
1052            if is_repaint:
1053                if i < n_min:
1054                    continue
1055                elif i == n_min:
1056                    t_i = t / 1000
1057                    zt_src = (1 - t_i) * x0 + (t_i) * z0
1058                    target_latents = zt_edit + zt_src - x0
1059                    nfo(f"repaint start from {n_min} add {t_i} level of noise")
1060
1061            # expand the latents if we are doing classifier free guidance
1062            latents = target_latents
1063
1064            is_in_guidance_interval = start_idx <= i < end_idx
1065            if is_in_guidance_interval and do_classifier_free_guidance:
1066                # compute current guidance scale
1067                if guidance_interval_decay > 0:
1068                    # Linearly interpolate to calculate the current guidance scale
1069                    progress = (i - start_idx) / (end_idx - start_idx - 1)  # 归一化到[0,1]
1070                    current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay
1071                else:
1072                    current_guidance_scale = guidance_scale
1073
1074                latent_model_input = latents
1075                timestep = t.expand(latent_model_input.shape[0])
1076                output_length = latent_model_input.shape[-1]
1077                # P(x|speaker, text, lyric)
1078                noise_pred_with_cond = self.ace_step_transformer.decode(
1079                    hidden_states=latent_model_input,
1080                    attention_mask=attention_mask,
1081                    encoder_hidden_states=encoder_hidden_states,
1082                    encoder_hidden_mask=encoder_hidden_mask,
1083                    output_length=output_length,
1084                    timestep=timestep,
1085                ).sample
1086
1087                noise_pred_with_only_text_cond = None
1088                if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None:
1089                    noise_pred_with_only_text_cond = self.ace_step_transformer.decode(
1090                        hidden_states=latent_model_input,
1091                        attention_mask=attention_mask,
1092                        encoder_hidden_states=encoder_hidden_states_no_lyric,
1093                        encoder_hidden_mask=encoder_hidden_mask,
1094                        output_length=output_length,
1095                        timestep=timestep,
1096                    ).sample
1097
1098                if use_erg_diffusion:
1099                    noise_pred_uncond = forward_diffusion_with_temperature(
1100                        self,
1101                        hidden_states=latent_model_input,
1102                        timestep=timestep,
1103                        inputs={
1104                            "encoder_hidden_states": encoder_hidden_states_null,
1105                            "encoder_hidden_mask": encoder_hidden_mask,
1106                            "output_length": output_length,
1107                            "attention_mask": attention_mask,
1108                        },
1109                    )
1110                else:
1111                    noise_pred_uncond = self.ace_step_transformer.decode(
1112                        hidden_states=latent_model_input,
1113                        attention_mask=attention_mask,
1114                        encoder_hidden_states=encoder_hidden_states_null,
1115                        encoder_hidden_mask=encoder_hidden_mask,
1116                        output_length=output_length,
1117                        timestep=timestep,
1118                    ).sample
1119
1120                if do_double_condition_guidance and noise_pred_with_only_text_cond is not None:
1121                    noise_pred = cfg_double_condition_forward(
1122                        cond_output=noise_pred_with_cond,
1123                        uncond_output=noise_pred_uncond,
1124                        only_text_cond_output=noise_pred_with_only_text_cond,
1125                        guidance_scale_text=guidance_scale_text,
1126                        guidance_scale_lyric=guidance_scale_lyric,
1127                    )
1128
1129                elif cfg_type == "apg":
1130                    noise_pred = apg_forward(
1131                        pred_cond=noise_pred_with_cond,
1132                        pred_uncond=noise_pred_uncond,
1133                        guidance_scale=current_guidance_scale,
1134                        momentum_buffer=momentum_buffer,
1135                    )
1136                elif cfg_type == "cfg":
1137                    noise_pred = cfg_forward(
1138                        cond_output=noise_pred_with_cond,
1139                        uncond_output=noise_pred_uncond,
1140                        cfg_strength=current_guidance_scale,
1141                    )
1142                elif cfg_type == "cfg_star":
1143                    noise_pred = cfg_zero_star(
1144                        noise_pred_with_cond=noise_pred_with_cond,
1145                        noise_pred_uncond=noise_pred_uncond,
1146                        guidance_scale=current_guidance_scale,
1147                        i=i,
1148                        zero_steps=zero_steps,
1149                        use_zero_init=use_zero_init,
1150                    )
1151            else:
1152                latent_model_input = latents
1153                timestep = t.expand(latent_model_input.shape[0])
1154                noise_pred = self.ace_step_transformer.decode(
1155                    hidden_states=latent_model_input,
1156                    attention_mask=attention_mask,
1157                    encoder_hidden_states=encoder_hidden_states,
1158                    encoder_hidden_mask=encoder_hidden_mask,
1159                    output_length=latent_model_input.shape[-1],
1160                    timestep=timestep,
1161                ).sample
1162
1163            if is_repaint and i >= n_min:
1164                t_i = t / 1000
1165                if i + 1 < len(timesteps):
1166                    t_im1 = (timesteps[i + 1]) / 1000
1167                else:
1168                    t_im1 = torch.zeros_like(t_i).to(self.device)
1169                target_latents = target_latents.to(torch.float32)
1170                prev_sample = target_latents + (t_im1 - t_i) * noise_pred
1171                prev_sample = prev_sample.to(self.dtype)
1172                target_latents = prev_sample
1173                zt_src = (1 - t_im1) * x0 + (t_im1) * z0
1174                target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src)
1175            else:
1176                target_latents = scheduler.step(
1177                    model_output=noise_pred,
1178                    timestep=t,
1179                    sample=target_latents,
1180                    return_dict=False,
1181                    omega=omega_scale,
1182                    generator=random_generators[0],
1183                )[0]
1184
1185        if is_extend:
1186            if to_right_pad_gt_latents is not None:
1187                target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1)
1188            if to_left_pad_gt_latents is not None:
1189                target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0)
1190        return target_latents
1191
1192    @cpu_offload("music_dcae")
1193    def latents2audio(
1194        self,
1195        latents,
1196        target_wav_duration_second=30,
1197        sample_rate=48000,
1198        save_path=None,
1199        format="wav",
1200    ):
1201        output_audio_paths = []
1202        bs = latents.shape[0]
1203        pred_latents = latents
1204        with torch.no_grad():
1205            if self.overlapped_decode and target_wav_duration_second > 48:
1206                _, pred_wavs = self.music_dcae.decode_overlap(pred_latents, sr=sample_rate)
1207            else:
1208                _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
1209        pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
1210        for i in tqdm(range(bs)):
1211            output_audio_path = self.save_wav_file(
1212                pred_wavs[i],
1213                i,
1214                save_path=save_path,
1215                sample_rate=sample_rate,
1216                format=format,
1217            )
1218            output_audio_paths.append(output_audio_path)
1219        return output_audio_paths
1220
1221    def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="wav"):
1222        if save_path is None:
1223            nfo("save_path is None, using default path ./outputs/")
1224            base_path = "./outputs"
1225            ensure_directory_exists(base_path)
1226            output_path_wav = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format
1227        else:
1228            ensure_directory_exists(os.path.dirname(save_path))
1229            if os.path.isdir(save_path):
1230                nfo(f"Provided save_path '{save_path}' is a directory. Appending timestamped filename.")
1231                output_path_wav = os.path.join(save_path, f"output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format)
1232            else:
1233                output_path_wav = save_path
1234
1235        target_wav = target_wav.float()
1236        backend = "soundfile"
1237        if format == "ogg":
1238            backend = "sox"
1239        nfo(f"Saving audio to {output_path_wav} using backend {backend}")
1240        try:
1241            torchaudio.save(output_path_wav, target_wav, sample_rate=sample_rate, format=format, backend=backend)
1242        except (RuntimeError, ImportError, ModuleNotFoundError):
1243            import soundfile as sf  # pyright: ignore[reportMissingImports] | pylint:disable=import-error
1244
1245            sf.write(output_path_wav, target_wav, sample_rate)
1246        return output_path_wav
1247
1248    @cpu_offload("music_dcae")
1249    def infer_latents(self, input_audio_path):
1250        if input_audio_path is None:
1251            return None
1252        input_audio, sr = self.music_dcae.load_audio(input_audio_path)
1253        input_audio = input_audio.unsqueeze(0)
1254        input_audio = input_audio.to(device=self.device, dtype=self.dtype)
1255        latents, _ = self.music_dcae.encode(input_audio, sr=sr)
1256        return latents
1257
1258    def load_lora(self, lora_name_or_path, lora_weight):
1259        if (lora_name_or_path != self.lora_path or lora_weight != self.lora_weight) and lora_name_or_path != "none":
1260            if not os.path.exists(lora_name_or_path):
1261                lora_download_path = snapshot_download(lora_name_or_path, cache_dir=self.checkpoint_dir)
1262            else:
1263                lora_download_path = lora_name_or_path
1264            if self.lora_path != "none":
1265                self.ace_step_transformer.unload_lora()
1266            self.ace_step_transformer.load_lora_adapter(
1267                os.path.join(lora_download_path, "pytorch_lora_weights.safetensors"), adapter_name="ace_step_lora", with_alpha=True, prefix=None
1268            )
1269            nfo(f"Loading lora weights from: {lora_name_or_path} download path is: {lora_download_path} weight: {lora_weight}")
1270            set_weights_and_activate_adapters(self.ace_step_transformer, ["ace_step_lora"], [lora_weight])
1271            self.lora_path = lora_name_or_path
1272            self.lora_weight = lora_weight
1273        elif self.lora_path != "none" and lora_name_or_path == "none":
1274            nfo("No lora weights to load.")
1275            self.ace_step_transformer.unload_lora()
1276
1277    def __call__(
1278        self,
1279        format: str = "wav",
1280        audio_duration: float = 60.0,
1281        prompt: str = None,
1282        lyrics: str = None,
1283        infer_step: int = 60,
1284        guidance_scale: float = 15.0,
1285        scheduler_type: str = "euler",
1286        cfg_type: str = "apg",
1287        omega_scale: int = 10.0,
1288        manual_seeds: list = None,
1289        guidance_interval: float = 0.5,
1290        guidance_interval_decay: float = 0.0,
1291        min_guidance_scale: float = 3.0,
1292        use_erg_tag: bool = True,
1293        use_erg_lyric: bool = True,
1294        use_erg_diffusion: bool = True,
1295        oss_steps: str = None,
1296        guidance_scale_text: float = 0.0,
1297        guidance_scale_lyric: float = 0.0,
1298        audio2audio_enable: bool = False,
1299        ref_audio_strength: float = 0.5,
1300        ref_audio_input: str = None,
1301        lora_name_or_path: str = "none",
1302        lora_weight: float = 1.0,
1303        retake_seeds: list = None,
1304        retake_variance: float = 0.5,
1305        task: str = "text2music",
1306        repaint_start: int = 0,
1307        repaint_end: int = 0,
1308        src_audio_path: str = None,
1309        edit_target_prompt: str = None,
1310        edit_target_lyrics: str = None,
1311        edit_n_min: float = 0.0,
1312        edit_n_max: float = 1.0,
1313        edit_n_avg: int = 1,
1314        save_path: str = None,
1315        batch_size: int = 1,
1316        debug: bool = False,
1317    ):
1318        start_time = time.time()
1319
1320        if audio2audio_enable and ref_audio_input is not None:
1321            task = "audio2audio"
1322
1323        if not self.loaded:
1324            nfo("Checkpoint not loaded, loading checkpoint...")
1325            if self.quantized:
1326                self.load_quantized_checkpoint(self.checkpoint_dir)
1327            else:
1328                self.load_checkpoint(self.checkpoint_dir)
1329
1330        self.load_lora(lora_name_or_path, lora_weight)
1331        load_model_cost = time.time() - start_time
1332        nfo(f"Model loaded in {load_model_cost:.2f} seconds.")
1333
1334        start_time = time.time()
1335
1336        random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds)
1337        retake_random_generators, actual_retake_seeds = self.set_seeds(batch_size, retake_seeds)
1338
1339        if isinstance(oss_steps, str) and len(oss_steps) > 0:
1340            oss_steps = list(map(int, oss_steps.split(",")))
1341        else:
1342            oss_steps = []
1343
1344        texts = [prompt]
1345        encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts)
1346        encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
1347        text_attention_mask = text_attention_mask.repeat(batch_size, 1)
1348
1349        encoder_text_hidden_states_null = None
1350        if use_erg_tag:
1351            encoder_text_hidden_states_null = self.get_text_embeddings_null(texts)
1352            encoder_text_hidden_states_null = encoder_text_hidden_states_null.repeat(batch_size, 1, 1)
1353
1354        # not support for released checkpoint
1355        speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype)
1356
1357        # 6 lyric
1358        lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1359        lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1360        if len(lyrics) > 0:
1361            lyric_token_idx = self.tokenize_lyrics(lyrics, debug=debug)
1362            lyric_mask = [1] * len(lyric_token_idx)
1363            lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1364            lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1365
1366        if audio_duration <= 0:
1367            audio_duration = random.uniform(30.0, 240.0)
1368            nfo(f"random audio duration: {audio_duration}")
1369
1370        end_time = time.time()
1371        preprocess_time_cost = end_time - start_time
1372        start_time = end_time
1373
1374        add_retake_noise = task in ("retake", "repaint", "extend")
1375        # retake equal to repaint
1376        if task == "retake":
1377            repaint_start = 0
1378            repaint_end = audio_duration
1379
1380        src_latents = None
1381        if src_audio_path is not None:
1382            assert src_audio_path is not None and task in (
1383                "repaint",
1384                "edit",
1385                "extend",
1386            ), "src_audio_path is required for retake/repaint/extend task"
1387            assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
1388            src_latents = self.infer_latents(src_audio_path)
1389
1390        ref_latents = None
1391        if ref_audio_input is not None and audio2audio_enable:
1392            assert ref_audio_input is not None, "ref_audio_input is required for audio2audio task"
1393            assert os.path.exists(ref_audio_input), f"ref_audio_input {ref_audio_input} does not exist"
1394            ref_latents = self.infer_latents(ref_audio_input)
1395
1396        if task == "edit":
1397            texts = [edit_target_prompt]
1398            target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts)
1399            target_encoder_text_hidden_states = target_encoder_text_hidden_states.repeat(batch_size, 1, 1)
1400            target_text_attention_mask = target_text_attention_mask.repeat(batch_size, 1)
1401
1402            target_lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1403            target_lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1404            if len(edit_target_lyrics) > 0:
1405                target_lyric_token_idx = self.tokenize_lyrics(edit_target_lyrics, debug=True)
1406                target_lyric_mask = [1] * len(target_lyric_token_idx)
1407                target_lyric_token_idx = torch.tensor(target_lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1408                target_lyric_mask = torch.tensor(target_lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1409
1410            target_speaker_embeds = speaker_embeds.clone()
1411
1412            target_latents = self.flowedit_diffusion_process(
1413                encoder_text_hidden_states=encoder_text_hidden_states,
1414                text_attention_mask=text_attention_mask,
1415                speaker_embds=speaker_embeds,
1416                lyric_token_ids=lyric_token_idx,
1417                lyric_mask=lyric_mask,
1418                target_encoder_text_hidden_states=target_encoder_text_hidden_states,
1419                target_text_attention_mask=target_text_attention_mask,
1420                target_speaker_embeds=target_speaker_embeds,
1421                target_lyric_token_ids=target_lyric_token_idx,
1422                target_lyric_mask=target_lyric_mask,
1423                src_latents=src_latents,
1424                random_generators=retake_random_generators,  # more diversity
1425                infer_steps=infer_step,
1426                guidance_scale=guidance_scale,
1427                n_min=edit_n_min,
1428                n_max=edit_n_max,
1429                n_avg=edit_n_avg,
1430                scheduler_type=scheduler_type,
1431            )
1432        else:
1433            target_latents = self.text2music_diffusion_process(
1434                duration=audio_duration,
1435                encoder_text_hidden_states=encoder_text_hidden_states,
1436                text_attention_mask=text_attention_mask,
1437                speaker_embds=speaker_embeds,
1438                lyric_token_ids=lyric_token_idx,
1439                lyric_mask=lyric_mask,
1440                guidance_scale=guidance_scale,
1441                omega_scale=omega_scale,
1442                infer_steps=infer_step,
1443                random_generators=random_generators,
1444                scheduler_type=scheduler_type,
1445                cfg_type=cfg_type,
1446                guidance_interval=guidance_interval,
1447                guidance_interval_decay=guidance_interval_decay,
1448                min_guidance_scale=min_guidance_scale,
1449                oss_steps=oss_steps,
1450                encoder_text_hidden_states_null=encoder_text_hidden_states_null,
1451                use_erg_lyric=use_erg_lyric,
1452                use_erg_diffusion=use_erg_diffusion,
1453                retake_random_generators=retake_random_generators,
1454                retake_variance=retake_variance,
1455                add_retake_noise=add_retake_noise,
1456                guidance_scale_text=guidance_scale_text,
1457                guidance_scale_lyric=guidance_scale_lyric,
1458                repaint_start=repaint_start,
1459                repaint_end=repaint_end,
1460                src_latents=src_latents,
1461                audio2audio_enable=audio2audio_enable,
1462                ref_audio_strength=ref_audio_strength,
1463                ref_latents=ref_latents,
1464            )
1465
1466        end_time = time.time()
1467        diffusion_time_cost = end_time - start_time
1468        start_time = end_time
1469
1470        output_paths = self.latents2audio(
1471            latents=target_latents,
1472            target_wav_duration_second=audio_duration,
1473            save_path=save_path,
1474            format=format,
1475        )
1476
1477        # Clean up memory after generation
1478        empty_cache
1479
1480        end_time = time.time()
1481        latent2audio_time_cost = end_time - start_time
1482        timecosts = {
1483            "preprocess": preprocess_time_cost,
1484            "diffusion": diffusion_time_cost,
1485            "latent2audio": latent2audio_time_cost,
1486        }
1487
1488        input_params_json = {
1489            "format": format,
1490            "lora_name_or_path": lora_name_or_path,
1491            "lora_weight": lora_weight,
1492            "task": task,
1493            "prompt": prompt if task != "edit" else edit_target_prompt,
1494            "lyrics": lyrics if task != "edit" else edit_target_lyrics,
1495            "audio_duration": audio_duration,
1496            "infer_step": infer_step,
1497            "guidance_scale": guidance_scale,
1498            "scheduler_type": scheduler_type,
1499            "cfg_type": cfg_type,
1500            "omega_scale": omega_scale,
1501            "guidance_interval": guidance_interval,
1502            "guidance_interval_decay": guidance_interval_decay,
1503            "min_guidance_scale": min_guidance_scale,
1504            "use_erg_tag": use_erg_tag,
1505            "use_erg_lyric": use_erg_lyric,
1506            "use_erg_diffusion": use_erg_diffusion,
1507            "oss_steps": oss_steps,
1508            "timecosts": timecosts,
1509            "actual_seeds": actual_seeds,
1510            "retake_seeds": actual_retake_seeds,
1511            "retake_variance": retake_variance,
1512            "guidance_scale_text": guidance_scale_text,
1513            "guidance_scale_lyric": guidance_scale_lyric,
1514            "repaint_start": repaint_start,
1515            "repaint_end": repaint_end,
1516            "edit_n_min": edit_n_min,
1517            "edit_n_max": edit_n_max,
1518            "edit_n_avg": edit_n_avg,
1519            "src_audio_path": src_audio_path,
1520            "edit_target_prompt": edit_target_prompt,
1521            "edit_target_lyrics": edit_target_lyrics,
1522            "audio2audio_enable": audio2audio_enable,
1523            "ref_audio_strength": ref_audio_strength,
1524            "ref_audio_input": ref_audio_input,
1525        }
1526        # save input_params_json
1527        for output_audio_path in output_paths:
1528            input_params_json_save_path = output_audio_path.replace(f".{format}", "_input_params.json")
1529            input_params_json["audio_path"] = output_audio_path
1530            with open(input_params_json_save_path, "w", encoding="utf-8") as f:
1531                json.dump(input_params_json, f, indent=4, ensure_ascii=False)
1532
1533        return output_paths + [input_params_json]
ACEStepPipeline( checkpoint_dir=None, device_id=0, dtype='bfloat16', text_encoder_checkpoint_path=None, persistent_storage_path=None, torch_compile=False, cpu_offload=False, quantized=False, overlapped_decode=False, **kwargs)
 94    def __init__(
 95        self,
 96        checkpoint_dir=None,
 97        device_id=0,
 98        dtype="bfloat16",
 99        text_encoder_checkpoint_path=None,
100        persistent_storage_path=None,
101        torch_compile=False,
102        cpu_offload=False,
103        quantized=False,
104        overlapped_decode=False,
105        **kwargs,
106    ):
107        if not checkpoint_dir:
108            if persistent_storage_path is None:
109                checkpoint_dir = os.path.join(os.path.expanduser("~"), ".cache/ace-step/checkpoints")
110                os.makedirs(checkpoint_dir, exist_ok=True)
111            else:
112                checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints")
113        ensure_directory_exists(checkpoint_dir)
114
115        self.checkpoint_dir = checkpoint_dir
116        self.lora_path = "none"
117        self.lora_weight = 1
118        self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
119        if gfx_device.type == "mps":
120            if self.dtype == torch.bfloat16:
121                self.dtype = torch.float16
122
123        if "ACE_PIPELINE_DTYPE" in os.environ and len(os.environ["ACE_PIPELINE_DTYPE"]):
124            self.dtype = getattr(torch, os.environ["ACE_PIPELINE_DTYPE"])
125        self.device: torch.device = gfx_device
126        self.loaded = False
127        self.torch_compile = torch_compile
128        self.cpu_offload = cpu_offload
129        self.quantized = quantized
130        self.overlapped_decode = overlapped_decode
checkpoint_dir
lora_path
lora_weight
dtype
device: torch.device
loaded
torch_compile
cpu_offload
quantized
overlapped_decode
def get_checkpoint_path(self, checkpoint_dir, repo):
132    def get_checkpoint_path(self, checkpoint_dir, repo):
133        checkpoint_dir_models = None
134
135        if checkpoint_dir is not None:
136            required_dirs = ["music_dcae_f8c8", "music_vocoder", "ace_step_transformer", "umt5-base"]
137            all_dirs_exist = True
138            for dir_name in required_dirs:
139                dir_path = os.path.join(checkpoint_dir, dir_name)
140                if not os.path.exists(dir_path):
141                    all_dirs_exist = False
142                    break
143
144            if all_dirs_exist:
145                nfo(f"Load models from: {checkpoint_dir}")
146                checkpoint_dir_models = checkpoint_dir
147
148        if checkpoint_dir_models is None:
149            if checkpoint_dir is None:
150                nfo(f"Download models from Hugging Face: {repo}")
151                checkpoint_dir_models = snapshot_download(repo)
152            else:
153                nfo(f"Download models from Hugging Face: {repo}, cache to: {checkpoint_dir}")
154                checkpoint_dir_models = snapshot_download(repo, cache_dir=checkpoint_dir)
155        return checkpoint_dir_models
def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False):
157    def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False):
158        checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID)
159        dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
160        vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder")
161        ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer")
162        text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base")
163
164        self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype)
165        # self.ace_step_transformer.to(self.device).eval().to(self.dtype)
166        if self.cpu_offload:
167            self.ace_step_transformer = self.ace_step_transformer.to("cpu").eval().to(self.dtype)
168        else:
169            self.ace_step_transformer = self.ace_step_transformer.to(self.device).eval().to(self.dtype)
170        if self.torch_compile:
171            self.ace_step_transformer = torch.compile(self.ace_step_transformer)
172
173        self.music_dcae = MusicDCAE(
174            dcae_checkpoint_path=dcae_checkpoint_path,
175            vocoder_checkpoint_path=vocoder_checkpoint_path,
176        )
177        # self.music_dcae.to(self.device).eval().to(self.dtype)
178        if self.cpu_offload:  # might be redundant
179            self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype)
180        else:
181            self.music_dcae = self.music_dcae.to(self.device).eval().to(self.dtype)
182        if self.torch_compile:
183            self.music_dcae = torch.compile(self.music_dcae)
184
185        lang_segment = LangSegment()
186        lang_segment.setfilters(language_filters.default)
187        self.lang_segment = lang_segment
188        self.lyric_tokenizer = VoiceBpeTokenizer()
189
190        text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval()
191        # text_encoder_model = text_encoder_model.to(self.device).to(self.dtype)
192        if self.cpu_offload:
193            text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype)
194        else:
195            text_encoder_model = text_encoder_model.to(self.device).eval().to(self.dtype)
196        text_encoder_model.requires_grad_(False)
197        self.text_encoder_model = text_encoder_model
198        if self.torch_compile:
199            self.text_encoder_model = torch.compile(self.text_encoder_model)
200
201        self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path)
202        self.loaded = True
203
204        # compile
205        if self.torch_compile:
206            if export_quantized_weights:
207                from torch.ao.quantization import (
208                    Int4WeightOnlyConfig,
209                    quantize_,
210                )
211
212                group_size = 128
213                use_hqq = True
214                quantize_(
215                    self.ace_step_transformer,
216                    Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq),
217                )
218                quantize_(
219                    self.text_encoder_model,
220                    Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq),
221                )
222
223                # save quantized weights
224                torch.save(
225                    self.ace_step_transformer.state_dict(),
226                    os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"),
227                )
228                print(
229                    "Quantized Weights Saved to: ",
230                    os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"),
231                )
232                torch.save(
233                    self.text_encoder_model.state_dict(),
234                    os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"),
235                )
236                print(
237                    "Quantized Weights Saved to: ",
238                    os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"),
239                )
def load_quantized_checkpoint(self, checkpoint_dir=None):
241    def load_quantized_checkpoint(self, checkpoint_dir=None):
242        checkpoint_dir = self.get_checkpoint_path(checkpoint_dir, REPO_ID_QUANT)
243        dcae_checkpoint_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
244        vocoder_checkpoint_path = os.path.join(checkpoint_dir, "music_vocoder")
245        ace_step_checkpoint_path = os.path.join(checkpoint_dir, "ace_step_transformer")
246        text_encoder_checkpoint_path = os.path.join(checkpoint_dir, "umt5-base")
247
248        self.music_dcae = MusicDCAE(
249            dcae_checkpoint_path=dcae_checkpoint_path,
250            vocoder_checkpoint_path=vocoder_checkpoint_path,
251        )
252        if self.cpu_offload:
253            self.music_dcae.eval().to(self.dtype).to(self.device)
254        else:
255            self.music_dcae.eval().to(self.dtype).to("cpu")
256        self.music_dcae = torch.compile(self.music_dcae)
257
258        self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path)
259        self.ace_step_transformer.eval().to(self.dtype).to("cpu")
260        self.ace_step_transformer = torch.compile(self.ace_step_transformer)
261        self.ace_step_transformer.load_state_dict(
262            torch.load(
263                os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"),
264                map_location=self.device,
265            ),
266            assign=True,
267        )
268        self.ace_step_transformer.torchao_quantized = True
269
270        self.text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path)
271        self.text_encoder_model.eval().to(self.dtype).to("cpu")
272        self.text_encoder_model = torch.compile(self.text_encoder_model)
273        self.text_encoder_model.load_state_dict(
274            torch.load(
275                os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"),
276                map_location=self.device,
277            ),
278            assign=True,
279        )
280        self.text_encoder_model.torchao_quantized = True
281
282        self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path)
283
284        lang_segment = LangSegment()
285        lang_segment.setfilters(language_filters.default)
286        self.lang_segment = lang_segment
287        self.lyric_tokenizer = VoiceBpeTokenizer()
288
289        self.loaded = True
@cpu_offload('text_encoder_model')
def get_text_embeddings(self, texts, text_max_length=256):
291    @cpu_offload("text_encoder_model")
292    def get_text_embeddings(self, texts, text_max_length=256):
293        inputs = self.text_tokenizer(
294            texts,
295            return_tensors="pt",
296            padding=True,
297            truncation=True,
298            max_length=text_max_length,
299        )
300        inputs = {key: value.to(self.device) for key, value in inputs.items()}
301        if self.text_encoder_model.device != self.device:
302            self.text_encoder_model.to(self.device)
303        with torch.no_grad():
304            outputs = self.text_encoder_model(**inputs)
305            last_hidden_states = outputs.last_hidden_state
306        attention_mask = inputs["attention_mask"]
307        return last_hidden_states, attention_mask
@cpu_offload('text_encoder_model')
def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10):
309    @cpu_offload("text_encoder_model")
310    def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10):
311        inputs = self.text_tokenizer(
312            texts,
313            return_tensors="pt",
314            padding=True,
315            truncation=True,
316            max_length=text_max_length,
317        )
318        inputs = {key: value.to(self.device) for key, value in inputs.items()}
319        if self.text_encoder_model.device != self.device:
320            self.text_encoder_model.to(self.device)
321
322        def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10):
323            handlers = []
324
325            def hook(module, input, output):
326                output[:] *= tau
327                return output
328
329            for i in range(l_min, l_max):
330                handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook)
331                handlers.append(handler)
332
333            with torch.no_grad():
334                outputs = self.text_encoder_model(**inputs)
335                last_hidden_states = outputs.last_hidden_state
336
337            for hook in handlers:
338                hook.remove()
339
340            return last_hidden_states
341
342        last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max)
343        return last_hidden_states
def set_seeds(self, batch_size, manual_seeds=None):
345    def set_seeds(self, batch_size, manual_seeds=None):
346        processed_input_seeds = None
347        if manual_seeds is not None:
348            if isinstance(manual_seeds, str):
349                if "," in manual_seeds:
350                    processed_input_seeds = list(map(int, manual_seeds.split(",")))
351                elif manual_seeds.isdigit():
352                    processed_input_seeds = int(manual_seeds)
353            elif isinstance(manual_seeds, list) and all(isinstance(s, int) for s in manual_seeds):
354                if len(manual_seeds) > 0:
355                    processed_input_seeds = list(manual_seeds)
356            elif isinstance(manual_seeds, int):
357                processed_input_seeds = manual_seeds
358        random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)]
359        actual_seeds = []
360        for i in range(batch_size):
361            current_seed_for_generator = None
362            if processed_input_seeds is None:
363                current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
364            elif isinstance(processed_input_seeds, int):
365                current_seed_for_generator = processed_input_seeds
366            elif isinstance(processed_input_seeds, list):
367                if i < len(processed_input_seeds):
368                    current_seed_for_generator = processed_input_seeds[i]
369                else:
370                    current_seed_for_generator = processed_input_seeds[-1]
371            if current_seed_for_generator is None:
372                current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
373            random_generators[i].manual_seed(current_seed_for_generator)
374            actual_seeds.append(current_seed_for_generator)
375        return random_generators, actual_seeds
def get_lang(self, text):
377    def get_lang(self, text):
378        language = "en"
379        try:
380            _ = self.lang_segment.getTexts(text)
381            langCounts = self.lang_segment.getCounts()
382            language = langCounts[0][0]
383            if len(langCounts) > 1 and language == "en":
384                language = langCounts[1][0]
385        except Exception as err:
386            language = "en"
387        return language
def tokenize_lyrics(self, lyrics, debug=False):
389    def tokenize_lyrics(self, lyrics, debug=False):
390        lines = lyrics.split("\n")
391        lyric_token_idx = [261]
392        for line in lines:
393            line = line.strip()
394            if not line:
395                lyric_token_idx += [2]
396                continue
397
398            lang = self.get_lang(line)
399
400            if lang not in SUPPORT_LANGUAGES:
401                lang = "en"
402            if "zh" in lang:
403                lang = "zh"
404            if "spa" in lang:
405                lang = "es"
406
407            try:
408                if structure_pattern.match(line):
409                    token_idx = self.lyric_tokenizer.encode(line, "en")
410                else:
411                    token_idx = self.lyric_tokenizer.encode(line, lang)
412                if debug:
413                    toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx])
414                    nfo(f"debbug {line} --> {lang} --> {toks}")
415                lyric_token_idx = lyric_token_idx + token_idx + [2]
416            except Exception as e:
417                print("tokenize error", e, "for line", line, "major_language", lang)
418        return lyric_token_idx
@cpu_offload('ace_step_transformer')
def calc_v( self, zt_src, zt_tar, t, encoder_text_hidden_states, text_attention_mask, target_encoder_text_hidden_states, target_text_attention_mask, speaker_embds, target_speaker_embeds, lyric_token_ids, lyric_mask, target_lyric_token_ids, target_lyric_mask, do_classifier_free_guidance=False, guidance_scale=1.0, target_guidance_scale=1.0, cfg_type='apg', attention_mask=None, momentum_buffer=None, momentum_buffer_tar=None, return_src_pred=True):
420    @cpu_offload("ace_step_transformer")
421    def calc_v(
422        self,
423        zt_src,
424        zt_tar,
425        t,
426        encoder_text_hidden_states,
427        text_attention_mask,
428        target_encoder_text_hidden_states,
429        target_text_attention_mask,
430        speaker_embds,
431        target_speaker_embeds,
432        lyric_token_ids,
433        lyric_mask,
434        target_lyric_token_ids,
435        target_lyric_mask,
436        do_classifier_free_guidance=False,
437        guidance_scale=1.0,
438        target_guidance_scale=1.0,
439        cfg_type="apg",
440        attention_mask=None,
441        momentum_buffer=None,
442        momentum_buffer_tar=None,
443        return_src_pred=True,
444    ):
445        noise_pred_src = None
446        if return_src_pred:
447            src_latent_model_input = torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src
448            timestep = t.expand(src_latent_model_input.shape[0])
449            # source
450            noise_pred_src = self.ace_step_transformer(
451                hidden_states=src_latent_model_input,
452                attention_mask=attention_mask,
453                encoder_text_hidden_states=encoder_text_hidden_states,
454                text_attention_mask=text_attention_mask,
455                speaker_embeds=speaker_embds,
456                lyric_token_idx=lyric_token_ids,
457                lyric_mask=lyric_mask,
458                timestep=timestep,
459            ).sample
460
461            if do_classifier_free_guidance:
462                noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk(2)
463                if cfg_type == "apg":
464                    noise_pred_src = apg_forward(
465                        pred_cond=noise_pred_with_cond_src,
466                        pred_uncond=noise_pred_uncond_src,
467                        guidance_scale=guidance_scale,
468                        momentum_buffer=momentum_buffer,
469                    )
470                elif cfg_type == "cfg":
471                    noise_pred_src = cfg_forward(
472                        cond_output=noise_pred_with_cond_src,
473                        uncond_output=noise_pred_uncond_src,
474                        cfg_strength=guidance_scale,
475                    )
476
477        tar_latent_model_input = torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar
478        timestep = t.expand(tar_latent_model_input.shape[0])
479        # target
480        noise_pred_tar = self.ace_step_transformer(
481            hidden_states=tar_latent_model_input,
482            attention_mask=attention_mask,
483            encoder_text_hidden_states=target_encoder_text_hidden_states,
484            text_attention_mask=target_text_attention_mask,
485            speaker_embeds=target_speaker_embeds,
486            lyric_token_idx=target_lyric_token_ids,
487            lyric_mask=target_lyric_mask,
488            timestep=timestep,
489        ).sample
490
491        if do_classifier_free_guidance:
492            noise_pred_with_cond_tar, noise_pred_uncond_tar = noise_pred_tar.chunk(2)
493            if cfg_type == "apg":
494                noise_pred_tar = apg_forward(
495                    pred_cond=noise_pred_with_cond_tar,
496                    pred_uncond=noise_pred_uncond_tar,
497                    guidance_scale=target_guidance_scale,
498                    momentum_buffer=momentum_buffer_tar,
499                )
500            elif cfg_type == "cfg":
501                noise_pred_tar = cfg_forward(
502                    cond_output=noise_pred_with_cond_tar,
503                    uncond_output=noise_pred_uncond_tar,
504                    cfg_strength=target_guidance_scale,
505                )
506        return noise_pred_src, noise_pred_tar
@torch.no_grad()
def flowedit_diffusion_process( self, encoder_text_hidden_states, text_attention_mask, speaker_embds, lyric_token_ids, lyric_mask, target_encoder_text_hidden_states, target_text_attention_mask, target_speaker_embeds, target_lyric_token_ids, target_lyric_mask, src_latents, random_generators=None, infer_steps=60, guidance_scale=15.0, n_min=0, n_max=1.0, n_avg=1, scheduler_type='euler'):
508    @torch.no_grad()
509    def flowedit_diffusion_process(
510        self,
511        encoder_text_hidden_states,
512        text_attention_mask,
513        speaker_embds,
514        lyric_token_ids,
515        lyric_mask,
516        target_encoder_text_hidden_states,
517        target_text_attention_mask,
518        target_speaker_embeds,
519        target_lyric_token_ids,
520        target_lyric_mask,
521        src_latents,
522        random_generators=None,
523        infer_steps=60,
524        guidance_scale=15.0,
525        n_min=0,
526        n_max=1.0,
527        n_avg=1,
528        scheduler_type="euler",
529    ):
530        do_classifier_free_guidance = True
531        if guidance_scale == 0.0 or guidance_scale == 1.0:
532            do_classifier_free_guidance = False
533
534        target_guidance_scale = guidance_scale
535        bsz = encoder_text_hidden_states.shape[0]
536
537        scheduler = FlowMatchEulerDiscreteScheduler(
538            num_train_timesteps=1000,
539            shift=3.0,
540        )
541
542        T_steps = infer_steps
543        frame_length = src_latents.shape[-1]
544        attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype)
545
546        timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, self.device, timesteps=None)
547
548        if do_classifier_free_guidance:
549            attention_mask = torch.cat([attention_mask] * 2, dim=0)
550
551            encoder_text_hidden_states = torch.cat(
552                [
553                    encoder_text_hidden_states,
554                    torch.zeros_like(encoder_text_hidden_states),
555                ],
556                0,
557            )
558            text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
559
560            target_encoder_text_hidden_states = torch.cat(
561                [
562                    target_encoder_text_hidden_states,
563                    torch.zeros_like(target_encoder_text_hidden_states),
564                ],
565                0,
566            )
567            target_text_attention_mask = torch.cat([target_text_attention_mask] * 2, dim=0)
568
569            speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0)
570            target_speaker_embeds = torch.cat([target_speaker_embeds, torch.zeros_like(target_speaker_embeds)], 0)
571
572            lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0)
573            lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
574
575            target_lyric_token_ids = torch.cat([target_lyric_token_ids, torch.zeros_like(target_lyric_token_ids)], 0)
576            target_lyric_mask = torch.cat([target_lyric_mask, torch.zeros_like(target_lyric_mask)], 0)
577
578        momentum_buffer = MomentumBuffer()
579        momentum_buffer_tar = MomentumBuffer()
580        x_src = src_latents
581        zt_edit = x_src.clone()
582        xt_tar = None
583        n_min = int(infer_steps * n_min)
584        n_max = int(infer_steps * n_max)
585
586        nfo("flowedit start from {} to {}".format(n_min, n_max))
587
588        for i, t in tqdm(enumerate(timesteps), total=T_steps):
589            if i < n_min:
590                continue
591
592            t_i = t / 1000
593
594            if i + 1 < len(timesteps):
595                t_im1 = (timesteps[i + 1]) / 1000
596            else:
597                t_im1 = torch.zeros_like(t_i).to(self.device)
598
599            if i < n_max:
600                # Calculate the average of the V predictions
601                V_delta_avg = torch.zeros_like(x_src)
602                for k in range(n_avg):
603                    fwd_noise = randn_tensor(
604                        shape=x_src.shape,
605                        generator=random_generators,
606                        device=self.device,
607                        dtype=self.dtype,
608                    )
609
610                    zt_src = (1 - t_i) * x_src + (t_i) * fwd_noise
611
612                    zt_tar = zt_edit + zt_src - x_src
613
614                    Vt_src, Vt_tar = self.calc_v(
615                        zt_src=zt_src,
616                        zt_tar=zt_tar,
617                        t=t,
618                        encoder_text_hidden_states=encoder_text_hidden_states,
619                        text_attention_mask=text_attention_mask,
620                        target_encoder_text_hidden_states=target_encoder_text_hidden_states,
621                        target_text_attention_mask=target_text_attention_mask,
622                        speaker_embds=speaker_embds,
623                        target_speaker_embeds=target_speaker_embeds,
624                        lyric_token_ids=lyric_token_ids,
625                        lyric_mask=lyric_mask,
626                        target_lyric_token_ids=target_lyric_token_ids,
627                        target_lyric_mask=target_lyric_mask,
628                        do_classifier_free_guidance=do_classifier_free_guidance,
629                        guidance_scale=guidance_scale,
630                        target_guidance_scale=target_guidance_scale,
631                        attention_mask=attention_mask,
632                        momentum_buffer=momentum_buffer,
633                    )
634                    V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src)  # - (hfg - 1) * (x_src)
635
636                zt_edit = zt_edit.to(torch.float32)  # arbitrary, should be settable for compatibility
637                if scheduler_type != "pingpong":
638                    # propagate direct ODE
639                    zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg
640                    zt_edit = zt_edit.to(self.dtype)
641                else:
642                    # propagate pingpong SDE
643                    zt_edit_denoised = zt_edit - t_i * V_delta_avg
644                    noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None)
645                    prev_sample = (1 - t_im1) * zt_edit_denoised + t_im1 * noise
646
647            else:  # i >= T_steps-n_min # regular sampling for last n_min steps
648                if i == n_max:
649                    fwd_noise = randn_tensor(
650                        shape=x_src.shape,
651                        generator=random_generators,
652                        device=self.device,
653                        dtype=self.dtype,
654                    )
655                    scheduler._init_step_index(t)
656                    sigma = scheduler.sigmas[scheduler.step_index]
657                    xt_src = sigma * fwd_noise + (1.0 - sigma) * x_src
658                    xt_tar = zt_edit + xt_src - x_src
659
660                _, Vt_tar = self.calc_v(
661                    zt_src=None,
662                    zt_tar=xt_tar,
663                    t=t,
664                    encoder_text_hidden_states=encoder_text_hidden_states,
665                    text_attention_mask=text_attention_mask,
666                    target_encoder_text_hidden_states=target_encoder_text_hidden_states,
667                    target_text_attention_mask=target_text_attention_mask,
668                    speaker_embds=speaker_embds,
669                    target_speaker_embeds=target_speaker_embeds,
670                    lyric_token_ids=lyric_token_ids,
671                    lyric_mask=lyric_mask,
672                    target_lyric_token_ids=target_lyric_token_ids,
673                    target_lyric_mask=target_lyric_mask,
674                    do_classifier_free_guidance=do_classifier_free_guidance,
675                    guidance_scale=guidance_scale,
676                    target_guidance_scale=target_guidance_scale,
677                    attention_mask=attention_mask,
678                    momentum_buffer_tar=momentum_buffer_tar,
679                    return_src_pred=False,
680                )
681
682                xt_tar = xt_tar.to(torch.float32)
683                if scheduler_type != "pingpong":
684                    prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar
685                    prev_sample = prev_sample.to(self.dtype)
686                    xt_tar = prev_sample
687                else:
688                    prev_sample = xt_tar - t_i * Vt_tar
689                    noise = torch.empty_like(zt_edit).normal_(generator=random_generators[0] if random_generators else None)
690                    prev_sample = (1 - t_im1) * prev_sample + t_im1 * noise
691                    xt_tar = prev_sample
692
693        target_latents = zt_edit if xt_tar is None else xt_tar
694        return target_latents
def add_latents_noise(self, gt_latents, sigma_max, noise, scheduler_type, infer_steps):
696    def add_latents_noise(
697        self,
698        gt_latents,
699        sigma_max,
700        noise,
701        scheduler_type,
702        infer_steps,
703    ):
704        bsz = gt_latents.shape[0]
705        if scheduler_type == "euler":
706            scheduler = FlowMatchEulerDiscreteScheduler(
707                num_train_timesteps=1000,
708                shift=3.0,
709                sigma_max=sigma_max,
710            )
711        elif scheduler_type == "heun":
712            scheduler = FlowMatchHeunDiscreteScheduler(
713                num_train_timesteps=1000,
714                shift=3.0,
715                sigma_max=sigma_max,
716            )
717        elif scheduler_type == "pingpong":
718            scheduler = FlowMatchPingPongScheduler(num_train_timesteps=1000, shift=3.0, sigma_max=sigma_max)
719
720        infer_steps = int(sigma_max * infer_steps)
721        timesteps, num_inference_steps = retrieve_timesteps(
722            scheduler,
723            num_inference_steps=infer_steps,
724            device=self.device,
725            timesteps=None,
726        )
727        noisy_image = gt_latents * (1 - scheduler.sigma_max) + noise * scheduler.sigma_max
728        nfo(f"{scheduler.sigma_min=} {scheduler.sigma_max=} {timesteps=} {num_inference_steps=}")
729        return noisy_image, timesteps, scheduler, num_inference_steps
@cpu_offload('ace_step_transformer')
@torch.no_grad()
def text2music_diffusion_process( self, duration, encoder_text_hidden_states, text_attention_mask, speaker_embds, lyric_token_ids, lyric_mask, random_generators=None, infer_steps=60, guidance_scale=15.0, omega_scale=10.0, scheduler_type='euler', cfg_type='apg', zero_steps=1, use_zero_init=True, guidance_interval=0.5, guidance_interval_decay=1.0, min_guidance_scale=3.0, oss_steps=[], encoder_text_hidden_states_null=None, use_erg_lyric=False, use_erg_diffusion=False, retake_random_generators=None, retake_variance=0.5, add_retake_noise=False, guidance_scale_text=0.0, guidance_scale_lyric=0.0, repaint_start=0, repaint_end=0, src_latents=None, audio2audio_enable=False, ref_audio_strength=0.5, ref_latents=None):
 731    @cpu_offload("ace_step_transformer")
 732    @torch.no_grad()
 733    def text2music_diffusion_process(
 734        self,
 735        duration,
 736        encoder_text_hidden_states,
 737        text_attention_mask,
 738        speaker_embds,
 739        lyric_token_ids,
 740        lyric_mask,
 741        random_generators=None,
 742        infer_steps=60,
 743        guidance_scale=15.0,
 744        omega_scale=10.0,
 745        scheduler_type="euler",
 746        cfg_type="apg",
 747        zero_steps=1,
 748        use_zero_init=True,
 749        guidance_interval=0.5,
 750        guidance_interval_decay=1.0,
 751        min_guidance_scale=3.0,
 752        oss_steps=[],
 753        encoder_text_hidden_states_null=None,
 754        use_erg_lyric=False,
 755        use_erg_diffusion=False,
 756        retake_random_generators=None,
 757        retake_variance=0.5,
 758        add_retake_noise=False,
 759        guidance_scale_text=0.0,
 760        guidance_scale_lyric=0.0,
 761        repaint_start=0,
 762        repaint_end=0,
 763        src_latents=None,
 764        audio2audio_enable=False,
 765        ref_audio_strength=0.5,
 766        ref_latents=None,
 767    ):
 768        nfo("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
 769        do_classifier_free_guidance = True
 770        if guidance_scale == 0.0 or guidance_scale == 1.0:
 771            do_classifier_free_guidance = False
 772
 773        do_double_condition_guidance = False
 774        if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0:
 775            do_double_condition_guidance = True
 776            nfo(
 777                "do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format(
 778                    do_double_condition_guidance,
 779                    guidance_scale_text,
 780                    guidance_scale_lyric,
 781                )
 782            )
 783
 784        bsz = encoder_text_hidden_states.shape[0]
 785
 786        if scheduler_type == "euler":
 787            scheduler = FlowMatchEulerDiscreteScheduler(
 788                num_train_timesteps=1000,
 789                shift=3.0,
 790            )
 791        elif scheduler_type == "heun":
 792            scheduler = FlowMatchHeunDiscreteScheduler(
 793                num_train_timesteps=1000,
 794                shift=3.0,
 795            )
 796        elif scheduler_type == "pingpong":
 797            scheduler = FlowMatchPingPongScheduler(
 798                num_train_timesteps=1000,
 799                shift=3.0,
 800            )
 801
 802        frame_length = int(duration * 44100 / 512 / 8)
 803        if src_latents is not None:
 804            frame_length = src_latents.shape[-1]
 805
 806        if ref_latents is not None:
 807            frame_length = ref_latents.shape[-1]
 808
 809        if len(oss_steps) > 0:
 810            infer_steps = max(oss_steps)
 811            scheduler.set_timesteps
 812            timesteps, num_inference_steps = retrieve_timesteps(
 813                scheduler,
 814                num_inference_steps=infer_steps,
 815                device=self.device,
 816                timesteps=None,
 817            )
 818            new_timesteps = torch.zeros(len(oss_steps), dtype=self.dtype, device=self.device)
 819            for idx in range(len(oss_steps)):
 820                new_timesteps[idx] = timesteps[oss_steps[idx] - 1]
 821            num_inference_steps = len(oss_steps)
 822            sigmas = (new_timesteps / 1000).float().cpu().numpy()
 823            timesteps, num_inference_steps = retrieve_timesteps(
 824                scheduler,
 825                num_inference_steps=num_inference_steps,
 826                device=self.device,
 827                sigmas=sigmas,
 828            )
 829            nfo(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}")
 830        else:
 831            timesteps, num_inference_steps = retrieve_timesteps(
 832                scheduler,
 833                num_inference_steps=infer_steps,
 834                device=self.device,
 835                timesteps=None,
 836            )
 837
 838        target_latents = randn_tensor(
 839            shape=(bsz, 8, 16, frame_length),
 840            generator=random_generators,
 841            device=self.device,
 842            dtype=self.dtype,
 843        )
 844
 845        is_repaint = False
 846        is_extend = False
 847
 848        if add_retake_noise:
 849            n_min = int(infer_steps * (1 - retake_variance))
 850            retake_variance = torch.tensor(retake_variance * math.pi / 2).to(self.device).to(self.dtype)
 851            retake_latents = randn_tensor(
 852                shape=(bsz, 8, 16, frame_length),
 853                generator=retake_random_generators,
 854                device=self.device,
 855                dtype=self.dtype,
 856            )
 857            repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
 858            repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
 859            x0 = src_latents
 860            # retake
 861            is_repaint = repaint_end_frame - repaint_start_frame != frame_length
 862
 863            is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length)
 864            if is_extend:
 865                is_repaint = True
 866
 867            # TODO: train a mask aware repainting controlnet
 868            # to make sure mean = 0, std = 1
 869            if not is_repaint:
 870                target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
 871            elif not is_extend:
 872                # if repaint_end_frame
 873                repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype)
 874                repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
 875                repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
 876                repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
 877                zt_edit = x0.clone()
 878                z0 = repaint_noise
 879            elif is_extend:
 880                to_right_pad_gt_latents = None
 881                to_left_pad_gt_latents = None
 882                gt_latents = src_latents
 883                src_latents_length = gt_latents.shape[-1]
 884                max_infer_fame_length = int(240 * 44100 / 512 / 8)
 885                left_pad_frame_length = 0
 886                right_pad_frame_length = 0
 887                right_trim_length = 0
 888                left_trim_length = 0
 889                if repaint_start_frame < 0:
 890                    left_pad_frame_length = abs(repaint_start_frame)
 891                    frame_length = left_pad_frame_length + gt_latents.shape[-1]
 892                    extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0)
 893                    if frame_length > max_infer_fame_length:
 894                        right_trim_length = frame_length - max_infer_fame_length
 895                        extend_gt_latents = extend_gt_latents[:, :, :, :max_infer_fame_length]
 896                        to_right_pad_gt_latents = extend_gt_latents[:, :, :, -right_trim_length:]
 897                        frame_length = max_infer_fame_length
 898                    repaint_start_frame = 0
 899                    gt_latents = extend_gt_latents
 900
 901                if repaint_end_frame > src_latents_length:
 902                    right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1]
 903                    frame_length = gt_latents.shape[-1] + right_pad_frame_length
 904                    extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0)
 905                    if frame_length > max_infer_fame_length:
 906                        left_trim_length = frame_length - max_infer_fame_length
 907                        extend_gt_latents = extend_gt_latents[:, :, :, -max_infer_fame_length:]
 908                        to_left_pad_gt_latents = extend_gt_latents[:, :, :, :left_trim_length]
 909                        frame_length = max_infer_fame_length
 910                    repaint_end_frame = frame_length
 911                    gt_latents = extend_gt_latents
 912
 913                repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype)
 914                if left_pad_frame_length > 0:
 915                    repaint_mask[:, :, :, :left_pad_frame_length] = 1.0
 916                if right_pad_frame_length > 0:
 917                    repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0
 918                x0 = gt_latents
 919                padd_list = []
 920                if left_pad_frame_length > 0:
 921                    padd_list.append(retake_latents[:, :, :, :left_pad_frame_length])
 922                padd_list.append(
 923                    target_latents[
 924                        :,
 925                        :,
 926                        :,
 927                        left_trim_length : target_latents.shape[-1] - right_trim_length,
 928                    ]
 929                )
 930                if right_pad_frame_length > 0:
 931                    padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
 932                target_latents = torch.cat(padd_list, dim=-1)
 933                assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
 934                zt_edit = x0.clone()
 935                z0 = target_latents
 936
 937        if audio2audio_enable and ref_latents is not None:
 938            nfo(f"audio2audio_enable: {audio2audio_enable}, ref_latents: {ref_latents.shape}")
 939            target_latents, timesteps, scheduler, num_inference_steps = self.add_latents_noise(
 940                gt_latents=ref_latents,
 941                sigma_max=(1 - ref_audio_strength),
 942                noise=target_latents,
 943                scheduler_type=scheduler_type,
 944                infer_steps=infer_steps,
 945            )
 946
 947        attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype)
 948
 949        # guidance interval
 950        start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
 951        end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
 952        nfo(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
 953
 954        momentum_buffer = MomentumBuffer()
 955
 956        def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
 957            handlers = []
 958
 959            def hook(module, input, output):
 960                output[:] *= tau
 961                return output
 962
 963            for i in range(l_min, l_max):
 964                handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
 965                handlers.append(handler)
 966
 967            encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
 968
 969            for hook in handlers:
 970                hook.remove()
 971
 972            return encoder_hidden_states
 973
 974        # P(speaker, text, lyric)
 975        encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(
 976            encoder_text_hidden_states,
 977            text_attention_mask,
 978            speaker_embds,
 979            lyric_token_ids,
 980            lyric_mask,
 981        )
 982
 983        if use_erg_lyric:
 984            # P(null_speaker, text_weaker, lyric_weaker)
 985            encoder_hidden_states_null = forward_encoder_with_temperature(
 986                self,
 987                inputs={
 988                    "encoder_text_hidden_states": (
 989                        encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states)
 990                    ),
 991                    "text_attention_mask": text_attention_mask,
 992                    "speaker_embeds": torch.zeros_like(speaker_embds),
 993                    "lyric_token_idx": lyric_token_ids,
 994                    "lyric_mask": lyric_mask,
 995                },
 996            )
 997        else:
 998            # P(null_speaker, null_text, null_lyric)
 999            encoder_hidden_states_null, _ = self.ace_step_transformer.encode(
1000                torch.zeros_like(encoder_text_hidden_states),
1001                text_attention_mask,
1002                torch.zeros_like(speaker_embds),
1003                torch.zeros_like(lyric_token_ids),
1004                lyric_mask,
1005            )
1006
1007        encoder_hidden_states_no_lyric = None
1008        if do_double_condition_guidance:
1009            # P(null_speaker, text, lyric_weaker)
1010            if use_erg_lyric:
1011                encoder_hidden_states_no_lyric = forward_encoder_with_temperature(
1012                    self,
1013                    inputs={
1014                        "encoder_text_hidden_states": encoder_text_hidden_states,
1015                        "text_attention_mask": text_attention_mask,
1016                        "speaker_embeds": torch.zeros_like(speaker_embds),
1017                        "lyric_token_idx": lyric_token_ids,
1018                        "lyric_mask": lyric_mask,
1019                    },
1020                )
1021            # P(null_speaker, text, no_lyric)
1022            else:
1023                encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode(
1024                    encoder_text_hidden_states,
1025                    text_attention_mask,
1026                    torch.zeros_like(speaker_embds),
1027                    torch.zeros_like(lyric_token_ids),
1028                    lyric_mask,
1029                )
1030
1031        def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
1032            handlers = []
1033
1034            def hook(module, input, output):
1035                output[:] *= tau
1036                return output
1037
1038            for i in range(l_min, l_max):
1039                handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
1040                handlers.append(handler)
1041                handler = self.ace_step_transformer.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook)
1042                handlers.append(handler)
1043
1044            sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
1045
1046            for hook in handlers:
1047                hook.remove()
1048
1049            return sample
1050
1051        for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
1052            if is_repaint:
1053                if i < n_min:
1054                    continue
1055                elif i == n_min:
1056                    t_i = t / 1000
1057                    zt_src = (1 - t_i) * x0 + (t_i) * z0
1058                    target_latents = zt_edit + zt_src - x0
1059                    nfo(f"repaint start from {n_min} add {t_i} level of noise")
1060
1061            # expand the latents if we are doing classifier free guidance
1062            latents = target_latents
1063
1064            is_in_guidance_interval = start_idx <= i < end_idx
1065            if is_in_guidance_interval and do_classifier_free_guidance:
1066                # compute current guidance scale
1067                if guidance_interval_decay > 0:
1068                    # Linearly interpolate to calculate the current guidance scale
1069                    progress = (i - start_idx) / (end_idx - start_idx - 1)  # 归一化到[0,1]
1070                    current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay
1071                else:
1072                    current_guidance_scale = guidance_scale
1073
1074                latent_model_input = latents
1075                timestep = t.expand(latent_model_input.shape[0])
1076                output_length = latent_model_input.shape[-1]
1077                # P(x|speaker, text, lyric)
1078                noise_pred_with_cond = self.ace_step_transformer.decode(
1079                    hidden_states=latent_model_input,
1080                    attention_mask=attention_mask,
1081                    encoder_hidden_states=encoder_hidden_states,
1082                    encoder_hidden_mask=encoder_hidden_mask,
1083                    output_length=output_length,
1084                    timestep=timestep,
1085                ).sample
1086
1087                noise_pred_with_only_text_cond = None
1088                if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None:
1089                    noise_pred_with_only_text_cond = self.ace_step_transformer.decode(
1090                        hidden_states=latent_model_input,
1091                        attention_mask=attention_mask,
1092                        encoder_hidden_states=encoder_hidden_states_no_lyric,
1093                        encoder_hidden_mask=encoder_hidden_mask,
1094                        output_length=output_length,
1095                        timestep=timestep,
1096                    ).sample
1097
1098                if use_erg_diffusion:
1099                    noise_pred_uncond = forward_diffusion_with_temperature(
1100                        self,
1101                        hidden_states=latent_model_input,
1102                        timestep=timestep,
1103                        inputs={
1104                            "encoder_hidden_states": encoder_hidden_states_null,
1105                            "encoder_hidden_mask": encoder_hidden_mask,
1106                            "output_length": output_length,
1107                            "attention_mask": attention_mask,
1108                        },
1109                    )
1110                else:
1111                    noise_pred_uncond = self.ace_step_transformer.decode(
1112                        hidden_states=latent_model_input,
1113                        attention_mask=attention_mask,
1114                        encoder_hidden_states=encoder_hidden_states_null,
1115                        encoder_hidden_mask=encoder_hidden_mask,
1116                        output_length=output_length,
1117                        timestep=timestep,
1118                    ).sample
1119
1120                if do_double_condition_guidance and noise_pred_with_only_text_cond is not None:
1121                    noise_pred = cfg_double_condition_forward(
1122                        cond_output=noise_pred_with_cond,
1123                        uncond_output=noise_pred_uncond,
1124                        only_text_cond_output=noise_pred_with_only_text_cond,
1125                        guidance_scale_text=guidance_scale_text,
1126                        guidance_scale_lyric=guidance_scale_lyric,
1127                    )
1128
1129                elif cfg_type == "apg":
1130                    noise_pred = apg_forward(
1131                        pred_cond=noise_pred_with_cond,
1132                        pred_uncond=noise_pred_uncond,
1133                        guidance_scale=current_guidance_scale,
1134                        momentum_buffer=momentum_buffer,
1135                    )
1136                elif cfg_type == "cfg":
1137                    noise_pred = cfg_forward(
1138                        cond_output=noise_pred_with_cond,
1139                        uncond_output=noise_pred_uncond,
1140                        cfg_strength=current_guidance_scale,
1141                    )
1142                elif cfg_type == "cfg_star":
1143                    noise_pred = cfg_zero_star(
1144                        noise_pred_with_cond=noise_pred_with_cond,
1145                        noise_pred_uncond=noise_pred_uncond,
1146                        guidance_scale=current_guidance_scale,
1147                        i=i,
1148                        zero_steps=zero_steps,
1149                        use_zero_init=use_zero_init,
1150                    )
1151            else:
1152                latent_model_input = latents
1153                timestep = t.expand(latent_model_input.shape[0])
1154                noise_pred = self.ace_step_transformer.decode(
1155                    hidden_states=latent_model_input,
1156                    attention_mask=attention_mask,
1157                    encoder_hidden_states=encoder_hidden_states,
1158                    encoder_hidden_mask=encoder_hidden_mask,
1159                    output_length=latent_model_input.shape[-1],
1160                    timestep=timestep,
1161                ).sample
1162
1163            if is_repaint and i >= n_min:
1164                t_i = t / 1000
1165                if i + 1 < len(timesteps):
1166                    t_im1 = (timesteps[i + 1]) / 1000
1167                else:
1168                    t_im1 = torch.zeros_like(t_i).to(self.device)
1169                target_latents = target_latents.to(torch.float32)
1170                prev_sample = target_latents + (t_im1 - t_i) * noise_pred
1171                prev_sample = prev_sample.to(self.dtype)
1172                target_latents = prev_sample
1173                zt_src = (1 - t_im1) * x0 + (t_im1) * z0
1174                target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src)
1175            else:
1176                target_latents = scheduler.step(
1177                    model_output=noise_pred,
1178                    timestep=t,
1179                    sample=target_latents,
1180                    return_dict=False,
1181                    omega=omega_scale,
1182                    generator=random_generators[0],
1183                )[0]
1184
1185        if is_extend:
1186            if to_right_pad_gt_latents is not None:
1187                target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1)
1188            if to_left_pad_gt_latents is not None:
1189                target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0)
1190        return target_latents
@cpu_offload('music_dcae')
def latents2audio( self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format='wav'):
1192    @cpu_offload("music_dcae")
1193    def latents2audio(
1194        self,
1195        latents,
1196        target_wav_duration_second=30,
1197        sample_rate=48000,
1198        save_path=None,
1199        format="wav",
1200    ):
1201        output_audio_paths = []
1202        bs = latents.shape[0]
1203        pred_latents = latents
1204        with torch.no_grad():
1205            if self.overlapped_decode and target_wav_duration_second > 48:
1206                _, pred_wavs = self.music_dcae.decode_overlap(pred_latents, sr=sample_rate)
1207            else:
1208                _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
1209        pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
1210        for i in tqdm(range(bs)):
1211            output_audio_path = self.save_wav_file(
1212                pred_wavs[i],
1213                i,
1214                save_path=save_path,
1215                sample_rate=sample_rate,
1216                format=format,
1217            )
1218            output_audio_paths.append(output_audio_path)
1219        return output_audio_paths
def save_wav_file( self, target_wav, idx, save_path=None, sample_rate=48000, format='wav'):
1221    def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="wav"):
1222        if save_path is None:
1223            nfo("save_path is None, using default path ./outputs/")
1224            base_path = "./outputs"
1225            ensure_directory_exists(base_path)
1226            output_path_wav = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format
1227        else:
1228            ensure_directory_exists(os.path.dirname(save_path))
1229            if os.path.isdir(save_path):
1230                nfo(f"Provided save_path '{save_path}' is a directory. Appending timestamped filename.")
1231                output_path_wav = os.path.join(save_path, f"output_{time.strftime('%Y%m%d%H%M%S')}_{idx}." + format)
1232            else:
1233                output_path_wav = save_path
1234
1235        target_wav = target_wav.float()
1236        backend = "soundfile"
1237        if format == "ogg":
1238            backend = "sox"
1239        nfo(f"Saving audio to {output_path_wav} using backend {backend}")
1240        try:
1241            torchaudio.save(output_path_wav, target_wav, sample_rate=sample_rate, format=format, backend=backend)
1242        except (RuntimeError, ImportError, ModuleNotFoundError):
1243            import soundfile as sf  # pyright: ignore[reportMissingImports] | pylint:disable=import-error
1244
1245            sf.write(output_path_wav, target_wav, sample_rate)
1246        return output_path_wav
@cpu_offload('music_dcae')
def infer_latents(self, input_audio_path):
1248    @cpu_offload("music_dcae")
1249    def infer_latents(self, input_audio_path):
1250        if input_audio_path is None:
1251            return None
1252        input_audio, sr = self.music_dcae.load_audio(input_audio_path)
1253        input_audio = input_audio.unsqueeze(0)
1254        input_audio = input_audio.to(device=self.device, dtype=self.dtype)
1255        latents, _ = self.music_dcae.encode(input_audio, sr=sr)
1256        return latents
def load_lora(self, lora_name_or_path, lora_weight):
1258    def load_lora(self, lora_name_or_path, lora_weight):
1259        if (lora_name_or_path != self.lora_path or lora_weight != self.lora_weight) and lora_name_or_path != "none":
1260            if not os.path.exists(lora_name_or_path):
1261                lora_download_path = snapshot_download(lora_name_or_path, cache_dir=self.checkpoint_dir)
1262            else:
1263                lora_download_path = lora_name_or_path
1264            if self.lora_path != "none":
1265                self.ace_step_transformer.unload_lora()
1266            self.ace_step_transformer.load_lora_adapter(
1267                os.path.join(lora_download_path, "pytorch_lora_weights.safetensors"), adapter_name="ace_step_lora", with_alpha=True, prefix=None
1268            )
1269            nfo(f"Loading lora weights from: {lora_name_or_path} download path is: {lora_download_path} weight: {lora_weight}")
1270            set_weights_and_activate_adapters(self.ace_step_transformer, ["ace_step_lora"], [lora_weight])
1271            self.lora_path = lora_name_or_path
1272            self.lora_weight = lora_weight
1273        elif self.lora_path != "none" and lora_name_or_path == "none":
1274            nfo("No lora weights to load.")
1275            self.ace_step_transformer.unload_lora()