divisor.acestep.music_dcae.music_dcae_pipeline
ACE-Step: A Step Towards Music Generation Foundation Model
https://github.com/ace-step/ACE-Step
Apache 2.0 License
1""" 2ACE-Step: A Step Towards Music Generation Foundation Model 3 4https://github.com/ace-step/ACE-Step 5 6Apache 2.0 License 7""" 8 9import os 10import torch 11from diffusers import AutoencoderDC 12import torchaudio 13import torchvision.transforms as transforms 14from diffusers.models.modeling_utils import ModelMixin 15from diffusers.loaders import FromOriginalModelMixin 16from diffusers.configuration_utils import ConfigMixin, register_to_config 17from tqdm import tqdm 18 19try: 20 from divisor.acestep.music_dcae.music_vocoder import ADaMoSHiFiGANV1 21except ImportError: 22 from music_vocoder import ADaMoSHiFiGANV1 23 24 25root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 26DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8") 27VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder") 28 29 30class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin): 31 @register_to_config 32 def __init__( 33 self, 34 source_sample_rate=None, 35 dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH, 36 vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH, 37 ): 38 super(MusicDCAE, self).__init__() 39 40 self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path) 41 self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path) 42 43 if source_sample_rate is None: 44 source_sample_rate = 48000 45 46 self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) 47 48 self.transform = transforms.Compose( 49 [ 50 transforms.Normalize(0.5, 0.5), 51 ] 52 ) 53 self.min_mel_value = -11.0 54 self.max_mel_value = 3.0 55 self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000))) 56 self.mel_chunk_size = 1024 57 self.time_dimention_multiple = 8 58 self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple 59 self.scale_factor = 0.1786 60 self.shift_factor = -1.9091 61 62 def load_audio(self, audio_path): 63 audio, sr = torchaudio.load(audio_path) 64 if audio.shape[0] == 1: 65 audio = audio.repeat(2, 1) 66 return audio, sr 67 68 def forward_mel(self, audios): 69 mels = [] 70 for i in range(len(audios)): 71 image = self.vocoder.mel_transform(audios[i]) 72 mels.append(image) 73 mels = torch.stack(mels) 74 return mels 75 76 @torch.no_grad() 77 def encode(self, audios, audio_lengths=None, sr=None): 78 if audio_lengths is None: 79 audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0]) 80 audio_lengths = audio_lengths.to(audios.device) 81 82 # audios: N x 2 x T, 48kHz 83 device = audios.device 84 dtype = audios.dtype 85 86 if sr is None: 87 sr = 48000 88 resampler = self.resampler 89 else: 90 resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype) 91 92 audio = resampler(audios) 93 94 max_audio_len = audio.shape[-1] 95 if max_audio_len % (8 * 512) != 0: 96 audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512))) 97 98 mels = self.forward_mel(audio) 99 mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value) 100 mels = self.transform(mels) 101 latents = [] 102 for mel in mels: 103 latent = self.dcae.encoder(mel.unsqueeze(0)) 104 latents.append(latent) 105 latents = torch.cat(latents, dim=0) 106 latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long() 107 latents = (latents - self.shift_factor) * self.scale_factor 108 return latents, latent_lengths 109 110 @torch.no_grad() 111 def decode(self, latents, audio_lengths=None, sr=None): 112 latents = latents / self.scale_factor + self.shift_factor 113 114 pred_wavs = [] 115 116 for latent in latents: 117 mels = self.dcae.decoder(latent.unsqueeze(0)) 118 mels = mels * 0.5 + 0.5 119 mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value 120 121 # wav = self.vocoder.decode(mels[0]).squeeze(1) 122 # decode waveform for each channels to reduce vram footprint 123 wav_ch1 = self.vocoder.decode(mels[:, 0, :, :]).squeeze(1).cpu() 124 wav_ch2 = self.vocoder.decode(mels[:, 1, :, :]).squeeze(1).cpu() 125 wav = torch.cat([wav_ch1, wav_ch2], dim=0) 126 127 if sr is not None: 128 resampler = torchaudio.transforms.Resample(44100, sr) 129 wav = resampler(wav.cpu().float()) 130 else: 131 sr = 44100 132 pred_wavs.append(wav) 133 134 if audio_lengths is not None: 135 pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)] 136 return sr, pred_wavs 137 138 @torch.no_grad() 139 def decode_overlap(self, latents, audio_lengths=None, sr=None): 140 """ 141 Decodes latents into waveforms using an overlapped DCAE and Vocoder. 142 """ 143 print("Using Overlapped DCAE and Vocoder") 144 145 MODEL_INTERNAL_SR = 44100 146 DCAE_LATENT_TO_MEL_STRIDE = 8 147 VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME = 512 148 149 pred_wavs = [] 150 final_output_sr = sr if sr is not None else MODEL_INTERNAL_SR 151 152 # --- DCAE Parameters --- 153 # dcae_win_len_latent: Window length in the latent domain for DCAE processing 154 dcae_win_len_latent = 512 155 # dcae_mel_win_len: Expected mel window length from DCAE decoder output (latent_win * stride) 156 dcae_mel_win_len = dcae_win_len_latent * 8 157 # dcae_anchor_offset: Offset from anchor point to actual start of latent window slice 158 dcae_anchor_offset = dcae_win_len_latent // 4 159 # dcae_anchor_hop: Hop size for anchor points in latent domain 160 dcae_anchor_hop = dcae_win_len_latent // 2 161 # dcae_mel_overlap_len: Overlap length in the mel domain to be trimmed/blended 162 dcae_mel_overlap_len = dcae_mel_win_len // 4 163 164 # --- Vocoder Parameters --- 165 # vocoder_win_len_audio: Audio samples per vocoder processing window 166 vocoder_win_len_audio = 512 * 512 # Example: 262144 samples 167 # vocoder_overlap_len_audio: Audio samples for overlap between vocoder windows 168 vocoder_overlap_len_audio = 1024 169 # vocoder_hop_len_audio: Hop size in audio samples for vocoder processing 170 vocoder_hop_len_audio = vocoder_win_len_audio - 2 * vocoder_overlap_len_audio 171 # vocoder_input_mel_frames_per_block: Number of mel frames fed to vocoder in one go 172 vocoder_input_mel_frames_per_block = vocoder_win_len_audio // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 173 174 crossfade_len_audio = 128 # Audio samples for crossfading vocoder outputs 175 cf_win_tail = torch.linspace(1, 0, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0) 176 cf_win_head = torch.linspace(0, 1, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0) 177 178 for latent_idx, latent_item in enumerate(latents): 179 latent_item = latent_item.to(self.device) 180 current_latent = (latent_item / self.scale_factor + self.shift_factor).unsqueeze(0) # (1, C, H, W_latent) 181 latent_len = current_latent.shape[3] 182 183 # 1. DCAE: Latent to Mel Spectrogram (Overlapped) 184 mels_segments = [] 185 if latent_len == 0: 186 pass # No mel segments to generate 187 else: 188 # Determine anchor points for DCAE windows 189 # An anchor marks a reference point for a window slice. 190 # Window slice: current_latent[..., anchor - offset : anchor - offset + win_len] 191 # First anchor ensures window starts at 0. Last anchor ensures tail is covered. 192 dcae_anchors = list(range(dcae_anchor_offset, latent_len - dcae_anchor_offset, dcae_anchor_hop)) 193 if not dcae_anchors: # If latent is too short for the range, use one anchor 194 dcae_anchors = [dcae_anchor_offset] 195 196 for i, anchor in enumerate(dcae_anchors): 197 win_start_idx = max(0, anchor - dcae_anchor_offset) 198 win_end_idx = min(latent_len, win_start_idx + dcae_win_len_latent) 199 200 dcae_input_segment = current_latent[:, :, :, win_start_idx:win_end_idx] 201 if dcae_input_segment.shape[3] == 0: 202 continue 203 204 mel_output_full = self.dcae.decoder(dcae_input_segment) # (1, C, H_mel, W_mel_fixed_from_dcae) 205 206 is_first = i == 0 207 is_last = i == len(dcae_anchors) - 1 208 209 if is_first and is_last: # Only one segment 210 # Use mel corresponding to actual input latent length 211 true_mel_content_len = dcae_input_segment.shape[3] * DCAE_LATENT_TO_MEL_STRIDE 212 mel_to_keep = mel_output_full[:, :, :, : min(true_mel_content_len, mel_output_full.shape[3])] 213 elif is_first: # First segment, trim end overlap 214 mel_to_keep = mel_output_full[:, :, :, :-dcae_mel_overlap_len] 215 elif is_last: # Last segment, trim start overlap 216 # And ensure we only take content relevant to the (potentially partial) last latent window 217 # The mel_output_full is fixed length. The useful part starts after overlap. 218 # The length of the useful part depends on how much of dcae_input_segment was actual content. 219 # For simplicity in overlap-add, typically trim fixed overlap. 220 # If dcae_input_segment was shorter than dcae_win_len_latent, mel_output_full might contain padding effects. 221 # Standard OLA keeps the corresponding tail. 222 mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:] 223 else: # Middle segment, trim both overlaps 224 mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:-dcae_mel_overlap_len] 225 226 if mel_to_keep.shape[3] > 0: 227 mels_segments.append(mel_to_keep) 228 229 if not mels_segments: 230 num_mel_channels = current_latent.shape[1] 231 mel_height = self.dcae.decoder_output_mel_height 232 concatenated_mels = torch.empty((1, num_mel_channels, mel_height, 0), device=current_latent.device, dtype=current_latent.dtype) 233 else: 234 concatenated_mels = torch.cat(mels_segments, dim=3) 235 236 # Denormalize mels 237 concatenated_mels = concatenated_mels * 0.5 + 0.5 238 concatenated_mels = concatenated_mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value 239 240 mel_total_frames = concatenated_mels.shape[3] 241 242 # 2. Vocoder: Mel Spectrogram to Waveform (Overlapped) 243 if mel_total_frames == 0: 244 # Assuming mono or stereo output based on mel channels (typically mono for vocoder from single mel) 245 num_audio_channels = 1 # Or determine from vocoder capabilities / mel channels 246 final_wav = torch.zeros((num_audio_channels, 0), device=self.device, dtype=torch.float32) 247 else: 248 # Initial vocoder window 249 # Vocoder expects (C_mel, H_mel, W_mel_block) 250 mel_block = concatenated_mels[0, :, :, :vocoder_input_mel_frames_per_block].to(self.device) 251 252 # Pad mel_block if it's shorter than vocoder_input_mel_frames_per_block (e.g. very short audio) 253 if 0 < mel_block.shape[2] < vocoder_input_mel_frames_per_block: 254 pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2] 255 mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0) # Pad last dim 256 257 current_audio_output = self.vocoder.decode(mel_block) # (C_audio, 1, Samples) 258 current_audio_output = current_audio_output[:, :, :-vocoder_overlap_len_audio] # Remove end overlap 259 260 # p_audio_samples tracks the start of the *next* audio segment to generate (in conceptual total audio samples) 261 p_audio_samples = vocoder_hop_len_audio 262 conceptual_total_audio_len_native_sr = mel_total_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 263 264 pbar_total = 1 + max(0, (conceptual_total_audio_len_native_sr - (vocoder_win_len_audio - vocoder_overlap_len_audio))) // vocoder_hop_len_audio 265 266 # Use tqdm if you want a progress bar for the vocoder part 267 # with tqdm(total=pbar_total, desc=f"Vocoder {latent_idx+1}/{len(latents)}", leave=False) as pbar: 268 # pbar.update(1) # For initial window 269 # The loop for subsequent windows 270 while p_audio_samples < conceptual_total_audio_len_native_sr: 271 mel_frame_start = p_audio_samples // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 272 mel_frame_end = mel_frame_start + vocoder_input_mel_frames_per_block 273 274 if mel_frame_start >= mel_total_frames: 275 break # No more mel frames 276 277 mel_block = concatenated_mels[0, :, :, mel_frame_start : min(mel_frame_end, mel_total_frames)].to(self.device) 278 279 if mel_block.shape[2] == 0: 280 break # Should not happen if mel_frame_start is valid 281 282 # Pad if current mel_block is too short (end of sequence) 283 if mel_block.shape[2] < vocoder_input_mel_frames_per_block: 284 pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2] 285 mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0) 286 287 new_audio_win = self.vocoder.decode(mel_block) # (C_audio, 1, Samples) 288 289 # Crossfade 290 # Determine actual crossfade length based on available audio 291 actual_cf_len = min(crossfade_len_audio, current_audio_output.shape[2], new_audio_win.shape[2] - (vocoder_overlap_len_audio - crossfade_len_audio)) 292 if actual_cf_len > 0: # Ensure valid slice lengths for crossfade 293 tail_part = current_audio_output[:, :, -actual_cf_len:] 294 head_part = new_audio_win[:, :, vocoder_overlap_len_audio - actual_cf_len : vocoder_overlap_len_audio] 295 296 crossfaded_segment = tail_part * cf_win_tail[:, :, :actual_cf_len] + head_part * cf_win_head[:, :, :actual_cf_len] 297 298 current_audio_output = torch.cat([current_audio_output[:, :, :-actual_cf_len], crossfaded_segment], dim=2) 299 300 # Append non-overlapping part of new_audio_win 301 is_final_append = p_audio_samples + vocoder_hop_len_audio >= conceptual_total_audio_len_native_sr 302 if is_final_append: 303 segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:] 304 else: 305 segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:-vocoder_overlap_len_audio] 306 307 current_audio_output = torch.cat([current_audio_output, segment_to_append], dim=2) 308 309 p_audio_samples += vocoder_hop_len_audio 310 # pbar.update(1) # if using tqdm 311 312 final_wav = current_audio_output.squeeze(1) # (C_audio, Samples) 313 314 # 3. Resampling (if necessary) 315 if final_output_sr != MODEL_INTERNAL_SR and final_wav.numel() > 0: 316 # Resample expects CPU tensor if using torchaudio.transforms on older versions or for some backends 317 resampler = torchaudio.transforms.Resample(MODEL_INTERNAL_SR, final_output_sr, dtype=final_wav.dtype) 318 final_wav = resampler(final_wav.cpu()).to(self.device) # Move back to device if needed later 319 320 pred_wavs.append(final_wav) 321 322 # 4. Final Truncation 323 processed_pred_wavs = [] 324 for i, wav in enumerate(pred_wavs): 325 # Calculate expected length based on original latent, at the FINAL output sample rate 326 _num_latent_frames = latents[i].shape[-1] # Use original latent item for shape 327 _num_mel_frames = _num_latent_frames * DCAE_LATENT_TO_MEL_STRIDE 328 _conceptual_native_audio_len = _num_mel_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 329 max_possible_len = int(_conceptual_native_audio_len * final_output_sr / MODEL_INTERNAL_SR) 330 331 current_wav_len = wav.shape[1] 332 333 if audio_lengths is not None: 334 # User-provided length is the primary target, capped by actual and max possible 335 target_len = min(audio_lengths[i], current_wav_len, max_possible_len) 336 else: 337 # No user length, use max possible capped by actual 338 target_len = min(max_possible_len, current_wav_len) 339 340 processed_pred_wavs.append(wav[:, : max(0, target_len)].cpu()) # Ensure length is non-negative 341 342 return final_output_sr, processed_pred_wavs 343 344 def forward(self, audios, audio_lengths=None, sr=None): 345 latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr) 346 sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr) 347 return sr, pred_wavs, latents, latent_lengths 348 349 350if __name__ == "__main__": 351 audio, sr = torchaudio.load("test.wav") 352 audio_lengths = torch.tensor([audio.shape[1]]) 353 audios = audio.unsqueeze(0) 354 355 # test encode only 356 model = MusicDCAE() 357 # latents, latent_lengths = model.encode(audios, audio_lengths) 358 # print("latents shape: ", latents.shape) 359 # print("latent_lengths: ", latent_lengths) 360 361 # test encode and decode 362 sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr) 363 print("reconstructed wavs: ", pred_wavs[0].shape) 364 print("latents shape: ", latents.shape) 365 print("latent_lengths: ", latent_lengths) 366 print("sr: ", sr) 367 torchaudio.save("test_reconstructed.wav", pred_wavs[0], sr) 368 print("test_reconstructed.wav")
root_dir =
'/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep'
DEFAULT_PRETRAINED_PATH =
'/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep/checkpoints/music_dcae_f8c8'
VOCODER_PRETRAINED_PATH =
'/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep/checkpoints/music_vocoder'
class
MusicDCAE(diffusers.models.modeling_utils.ModelMixin, diffusers.configuration_utils.ConfigMixin, diffusers.loaders.single_file_model.FromOriginalModelMixin):
31class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin): 32 @register_to_config 33 def __init__( 34 self, 35 source_sample_rate=None, 36 dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH, 37 vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH, 38 ): 39 super(MusicDCAE, self).__init__() 40 41 self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path) 42 self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path) 43 44 if source_sample_rate is None: 45 source_sample_rate = 48000 46 47 self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) 48 49 self.transform = transforms.Compose( 50 [ 51 transforms.Normalize(0.5, 0.5), 52 ] 53 ) 54 self.min_mel_value = -11.0 55 self.max_mel_value = 3.0 56 self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000))) 57 self.mel_chunk_size = 1024 58 self.time_dimention_multiple = 8 59 self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple 60 self.scale_factor = 0.1786 61 self.shift_factor = -1.9091 62 63 def load_audio(self, audio_path): 64 audio, sr = torchaudio.load(audio_path) 65 if audio.shape[0] == 1: 66 audio = audio.repeat(2, 1) 67 return audio, sr 68 69 def forward_mel(self, audios): 70 mels = [] 71 for i in range(len(audios)): 72 image = self.vocoder.mel_transform(audios[i]) 73 mels.append(image) 74 mels = torch.stack(mels) 75 return mels 76 77 @torch.no_grad() 78 def encode(self, audios, audio_lengths=None, sr=None): 79 if audio_lengths is None: 80 audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0]) 81 audio_lengths = audio_lengths.to(audios.device) 82 83 # audios: N x 2 x T, 48kHz 84 device = audios.device 85 dtype = audios.dtype 86 87 if sr is None: 88 sr = 48000 89 resampler = self.resampler 90 else: 91 resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype) 92 93 audio = resampler(audios) 94 95 max_audio_len = audio.shape[-1] 96 if max_audio_len % (8 * 512) != 0: 97 audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512))) 98 99 mels = self.forward_mel(audio) 100 mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value) 101 mels = self.transform(mels) 102 latents = [] 103 for mel in mels: 104 latent = self.dcae.encoder(mel.unsqueeze(0)) 105 latents.append(latent) 106 latents = torch.cat(latents, dim=0) 107 latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long() 108 latents = (latents - self.shift_factor) * self.scale_factor 109 return latents, latent_lengths 110 111 @torch.no_grad() 112 def decode(self, latents, audio_lengths=None, sr=None): 113 latents = latents / self.scale_factor + self.shift_factor 114 115 pred_wavs = [] 116 117 for latent in latents: 118 mels = self.dcae.decoder(latent.unsqueeze(0)) 119 mels = mels * 0.5 + 0.5 120 mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value 121 122 # wav = self.vocoder.decode(mels[0]).squeeze(1) 123 # decode waveform for each channels to reduce vram footprint 124 wav_ch1 = self.vocoder.decode(mels[:, 0, :, :]).squeeze(1).cpu() 125 wav_ch2 = self.vocoder.decode(mels[:, 1, :, :]).squeeze(1).cpu() 126 wav = torch.cat([wav_ch1, wav_ch2], dim=0) 127 128 if sr is not None: 129 resampler = torchaudio.transforms.Resample(44100, sr) 130 wav = resampler(wav.cpu().float()) 131 else: 132 sr = 44100 133 pred_wavs.append(wav) 134 135 if audio_lengths is not None: 136 pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)] 137 return sr, pred_wavs 138 139 @torch.no_grad() 140 def decode_overlap(self, latents, audio_lengths=None, sr=None): 141 """ 142 Decodes latents into waveforms using an overlapped DCAE and Vocoder. 143 """ 144 print("Using Overlapped DCAE and Vocoder") 145 146 MODEL_INTERNAL_SR = 44100 147 DCAE_LATENT_TO_MEL_STRIDE = 8 148 VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME = 512 149 150 pred_wavs = [] 151 final_output_sr = sr if sr is not None else MODEL_INTERNAL_SR 152 153 # --- DCAE Parameters --- 154 # dcae_win_len_latent: Window length in the latent domain for DCAE processing 155 dcae_win_len_latent = 512 156 # dcae_mel_win_len: Expected mel window length from DCAE decoder output (latent_win * stride) 157 dcae_mel_win_len = dcae_win_len_latent * 8 158 # dcae_anchor_offset: Offset from anchor point to actual start of latent window slice 159 dcae_anchor_offset = dcae_win_len_latent // 4 160 # dcae_anchor_hop: Hop size for anchor points in latent domain 161 dcae_anchor_hop = dcae_win_len_latent // 2 162 # dcae_mel_overlap_len: Overlap length in the mel domain to be trimmed/blended 163 dcae_mel_overlap_len = dcae_mel_win_len // 4 164 165 # --- Vocoder Parameters --- 166 # vocoder_win_len_audio: Audio samples per vocoder processing window 167 vocoder_win_len_audio = 512 * 512 # Example: 262144 samples 168 # vocoder_overlap_len_audio: Audio samples for overlap between vocoder windows 169 vocoder_overlap_len_audio = 1024 170 # vocoder_hop_len_audio: Hop size in audio samples for vocoder processing 171 vocoder_hop_len_audio = vocoder_win_len_audio - 2 * vocoder_overlap_len_audio 172 # vocoder_input_mel_frames_per_block: Number of mel frames fed to vocoder in one go 173 vocoder_input_mel_frames_per_block = vocoder_win_len_audio // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 174 175 crossfade_len_audio = 128 # Audio samples for crossfading vocoder outputs 176 cf_win_tail = torch.linspace(1, 0, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0) 177 cf_win_head = torch.linspace(0, 1, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0) 178 179 for latent_idx, latent_item in enumerate(latents): 180 latent_item = latent_item.to(self.device) 181 current_latent = (latent_item / self.scale_factor + self.shift_factor).unsqueeze(0) # (1, C, H, W_latent) 182 latent_len = current_latent.shape[3] 183 184 # 1. DCAE: Latent to Mel Spectrogram (Overlapped) 185 mels_segments = [] 186 if latent_len == 0: 187 pass # No mel segments to generate 188 else: 189 # Determine anchor points for DCAE windows 190 # An anchor marks a reference point for a window slice. 191 # Window slice: current_latent[..., anchor - offset : anchor - offset + win_len] 192 # First anchor ensures window starts at 0. Last anchor ensures tail is covered. 193 dcae_anchors = list(range(dcae_anchor_offset, latent_len - dcae_anchor_offset, dcae_anchor_hop)) 194 if not dcae_anchors: # If latent is too short for the range, use one anchor 195 dcae_anchors = [dcae_anchor_offset] 196 197 for i, anchor in enumerate(dcae_anchors): 198 win_start_idx = max(0, anchor - dcae_anchor_offset) 199 win_end_idx = min(latent_len, win_start_idx + dcae_win_len_latent) 200 201 dcae_input_segment = current_latent[:, :, :, win_start_idx:win_end_idx] 202 if dcae_input_segment.shape[3] == 0: 203 continue 204 205 mel_output_full = self.dcae.decoder(dcae_input_segment) # (1, C, H_mel, W_mel_fixed_from_dcae) 206 207 is_first = i == 0 208 is_last = i == len(dcae_anchors) - 1 209 210 if is_first and is_last: # Only one segment 211 # Use mel corresponding to actual input latent length 212 true_mel_content_len = dcae_input_segment.shape[3] * DCAE_LATENT_TO_MEL_STRIDE 213 mel_to_keep = mel_output_full[:, :, :, : min(true_mel_content_len, mel_output_full.shape[3])] 214 elif is_first: # First segment, trim end overlap 215 mel_to_keep = mel_output_full[:, :, :, :-dcae_mel_overlap_len] 216 elif is_last: # Last segment, trim start overlap 217 # And ensure we only take content relevant to the (potentially partial) last latent window 218 # The mel_output_full is fixed length. The useful part starts after overlap. 219 # The length of the useful part depends on how much of dcae_input_segment was actual content. 220 # For simplicity in overlap-add, typically trim fixed overlap. 221 # If dcae_input_segment was shorter than dcae_win_len_latent, mel_output_full might contain padding effects. 222 # Standard OLA keeps the corresponding tail. 223 mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:] 224 else: # Middle segment, trim both overlaps 225 mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:-dcae_mel_overlap_len] 226 227 if mel_to_keep.shape[3] > 0: 228 mels_segments.append(mel_to_keep) 229 230 if not mels_segments: 231 num_mel_channels = current_latent.shape[1] 232 mel_height = self.dcae.decoder_output_mel_height 233 concatenated_mels = torch.empty((1, num_mel_channels, mel_height, 0), device=current_latent.device, dtype=current_latent.dtype) 234 else: 235 concatenated_mels = torch.cat(mels_segments, dim=3) 236 237 # Denormalize mels 238 concatenated_mels = concatenated_mels * 0.5 + 0.5 239 concatenated_mels = concatenated_mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value 240 241 mel_total_frames = concatenated_mels.shape[3] 242 243 # 2. Vocoder: Mel Spectrogram to Waveform (Overlapped) 244 if mel_total_frames == 0: 245 # Assuming mono or stereo output based on mel channels (typically mono for vocoder from single mel) 246 num_audio_channels = 1 # Or determine from vocoder capabilities / mel channels 247 final_wav = torch.zeros((num_audio_channels, 0), device=self.device, dtype=torch.float32) 248 else: 249 # Initial vocoder window 250 # Vocoder expects (C_mel, H_mel, W_mel_block) 251 mel_block = concatenated_mels[0, :, :, :vocoder_input_mel_frames_per_block].to(self.device) 252 253 # Pad mel_block if it's shorter than vocoder_input_mel_frames_per_block (e.g. very short audio) 254 if 0 < mel_block.shape[2] < vocoder_input_mel_frames_per_block: 255 pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2] 256 mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0) # Pad last dim 257 258 current_audio_output = self.vocoder.decode(mel_block) # (C_audio, 1, Samples) 259 current_audio_output = current_audio_output[:, :, :-vocoder_overlap_len_audio] # Remove end overlap 260 261 # p_audio_samples tracks the start of the *next* audio segment to generate (in conceptual total audio samples) 262 p_audio_samples = vocoder_hop_len_audio 263 conceptual_total_audio_len_native_sr = mel_total_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 264 265 pbar_total = 1 + max(0, (conceptual_total_audio_len_native_sr - (vocoder_win_len_audio - vocoder_overlap_len_audio))) // vocoder_hop_len_audio 266 267 # Use tqdm if you want a progress bar for the vocoder part 268 # with tqdm(total=pbar_total, desc=f"Vocoder {latent_idx+1}/{len(latents)}", leave=False) as pbar: 269 # pbar.update(1) # For initial window 270 # The loop for subsequent windows 271 while p_audio_samples < conceptual_total_audio_len_native_sr: 272 mel_frame_start = p_audio_samples // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 273 mel_frame_end = mel_frame_start + vocoder_input_mel_frames_per_block 274 275 if mel_frame_start >= mel_total_frames: 276 break # No more mel frames 277 278 mel_block = concatenated_mels[0, :, :, mel_frame_start : min(mel_frame_end, mel_total_frames)].to(self.device) 279 280 if mel_block.shape[2] == 0: 281 break # Should not happen if mel_frame_start is valid 282 283 # Pad if current mel_block is too short (end of sequence) 284 if mel_block.shape[2] < vocoder_input_mel_frames_per_block: 285 pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2] 286 mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0) 287 288 new_audio_win = self.vocoder.decode(mel_block) # (C_audio, 1, Samples) 289 290 # Crossfade 291 # Determine actual crossfade length based on available audio 292 actual_cf_len = min(crossfade_len_audio, current_audio_output.shape[2], new_audio_win.shape[2] - (vocoder_overlap_len_audio - crossfade_len_audio)) 293 if actual_cf_len > 0: # Ensure valid slice lengths for crossfade 294 tail_part = current_audio_output[:, :, -actual_cf_len:] 295 head_part = new_audio_win[:, :, vocoder_overlap_len_audio - actual_cf_len : vocoder_overlap_len_audio] 296 297 crossfaded_segment = tail_part * cf_win_tail[:, :, :actual_cf_len] + head_part * cf_win_head[:, :, :actual_cf_len] 298 299 current_audio_output = torch.cat([current_audio_output[:, :, :-actual_cf_len], crossfaded_segment], dim=2) 300 301 # Append non-overlapping part of new_audio_win 302 is_final_append = p_audio_samples + vocoder_hop_len_audio >= conceptual_total_audio_len_native_sr 303 if is_final_append: 304 segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:] 305 else: 306 segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:-vocoder_overlap_len_audio] 307 308 current_audio_output = torch.cat([current_audio_output, segment_to_append], dim=2) 309 310 p_audio_samples += vocoder_hop_len_audio 311 # pbar.update(1) # if using tqdm 312 313 final_wav = current_audio_output.squeeze(1) # (C_audio, Samples) 314 315 # 3. Resampling (if necessary) 316 if final_output_sr != MODEL_INTERNAL_SR and final_wav.numel() > 0: 317 # Resample expects CPU tensor if using torchaudio.transforms on older versions or for some backends 318 resampler = torchaudio.transforms.Resample(MODEL_INTERNAL_SR, final_output_sr, dtype=final_wav.dtype) 319 final_wav = resampler(final_wav.cpu()).to(self.device) # Move back to device if needed later 320 321 pred_wavs.append(final_wav) 322 323 # 4. Final Truncation 324 processed_pred_wavs = [] 325 for i, wav in enumerate(pred_wavs): 326 # Calculate expected length based on original latent, at the FINAL output sample rate 327 _num_latent_frames = latents[i].shape[-1] # Use original latent item for shape 328 _num_mel_frames = _num_latent_frames * DCAE_LATENT_TO_MEL_STRIDE 329 _conceptual_native_audio_len = _num_mel_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 330 max_possible_len = int(_conceptual_native_audio_len * final_output_sr / MODEL_INTERNAL_SR) 331 332 current_wav_len = wav.shape[1] 333 334 if audio_lengths is not None: 335 # User-provided length is the primary target, capped by actual and max possible 336 target_len = min(audio_lengths[i], current_wav_len, max_possible_len) 337 else: 338 # No user length, use max possible capped by actual 339 target_len = min(max_possible_len, current_wav_len) 340 341 processed_pred_wavs.append(wav[:, : max(0, target_len)].cpu()) # Ensure length is non-negative 342 343 return final_output_sr, processed_pred_wavs 344 345 def forward(self, audios, audio_lengths=None, sr=None): 346 latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr) 347 sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr) 348 return sr, pred_wavs, latents, latent_lengths
Base class for all models.
[ModelMixin] takes care of storing the model configuration and provides methods for loading, downloading and
saving models.
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
@register_to_config
MusicDCAE( source_sample_rate=None, dcae_checkpoint_path='/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep/checkpoints/music_dcae_f8c8', vocoder_checkpoint_path='/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep/checkpoints/music_vocoder')
32 @register_to_config 33 def __init__( 34 self, 35 source_sample_rate=None, 36 dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH, 37 vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH, 38 ): 39 super(MusicDCAE, self).__init__() 40 41 self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path) 42 self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path) 43 44 if source_sample_rate is None: 45 source_sample_rate = 48000 46 47 self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) 48 49 self.transform = transforms.Compose( 50 [ 51 transforms.Normalize(0.5, 0.5), 52 ] 53 ) 54 self.min_mel_value = -11.0 55 self.max_mel_value = 3.0 56 self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000))) 57 self.mel_chunk_size = 1024 58 self.time_dimention_multiple = 8 59 self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple 60 self.scale_factor = 0.1786 61 self.shift_factor = -1.9091
Initialize internal Module state, shared by both nn.Module and ScriptModule.
@torch.no_grad()
def
encode(self, audios, audio_lengths=None, sr=None):
77 @torch.no_grad() 78 def encode(self, audios, audio_lengths=None, sr=None): 79 if audio_lengths is None: 80 audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0]) 81 audio_lengths = audio_lengths.to(audios.device) 82 83 # audios: N x 2 x T, 48kHz 84 device = audios.device 85 dtype = audios.dtype 86 87 if sr is None: 88 sr = 48000 89 resampler = self.resampler 90 else: 91 resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype) 92 93 audio = resampler(audios) 94 95 max_audio_len = audio.shape[-1] 96 if max_audio_len % (8 * 512) != 0: 97 audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512))) 98 99 mels = self.forward_mel(audio) 100 mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value) 101 mels = self.transform(mels) 102 latents = [] 103 for mel in mels: 104 latent = self.dcae.encoder(mel.unsqueeze(0)) 105 latents.append(latent) 106 latents = torch.cat(latents, dim=0) 107 latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long() 108 latents = (latents - self.shift_factor) * self.scale_factor 109 return latents, latent_lengths
@torch.no_grad()
def
decode(self, latents, audio_lengths=None, sr=None):
111 @torch.no_grad() 112 def decode(self, latents, audio_lengths=None, sr=None): 113 latents = latents / self.scale_factor + self.shift_factor 114 115 pred_wavs = [] 116 117 for latent in latents: 118 mels = self.dcae.decoder(latent.unsqueeze(0)) 119 mels = mels * 0.5 + 0.5 120 mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value 121 122 # wav = self.vocoder.decode(mels[0]).squeeze(1) 123 # decode waveform for each channels to reduce vram footprint 124 wav_ch1 = self.vocoder.decode(mels[:, 0, :, :]).squeeze(1).cpu() 125 wav_ch2 = self.vocoder.decode(mels[:, 1, :, :]).squeeze(1).cpu() 126 wav = torch.cat([wav_ch1, wav_ch2], dim=0) 127 128 if sr is not None: 129 resampler = torchaudio.transforms.Resample(44100, sr) 130 wav = resampler(wav.cpu().float()) 131 else: 132 sr = 44100 133 pred_wavs.append(wav) 134 135 if audio_lengths is not None: 136 pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)] 137 return sr, pred_wavs
@torch.no_grad()
def
decode_overlap(self, latents, audio_lengths=None, sr=None):
139 @torch.no_grad() 140 def decode_overlap(self, latents, audio_lengths=None, sr=None): 141 """ 142 Decodes latents into waveforms using an overlapped DCAE and Vocoder. 143 """ 144 print("Using Overlapped DCAE and Vocoder") 145 146 MODEL_INTERNAL_SR = 44100 147 DCAE_LATENT_TO_MEL_STRIDE = 8 148 VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME = 512 149 150 pred_wavs = [] 151 final_output_sr = sr if sr is not None else MODEL_INTERNAL_SR 152 153 # --- DCAE Parameters --- 154 # dcae_win_len_latent: Window length in the latent domain for DCAE processing 155 dcae_win_len_latent = 512 156 # dcae_mel_win_len: Expected mel window length from DCAE decoder output (latent_win * stride) 157 dcae_mel_win_len = dcae_win_len_latent * 8 158 # dcae_anchor_offset: Offset from anchor point to actual start of latent window slice 159 dcae_anchor_offset = dcae_win_len_latent // 4 160 # dcae_anchor_hop: Hop size for anchor points in latent domain 161 dcae_anchor_hop = dcae_win_len_latent // 2 162 # dcae_mel_overlap_len: Overlap length in the mel domain to be trimmed/blended 163 dcae_mel_overlap_len = dcae_mel_win_len // 4 164 165 # --- Vocoder Parameters --- 166 # vocoder_win_len_audio: Audio samples per vocoder processing window 167 vocoder_win_len_audio = 512 * 512 # Example: 262144 samples 168 # vocoder_overlap_len_audio: Audio samples for overlap between vocoder windows 169 vocoder_overlap_len_audio = 1024 170 # vocoder_hop_len_audio: Hop size in audio samples for vocoder processing 171 vocoder_hop_len_audio = vocoder_win_len_audio - 2 * vocoder_overlap_len_audio 172 # vocoder_input_mel_frames_per_block: Number of mel frames fed to vocoder in one go 173 vocoder_input_mel_frames_per_block = vocoder_win_len_audio // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 174 175 crossfade_len_audio = 128 # Audio samples for crossfading vocoder outputs 176 cf_win_tail = torch.linspace(1, 0, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0) 177 cf_win_head = torch.linspace(0, 1, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0) 178 179 for latent_idx, latent_item in enumerate(latents): 180 latent_item = latent_item.to(self.device) 181 current_latent = (latent_item / self.scale_factor + self.shift_factor).unsqueeze(0) # (1, C, H, W_latent) 182 latent_len = current_latent.shape[3] 183 184 # 1. DCAE: Latent to Mel Spectrogram (Overlapped) 185 mels_segments = [] 186 if latent_len == 0: 187 pass # No mel segments to generate 188 else: 189 # Determine anchor points for DCAE windows 190 # An anchor marks a reference point for a window slice. 191 # Window slice: current_latent[..., anchor - offset : anchor - offset + win_len] 192 # First anchor ensures window starts at 0. Last anchor ensures tail is covered. 193 dcae_anchors = list(range(dcae_anchor_offset, latent_len - dcae_anchor_offset, dcae_anchor_hop)) 194 if not dcae_anchors: # If latent is too short for the range, use one anchor 195 dcae_anchors = [dcae_anchor_offset] 196 197 for i, anchor in enumerate(dcae_anchors): 198 win_start_idx = max(0, anchor - dcae_anchor_offset) 199 win_end_idx = min(latent_len, win_start_idx + dcae_win_len_latent) 200 201 dcae_input_segment = current_latent[:, :, :, win_start_idx:win_end_idx] 202 if dcae_input_segment.shape[3] == 0: 203 continue 204 205 mel_output_full = self.dcae.decoder(dcae_input_segment) # (1, C, H_mel, W_mel_fixed_from_dcae) 206 207 is_first = i == 0 208 is_last = i == len(dcae_anchors) - 1 209 210 if is_first and is_last: # Only one segment 211 # Use mel corresponding to actual input latent length 212 true_mel_content_len = dcae_input_segment.shape[3] * DCAE_LATENT_TO_MEL_STRIDE 213 mel_to_keep = mel_output_full[:, :, :, : min(true_mel_content_len, mel_output_full.shape[3])] 214 elif is_first: # First segment, trim end overlap 215 mel_to_keep = mel_output_full[:, :, :, :-dcae_mel_overlap_len] 216 elif is_last: # Last segment, trim start overlap 217 # And ensure we only take content relevant to the (potentially partial) last latent window 218 # The mel_output_full is fixed length. The useful part starts after overlap. 219 # The length of the useful part depends on how much of dcae_input_segment was actual content. 220 # For simplicity in overlap-add, typically trim fixed overlap. 221 # If dcae_input_segment was shorter than dcae_win_len_latent, mel_output_full might contain padding effects. 222 # Standard OLA keeps the corresponding tail. 223 mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:] 224 else: # Middle segment, trim both overlaps 225 mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:-dcae_mel_overlap_len] 226 227 if mel_to_keep.shape[3] > 0: 228 mels_segments.append(mel_to_keep) 229 230 if not mels_segments: 231 num_mel_channels = current_latent.shape[1] 232 mel_height = self.dcae.decoder_output_mel_height 233 concatenated_mels = torch.empty((1, num_mel_channels, mel_height, 0), device=current_latent.device, dtype=current_latent.dtype) 234 else: 235 concatenated_mels = torch.cat(mels_segments, dim=3) 236 237 # Denormalize mels 238 concatenated_mels = concatenated_mels * 0.5 + 0.5 239 concatenated_mels = concatenated_mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value 240 241 mel_total_frames = concatenated_mels.shape[3] 242 243 # 2. Vocoder: Mel Spectrogram to Waveform (Overlapped) 244 if mel_total_frames == 0: 245 # Assuming mono or stereo output based on mel channels (typically mono for vocoder from single mel) 246 num_audio_channels = 1 # Or determine from vocoder capabilities / mel channels 247 final_wav = torch.zeros((num_audio_channels, 0), device=self.device, dtype=torch.float32) 248 else: 249 # Initial vocoder window 250 # Vocoder expects (C_mel, H_mel, W_mel_block) 251 mel_block = concatenated_mels[0, :, :, :vocoder_input_mel_frames_per_block].to(self.device) 252 253 # Pad mel_block if it's shorter than vocoder_input_mel_frames_per_block (e.g. very short audio) 254 if 0 < mel_block.shape[2] < vocoder_input_mel_frames_per_block: 255 pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2] 256 mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0) # Pad last dim 257 258 current_audio_output = self.vocoder.decode(mel_block) # (C_audio, 1, Samples) 259 current_audio_output = current_audio_output[:, :, :-vocoder_overlap_len_audio] # Remove end overlap 260 261 # p_audio_samples tracks the start of the *next* audio segment to generate (in conceptual total audio samples) 262 p_audio_samples = vocoder_hop_len_audio 263 conceptual_total_audio_len_native_sr = mel_total_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 264 265 pbar_total = 1 + max(0, (conceptual_total_audio_len_native_sr - (vocoder_win_len_audio - vocoder_overlap_len_audio))) // vocoder_hop_len_audio 266 267 # Use tqdm if you want a progress bar for the vocoder part 268 # with tqdm(total=pbar_total, desc=f"Vocoder {latent_idx+1}/{len(latents)}", leave=False) as pbar: 269 # pbar.update(1) # For initial window 270 # The loop for subsequent windows 271 while p_audio_samples < conceptual_total_audio_len_native_sr: 272 mel_frame_start = p_audio_samples // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 273 mel_frame_end = mel_frame_start + vocoder_input_mel_frames_per_block 274 275 if mel_frame_start >= mel_total_frames: 276 break # No more mel frames 277 278 mel_block = concatenated_mels[0, :, :, mel_frame_start : min(mel_frame_end, mel_total_frames)].to(self.device) 279 280 if mel_block.shape[2] == 0: 281 break # Should not happen if mel_frame_start is valid 282 283 # Pad if current mel_block is too short (end of sequence) 284 if mel_block.shape[2] < vocoder_input_mel_frames_per_block: 285 pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2] 286 mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0) 287 288 new_audio_win = self.vocoder.decode(mel_block) # (C_audio, 1, Samples) 289 290 # Crossfade 291 # Determine actual crossfade length based on available audio 292 actual_cf_len = min(crossfade_len_audio, current_audio_output.shape[2], new_audio_win.shape[2] - (vocoder_overlap_len_audio - crossfade_len_audio)) 293 if actual_cf_len > 0: # Ensure valid slice lengths for crossfade 294 tail_part = current_audio_output[:, :, -actual_cf_len:] 295 head_part = new_audio_win[:, :, vocoder_overlap_len_audio - actual_cf_len : vocoder_overlap_len_audio] 296 297 crossfaded_segment = tail_part * cf_win_tail[:, :, :actual_cf_len] + head_part * cf_win_head[:, :, :actual_cf_len] 298 299 current_audio_output = torch.cat([current_audio_output[:, :, :-actual_cf_len], crossfaded_segment], dim=2) 300 301 # Append non-overlapping part of new_audio_win 302 is_final_append = p_audio_samples + vocoder_hop_len_audio >= conceptual_total_audio_len_native_sr 303 if is_final_append: 304 segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:] 305 else: 306 segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:-vocoder_overlap_len_audio] 307 308 current_audio_output = torch.cat([current_audio_output, segment_to_append], dim=2) 309 310 p_audio_samples += vocoder_hop_len_audio 311 # pbar.update(1) # if using tqdm 312 313 final_wav = current_audio_output.squeeze(1) # (C_audio, Samples) 314 315 # 3. Resampling (if necessary) 316 if final_output_sr != MODEL_INTERNAL_SR and final_wav.numel() > 0: 317 # Resample expects CPU tensor if using torchaudio.transforms on older versions or for some backends 318 resampler = torchaudio.transforms.Resample(MODEL_INTERNAL_SR, final_output_sr, dtype=final_wav.dtype) 319 final_wav = resampler(final_wav.cpu()).to(self.device) # Move back to device if needed later 320 321 pred_wavs.append(final_wav) 322 323 # 4. Final Truncation 324 processed_pred_wavs = [] 325 for i, wav in enumerate(pred_wavs): 326 # Calculate expected length based on original latent, at the FINAL output sample rate 327 _num_latent_frames = latents[i].shape[-1] # Use original latent item for shape 328 _num_mel_frames = _num_latent_frames * DCAE_LATENT_TO_MEL_STRIDE 329 _conceptual_native_audio_len = _num_mel_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME 330 max_possible_len = int(_conceptual_native_audio_len * final_output_sr / MODEL_INTERNAL_SR) 331 332 current_wav_len = wav.shape[1] 333 334 if audio_lengths is not None: 335 # User-provided length is the primary target, capped by actual and max possible 336 target_len = min(audio_lengths[i], current_wav_len, max_possible_len) 337 else: 338 # No user length, use max possible capped by actual 339 target_len = min(max_possible_len, current_wav_len) 340 341 processed_pred_wavs.append(wav[:, : max(0, target_len)].cpu()) # Ensure length is non-negative 342 343 return final_output_sr, processed_pred_wavs
Decodes latents into waveforms using an overlapped DCAE and Vocoder.
def
forward(self, audios, audio_lengths=None, sr=None):
345 def forward(self, audios, audio_lengths=None, sr=None): 346 latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr) 347 sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr) 348 return sr, pred_wavs, latents, latent_lengths
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.