divisor.acestep.music_dcae.music_dcae_pipeline

ACE-Step: A Step Towards Music Generation Foundation Model

https://github.com/ace-step/ACE-Step

Apache 2.0 License

  1"""
  2ACE-Step: A Step Towards Music Generation Foundation Model
  3
  4https://github.com/ace-step/ACE-Step
  5
  6Apache 2.0 License
  7"""
  8
  9import os
 10import torch
 11from diffusers import AutoencoderDC
 12import torchaudio
 13import torchvision.transforms as transforms
 14from diffusers.models.modeling_utils import ModelMixin
 15from diffusers.loaders import FromOriginalModelMixin
 16from diffusers.configuration_utils import ConfigMixin, register_to_config
 17from tqdm import tqdm
 18
 19try:
 20    from divisor.acestep.music_dcae.music_vocoder import ADaMoSHiFiGANV1
 21except ImportError:
 22    from music_vocoder import ADaMoSHiFiGANV1
 23
 24
 25root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 26DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8")
 27VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder")
 28
 29
 30class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
 31    @register_to_config
 32    def __init__(
 33        self,
 34        source_sample_rate=None,
 35        dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH,
 36        vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH,
 37    ):
 38        super(MusicDCAE, self).__init__()
 39
 40        self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
 41        self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path)
 42
 43        if source_sample_rate is None:
 44            source_sample_rate = 48000
 45
 46        self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
 47
 48        self.transform = transforms.Compose(
 49            [
 50                transforms.Normalize(0.5, 0.5),
 51            ]
 52        )
 53        self.min_mel_value = -11.0
 54        self.max_mel_value = 3.0
 55        self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
 56        self.mel_chunk_size = 1024
 57        self.time_dimention_multiple = 8
 58        self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
 59        self.scale_factor = 0.1786
 60        self.shift_factor = -1.9091
 61
 62    def load_audio(self, audio_path):
 63        audio, sr = torchaudio.load(audio_path)
 64        if audio.shape[0] == 1:
 65            audio = audio.repeat(2, 1)
 66        return audio, sr
 67
 68    def forward_mel(self, audios):
 69        mels = []
 70        for i in range(len(audios)):
 71            image = self.vocoder.mel_transform(audios[i])
 72            mels.append(image)
 73        mels = torch.stack(mels)
 74        return mels
 75
 76    @torch.no_grad()
 77    def encode(self, audios, audio_lengths=None, sr=None):
 78        if audio_lengths is None:
 79            audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
 80            audio_lengths = audio_lengths.to(audios.device)
 81
 82        # audios: N x 2 x T, 48kHz
 83        device = audios.device
 84        dtype = audios.dtype
 85
 86        if sr is None:
 87            sr = 48000
 88            resampler = self.resampler
 89        else:
 90            resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
 91
 92        audio = resampler(audios)
 93
 94        max_audio_len = audio.shape[-1]
 95        if max_audio_len % (8 * 512) != 0:
 96            audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512)))
 97
 98        mels = self.forward_mel(audio)
 99        mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
100        mels = self.transform(mels)
101        latents = []
102        for mel in mels:
103            latent = self.dcae.encoder(mel.unsqueeze(0))
104            latents.append(latent)
105        latents = torch.cat(latents, dim=0)
106        latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
107        latents = (latents - self.shift_factor) * self.scale_factor
108        return latents, latent_lengths
109
110    @torch.no_grad()
111    def decode(self, latents, audio_lengths=None, sr=None):
112        latents = latents / self.scale_factor + self.shift_factor
113
114        pred_wavs = []
115
116        for latent in latents:
117            mels = self.dcae.decoder(latent.unsqueeze(0))
118            mels = mels * 0.5 + 0.5
119            mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
120
121            # wav = self.vocoder.decode(mels[0]).squeeze(1)
122            # decode waveform for each channels to reduce vram footprint
123            wav_ch1 = self.vocoder.decode(mels[:, 0, :, :]).squeeze(1).cpu()
124            wav_ch2 = self.vocoder.decode(mels[:, 1, :, :]).squeeze(1).cpu()
125            wav = torch.cat([wav_ch1, wav_ch2], dim=0)
126
127            if sr is not None:
128                resampler = torchaudio.transforms.Resample(44100, sr)
129                wav = resampler(wav.cpu().float())
130            else:
131                sr = 44100
132            pred_wavs.append(wav)
133
134        if audio_lengths is not None:
135            pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
136        return sr, pred_wavs
137
138    @torch.no_grad()
139    def decode_overlap(self, latents, audio_lengths=None, sr=None):
140        """
141        Decodes latents into waveforms using an overlapped DCAE and Vocoder.
142        """
143        print("Using Overlapped DCAE and Vocoder")
144
145        MODEL_INTERNAL_SR = 44100
146        DCAE_LATENT_TO_MEL_STRIDE = 8
147        VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME = 512
148
149        pred_wavs = []
150        final_output_sr = sr if sr is not None else MODEL_INTERNAL_SR
151
152        # --- DCAE Parameters ---
153        # dcae_win_len_latent: Window length in the latent domain for DCAE processing
154        dcae_win_len_latent = 512
155        # dcae_mel_win_len: Expected mel window length from DCAE decoder output (latent_win * stride)
156        dcae_mel_win_len = dcae_win_len_latent * 8
157        # dcae_anchor_offset: Offset from anchor point to actual start of latent window slice
158        dcae_anchor_offset = dcae_win_len_latent // 4
159        # dcae_anchor_hop: Hop size for anchor points in latent domain
160        dcae_anchor_hop = dcae_win_len_latent // 2
161        # dcae_mel_overlap_len: Overlap length in the mel domain to be trimmed/blended
162        dcae_mel_overlap_len = dcae_mel_win_len // 4
163
164        # --- Vocoder Parameters ---
165        # vocoder_win_len_audio: Audio samples per vocoder processing window
166        vocoder_win_len_audio = 512 * 512  # Example: 262144 samples
167        # vocoder_overlap_len_audio: Audio samples for overlap between vocoder windows
168        vocoder_overlap_len_audio = 1024
169        # vocoder_hop_len_audio: Hop size in audio samples for vocoder processing
170        vocoder_hop_len_audio = vocoder_win_len_audio - 2 * vocoder_overlap_len_audio
171        # vocoder_input_mel_frames_per_block: Number of mel frames fed to vocoder in one go
172        vocoder_input_mel_frames_per_block = vocoder_win_len_audio // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
173
174        crossfade_len_audio = 128  # Audio samples for crossfading vocoder outputs
175        cf_win_tail = torch.linspace(1, 0, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
176        cf_win_head = torch.linspace(0, 1, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
177
178        for latent_idx, latent_item in enumerate(latents):
179            latent_item = latent_item.to(self.device)
180            current_latent = (latent_item / self.scale_factor + self.shift_factor).unsqueeze(0)  # (1, C, H, W_latent)
181            latent_len = current_latent.shape[3]
182
183            # 1. DCAE: Latent to Mel Spectrogram (Overlapped)
184            mels_segments = []
185            if latent_len == 0:
186                pass  # No mel segments to generate
187            else:
188                # Determine anchor points for DCAE windows
189                # An anchor marks a reference point for a window slice.
190                # Window slice: current_latent[..., anchor - offset : anchor - offset + win_len]
191                # First anchor ensures window starts at 0. Last anchor ensures tail is covered.
192                dcae_anchors = list(range(dcae_anchor_offset, latent_len - dcae_anchor_offset, dcae_anchor_hop))
193                if not dcae_anchors:  # If latent is too short for the range, use one anchor
194                    dcae_anchors = [dcae_anchor_offset]
195
196                for i, anchor in enumerate(dcae_anchors):
197                    win_start_idx = max(0, anchor - dcae_anchor_offset)
198                    win_end_idx = min(latent_len, win_start_idx + dcae_win_len_latent)
199
200                    dcae_input_segment = current_latent[:, :, :, win_start_idx:win_end_idx]
201                    if dcae_input_segment.shape[3] == 0:
202                        continue
203
204                    mel_output_full = self.dcae.decoder(dcae_input_segment)  # (1, C, H_mel, W_mel_fixed_from_dcae)
205
206                    is_first = i == 0
207                    is_last = i == len(dcae_anchors) - 1
208
209                    if is_first and is_last:  # Only one segment
210                        # Use mel corresponding to actual input latent length
211                        true_mel_content_len = dcae_input_segment.shape[3] * DCAE_LATENT_TO_MEL_STRIDE
212                        mel_to_keep = mel_output_full[:, :, :, : min(true_mel_content_len, mel_output_full.shape[3])]
213                    elif is_first:  # First segment, trim end overlap
214                        mel_to_keep = mel_output_full[:, :, :, :-dcae_mel_overlap_len]
215                    elif is_last:  # Last segment, trim start overlap
216                        # And ensure we only take content relevant to the (potentially partial) last latent window
217                        # The mel_output_full is fixed length. The useful part starts after overlap.
218                        # The length of the useful part depends on how much of dcae_input_segment was actual content.
219                        # For simplicity in overlap-add, typically trim fixed overlap.
220                        # If dcae_input_segment was shorter than dcae_win_len_latent, mel_output_full might contain padding effects.
221                        # Standard OLA keeps the corresponding tail.
222                        mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:]
223                    else:  # Middle segment, trim both overlaps
224                        mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:-dcae_mel_overlap_len]
225
226                    if mel_to_keep.shape[3] > 0:
227                        mels_segments.append(mel_to_keep)
228
229            if not mels_segments:
230                num_mel_channels = current_latent.shape[1]
231                mel_height = self.dcae.decoder_output_mel_height
232                concatenated_mels = torch.empty((1, num_mel_channels, mel_height, 0), device=current_latent.device, dtype=current_latent.dtype)
233            else:
234                concatenated_mels = torch.cat(mels_segments, dim=3)
235
236            # Denormalize mels
237            concatenated_mels = concatenated_mels * 0.5 + 0.5
238            concatenated_mels = concatenated_mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
239
240            mel_total_frames = concatenated_mels.shape[3]
241
242            # 2. Vocoder: Mel Spectrogram to Waveform (Overlapped)
243            if mel_total_frames == 0:
244                # Assuming mono or stereo output based on mel channels (typically mono for vocoder from single mel)
245                num_audio_channels = 1  # Or determine from vocoder capabilities / mel channels
246                final_wav = torch.zeros((num_audio_channels, 0), device=self.device, dtype=torch.float32)
247            else:
248                # Initial vocoder window
249                # Vocoder expects (C_mel, H_mel, W_mel_block)
250                mel_block = concatenated_mels[0, :, :, :vocoder_input_mel_frames_per_block].to(self.device)
251
252                # Pad mel_block if it's shorter than vocoder_input_mel_frames_per_block (e.g. very short audio)
253                if 0 < mel_block.shape[2] < vocoder_input_mel_frames_per_block:
254                    pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
255                    mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0)  # Pad last dim
256
257                current_audio_output = self.vocoder.decode(mel_block)  # (C_audio, 1, Samples)
258                current_audio_output = current_audio_output[:, :, :-vocoder_overlap_len_audio]  # Remove end overlap
259
260                # p_audio_samples tracks the start of the *next* audio segment to generate (in conceptual total audio samples)
261                p_audio_samples = vocoder_hop_len_audio
262                conceptual_total_audio_len_native_sr = mel_total_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
263
264                pbar_total = 1 + max(0, (conceptual_total_audio_len_native_sr - (vocoder_win_len_audio - vocoder_overlap_len_audio))) // vocoder_hop_len_audio
265
266                # Use tqdm if you want a progress bar for the vocoder part
267                # with tqdm(total=pbar_total, desc=f"Vocoder {latent_idx+1}/{len(latents)}", leave=False) as pbar:
268                # pbar.update(1) # For initial window
269                # The loop for subsequent windows
270                while p_audio_samples < conceptual_total_audio_len_native_sr:
271                    mel_frame_start = p_audio_samples // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
272                    mel_frame_end = mel_frame_start + vocoder_input_mel_frames_per_block
273
274                    if mel_frame_start >= mel_total_frames:
275                        break  # No more mel frames
276
277                    mel_block = concatenated_mels[0, :, :, mel_frame_start : min(mel_frame_end, mel_total_frames)].to(self.device)
278
279                    if mel_block.shape[2] == 0:
280                        break  # Should not happen if mel_frame_start is valid
281
282                    # Pad if current mel_block is too short (end of sequence)
283                    if mel_block.shape[2] < vocoder_input_mel_frames_per_block:
284                        pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
285                        mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0)
286
287                    new_audio_win = self.vocoder.decode(mel_block)  # (C_audio, 1, Samples)
288
289                    # Crossfade
290                    # Determine actual crossfade length based on available audio
291                    actual_cf_len = min(crossfade_len_audio, current_audio_output.shape[2], new_audio_win.shape[2] - (vocoder_overlap_len_audio - crossfade_len_audio))
292                    if actual_cf_len > 0:  # Ensure valid slice lengths for crossfade
293                        tail_part = current_audio_output[:, :, -actual_cf_len:]
294                        head_part = new_audio_win[:, :, vocoder_overlap_len_audio - actual_cf_len : vocoder_overlap_len_audio]
295
296                        crossfaded_segment = tail_part * cf_win_tail[:, :, :actual_cf_len] + head_part * cf_win_head[:, :, :actual_cf_len]
297
298                        current_audio_output = torch.cat([current_audio_output[:, :, :-actual_cf_len], crossfaded_segment], dim=2)
299
300                    # Append non-overlapping part of new_audio_win
301                    is_final_append = p_audio_samples + vocoder_hop_len_audio >= conceptual_total_audio_len_native_sr
302                    if is_final_append:
303                        segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:]
304                    else:
305                        segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:-vocoder_overlap_len_audio]
306
307                    current_audio_output = torch.cat([current_audio_output, segment_to_append], dim=2)
308
309                    p_audio_samples += vocoder_hop_len_audio
310                    # pbar.update(1) # if using tqdm
311
312                final_wav = current_audio_output.squeeze(1)  # (C_audio, Samples)
313
314            # 3. Resampling (if necessary)
315            if final_output_sr != MODEL_INTERNAL_SR and final_wav.numel() > 0:
316                # Resample expects CPU tensor if using torchaudio.transforms on older versions or for some backends
317                resampler = torchaudio.transforms.Resample(MODEL_INTERNAL_SR, final_output_sr, dtype=final_wav.dtype)
318                final_wav = resampler(final_wav.cpu()).to(self.device)  # Move back to device if needed later
319
320            pred_wavs.append(final_wav)
321
322        # 4. Final Truncation
323        processed_pred_wavs = []
324        for i, wav in enumerate(pred_wavs):
325            # Calculate expected length based on original latent, at the FINAL output sample rate
326            _num_latent_frames = latents[i].shape[-1]  # Use original latent item for shape
327            _num_mel_frames = _num_latent_frames * DCAE_LATENT_TO_MEL_STRIDE
328            _conceptual_native_audio_len = _num_mel_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
329            max_possible_len = int(_conceptual_native_audio_len * final_output_sr / MODEL_INTERNAL_SR)
330
331            current_wav_len = wav.shape[1]
332
333            if audio_lengths is not None:
334                # User-provided length is the primary target, capped by actual and max possible
335                target_len = min(audio_lengths[i], current_wav_len, max_possible_len)
336            else:
337                # No user length, use max possible capped by actual
338                target_len = min(max_possible_len, current_wav_len)
339
340            processed_pred_wavs.append(wav[:, : max(0, target_len)].cpu())  # Ensure length is non-negative
341
342        return final_output_sr, processed_pred_wavs
343
344    def forward(self, audios, audio_lengths=None, sr=None):
345        latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
346        sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
347        return sr, pred_wavs, latents, latent_lengths
348
349
350if __name__ == "__main__":
351    audio, sr = torchaudio.load("test.wav")
352    audio_lengths = torch.tensor([audio.shape[1]])
353    audios = audio.unsqueeze(0)
354
355    # test encode only
356    model = MusicDCAE()
357    # latents, latent_lengths = model.encode(audios, audio_lengths)
358    # print("latents shape: ", latents.shape)
359    # print("latent_lengths: ", latent_lengths)
360
361    # test encode and decode
362    sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
363    print("reconstructed wavs: ", pred_wavs[0].shape)
364    print("latents shape: ", latents.shape)
365    print("latent_lengths: ", latent_lengths)
366    print("sr: ", sr)
367    torchaudio.save("test_reconstructed.wav", pred_wavs[0], sr)
368    print("test_reconstructed.wav")
root_dir = '/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep'
DEFAULT_PRETRAINED_PATH = '/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep/checkpoints/music_dcae_f8c8'
VOCODER_PRETRAINED_PATH = '/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep/checkpoints/music_vocoder'
class MusicDCAE(diffusers.models.modeling_utils.ModelMixin, diffusers.configuration_utils.ConfigMixin, diffusers.loaders.single_file_model.FromOriginalModelMixin):
 31class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
 32    @register_to_config
 33    def __init__(
 34        self,
 35        source_sample_rate=None,
 36        dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH,
 37        vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH,
 38    ):
 39        super(MusicDCAE, self).__init__()
 40
 41        self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
 42        self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path)
 43
 44        if source_sample_rate is None:
 45            source_sample_rate = 48000
 46
 47        self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
 48
 49        self.transform = transforms.Compose(
 50            [
 51                transforms.Normalize(0.5, 0.5),
 52            ]
 53        )
 54        self.min_mel_value = -11.0
 55        self.max_mel_value = 3.0
 56        self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
 57        self.mel_chunk_size = 1024
 58        self.time_dimention_multiple = 8
 59        self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
 60        self.scale_factor = 0.1786
 61        self.shift_factor = -1.9091
 62
 63    def load_audio(self, audio_path):
 64        audio, sr = torchaudio.load(audio_path)
 65        if audio.shape[0] == 1:
 66            audio = audio.repeat(2, 1)
 67        return audio, sr
 68
 69    def forward_mel(self, audios):
 70        mels = []
 71        for i in range(len(audios)):
 72            image = self.vocoder.mel_transform(audios[i])
 73            mels.append(image)
 74        mels = torch.stack(mels)
 75        return mels
 76
 77    @torch.no_grad()
 78    def encode(self, audios, audio_lengths=None, sr=None):
 79        if audio_lengths is None:
 80            audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
 81            audio_lengths = audio_lengths.to(audios.device)
 82
 83        # audios: N x 2 x T, 48kHz
 84        device = audios.device
 85        dtype = audios.dtype
 86
 87        if sr is None:
 88            sr = 48000
 89            resampler = self.resampler
 90        else:
 91            resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
 92
 93        audio = resampler(audios)
 94
 95        max_audio_len = audio.shape[-1]
 96        if max_audio_len % (8 * 512) != 0:
 97            audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512)))
 98
 99        mels = self.forward_mel(audio)
100        mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
101        mels = self.transform(mels)
102        latents = []
103        for mel in mels:
104            latent = self.dcae.encoder(mel.unsqueeze(0))
105            latents.append(latent)
106        latents = torch.cat(latents, dim=0)
107        latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
108        latents = (latents - self.shift_factor) * self.scale_factor
109        return latents, latent_lengths
110
111    @torch.no_grad()
112    def decode(self, latents, audio_lengths=None, sr=None):
113        latents = latents / self.scale_factor + self.shift_factor
114
115        pred_wavs = []
116
117        for latent in latents:
118            mels = self.dcae.decoder(latent.unsqueeze(0))
119            mels = mels * 0.5 + 0.5
120            mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
121
122            # wav = self.vocoder.decode(mels[0]).squeeze(1)
123            # decode waveform for each channels to reduce vram footprint
124            wav_ch1 = self.vocoder.decode(mels[:, 0, :, :]).squeeze(1).cpu()
125            wav_ch2 = self.vocoder.decode(mels[:, 1, :, :]).squeeze(1).cpu()
126            wav = torch.cat([wav_ch1, wav_ch2], dim=0)
127
128            if sr is not None:
129                resampler = torchaudio.transforms.Resample(44100, sr)
130                wav = resampler(wav.cpu().float())
131            else:
132                sr = 44100
133            pred_wavs.append(wav)
134
135        if audio_lengths is not None:
136            pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
137        return sr, pred_wavs
138
139    @torch.no_grad()
140    def decode_overlap(self, latents, audio_lengths=None, sr=None):
141        """
142        Decodes latents into waveforms using an overlapped DCAE and Vocoder.
143        """
144        print("Using Overlapped DCAE and Vocoder")
145
146        MODEL_INTERNAL_SR = 44100
147        DCAE_LATENT_TO_MEL_STRIDE = 8
148        VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME = 512
149
150        pred_wavs = []
151        final_output_sr = sr if sr is not None else MODEL_INTERNAL_SR
152
153        # --- DCAE Parameters ---
154        # dcae_win_len_latent: Window length in the latent domain for DCAE processing
155        dcae_win_len_latent = 512
156        # dcae_mel_win_len: Expected mel window length from DCAE decoder output (latent_win * stride)
157        dcae_mel_win_len = dcae_win_len_latent * 8
158        # dcae_anchor_offset: Offset from anchor point to actual start of latent window slice
159        dcae_anchor_offset = dcae_win_len_latent // 4
160        # dcae_anchor_hop: Hop size for anchor points in latent domain
161        dcae_anchor_hop = dcae_win_len_latent // 2
162        # dcae_mel_overlap_len: Overlap length in the mel domain to be trimmed/blended
163        dcae_mel_overlap_len = dcae_mel_win_len // 4
164
165        # --- Vocoder Parameters ---
166        # vocoder_win_len_audio: Audio samples per vocoder processing window
167        vocoder_win_len_audio = 512 * 512  # Example: 262144 samples
168        # vocoder_overlap_len_audio: Audio samples for overlap between vocoder windows
169        vocoder_overlap_len_audio = 1024
170        # vocoder_hop_len_audio: Hop size in audio samples for vocoder processing
171        vocoder_hop_len_audio = vocoder_win_len_audio - 2 * vocoder_overlap_len_audio
172        # vocoder_input_mel_frames_per_block: Number of mel frames fed to vocoder in one go
173        vocoder_input_mel_frames_per_block = vocoder_win_len_audio // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
174
175        crossfade_len_audio = 128  # Audio samples for crossfading vocoder outputs
176        cf_win_tail = torch.linspace(1, 0, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
177        cf_win_head = torch.linspace(0, 1, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
178
179        for latent_idx, latent_item in enumerate(latents):
180            latent_item = latent_item.to(self.device)
181            current_latent = (latent_item / self.scale_factor + self.shift_factor).unsqueeze(0)  # (1, C, H, W_latent)
182            latent_len = current_latent.shape[3]
183
184            # 1. DCAE: Latent to Mel Spectrogram (Overlapped)
185            mels_segments = []
186            if latent_len == 0:
187                pass  # No mel segments to generate
188            else:
189                # Determine anchor points for DCAE windows
190                # An anchor marks a reference point for a window slice.
191                # Window slice: current_latent[..., anchor - offset : anchor - offset + win_len]
192                # First anchor ensures window starts at 0. Last anchor ensures tail is covered.
193                dcae_anchors = list(range(dcae_anchor_offset, latent_len - dcae_anchor_offset, dcae_anchor_hop))
194                if not dcae_anchors:  # If latent is too short for the range, use one anchor
195                    dcae_anchors = [dcae_anchor_offset]
196
197                for i, anchor in enumerate(dcae_anchors):
198                    win_start_idx = max(0, anchor - dcae_anchor_offset)
199                    win_end_idx = min(latent_len, win_start_idx + dcae_win_len_latent)
200
201                    dcae_input_segment = current_latent[:, :, :, win_start_idx:win_end_idx]
202                    if dcae_input_segment.shape[3] == 0:
203                        continue
204
205                    mel_output_full = self.dcae.decoder(dcae_input_segment)  # (1, C, H_mel, W_mel_fixed_from_dcae)
206
207                    is_first = i == 0
208                    is_last = i == len(dcae_anchors) - 1
209
210                    if is_first and is_last:  # Only one segment
211                        # Use mel corresponding to actual input latent length
212                        true_mel_content_len = dcae_input_segment.shape[3] * DCAE_LATENT_TO_MEL_STRIDE
213                        mel_to_keep = mel_output_full[:, :, :, : min(true_mel_content_len, mel_output_full.shape[3])]
214                    elif is_first:  # First segment, trim end overlap
215                        mel_to_keep = mel_output_full[:, :, :, :-dcae_mel_overlap_len]
216                    elif is_last:  # Last segment, trim start overlap
217                        # And ensure we only take content relevant to the (potentially partial) last latent window
218                        # The mel_output_full is fixed length. The useful part starts after overlap.
219                        # The length of the useful part depends on how much of dcae_input_segment was actual content.
220                        # For simplicity in overlap-add, typically trim fixed overlap.
221                        # If dcae_input_segment was shorter than dcae_win_len_latent, mel_output_full might contain padding effects.
222                        # Standard OLA keeps the corresponding tail.
223                        mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:]
224                    else:  # Middle segment, trim both overlaps
225                        mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:-dcae_mel_overlap_len]
226
227                    if mel_to_keep.shape[3] > 0:
228                        mels_segments.append(mel_to_keep)
229
230            if not mels_segments:
231                num_mel_channels = current_latent.shape[1]
232                mel_height = self.dcae.decoder_output_mel_height
233                concatenated_mels = torch.empty((1, num_mel_channels, mel_height, 0), device=current_latent.device, dtype=current_latent.dtype)
234            else:
235                concatenated_mels = torch.cat(mels_segments, dim=3)
236
237            # Denormalize mels
238            concatenated_mels = concatenated_mels * 0.5 + 0.5
239            concatenated_mels = concatenated_mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
240
241            mel_total_frames = concatenated_mels.shape[3]
242
243            # 2. Vocoder: Mel Spectrogram to Waveform (Overlapped)
244            if mel_total_frames == 0:
245                # Assuming mono or stereo output based on mel channels (typically mono for vocoder from single mel)
246                num_audio_channels = 1  # Or determine from vocoder capabilities / mel channels
247                final_wav = torch.zeros((num_audio_channels, 0), device=self.device, dtype=torch.float32)
248            else:
249                # Initial vocoder window
250                # Vocoder expects (C_mel, H_mel, W_mel_block)
251                mel_block = concatenated_mels[0, :, :, :vocoder_input_mel_frames_per_block].to(self.device)
252
253                # Pad mel_block if it's shorter than vocoder_input_mel_frames_per_block (e.g. very short audio)
254                if 0 < mel_block.shape[2] < vocoder_input_mel_frames_per_block:
255                    pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
256                    mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0)  # Pad last dim
257
258                current_audio_output = self.vocoder.decode(mel_block)  # (C_audio, 1, Samples)
259                current_audio_output = current_audio_output[:, :, :-vocoder_overlap_len_audio]  # Remove end overlap
260
261                # p_audio_samples tracks the start of the *next* audio segment to generate (in conceptual total audio samples)
262                p_audio_samples = vocoder_hop_len_audio
263                conceptual_total_audio_len_native_sr = mel_total_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
264
265                pbar_total = 1 + max(0, (conceptual_total_audio_len_native_sr - (vocoder_win_len_audio - vocoder_overlap_len_audio))) // vocoder_hop_len_audio
266
267                # Use tqdm if you want a progress bar for the vocoder part
268                # with tqdm(total=pbar_total, desc=f"Vocoder {latent_idx+1}/{len(latents)}", leave=False) as pbar:
269                # pbar.update(1) # For initial window
270                # The loop for subsequent windows
271                while p_audio_samples < conceptual_total_audio_len_native_sr:
272                    mel_frame_start = p_audio_samples // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
273                    mel_frame_end = mel_frame_start + vocoder_input_mel_frames_per_block
274
275                    if mel_frame_start >= mel_total_frames:
276                        break  # No more mel frames
277
278                    mel_block = concatenated_mels[0, :, :, mel_frame_start : min(mel_frame_end, mel_total_frames)].to(self.device)
279
280                    if mel_block.shape[2] == 0:
281                        break  # Should not happen if mel_frame_start is valid
282
283                    # Pad if current mel_block is too short (end of sequence)
284                    if mel_block.shape[2] < vocoder_input_mel_frames_per_block:
285                        pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
286                        mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0)
287
288                    new_audio_win = self.vocoder.decode(mel_block)  # (C_audio, 1, Samples)
289
290                    # Crossfade
291                    # Determine actual crossfade length based on available audio
292                    actual_cf_len = min(crossfade_len_audio, current_audio_output.shape[2], new_audio_win.shape[2] - (vocoder_overlap_len_audio - crossfade_len_audio))
293                    if actual_cf_len > 0:  # Ensure valid slice lengths for crossfade
294                        tail_part = current_audio_output[:, :, -actual_cf_len:]
295                        head_part = new_audio_win[:, :, vocoder_overlap_len_audio - actual_cf_len : vocoder_overlap_len_audio]
296
297                        crossfaded_segment = tail_part * cf_win_tail[:, :, :actual_cf_len] + head_part * cf_win_head[:, :, :actual_cf_len]
298
299                        current_audio_output = torch.cat([current_audio_output[:, :, :-actual_cf_len], crossfaded_segment], dim=2)
300
301                    # Append non-overlapping part of new_audio_win
302                    is_final_append = p_audio_samples + vocoder_hop_len_audio >= conceptual_total_audio_len_native_sr
303                    if is_final_append:
304                        segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:]
305                    else:
306                        segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:-vocoder_overlap_len_audio]
307
308                    current_audio_output = torch.cat([current_audio_output, segment_to_append], dim=2)
309
310                    p_audio_samples += vocoder_hop_len_audio
311                    # pbar.update(1) # if using tqdm
312
313                final_wav = current_audio_output.squeeze(1)  # (C_audio, Samples)
314
315            # 3. Resampling (if necessary)
316            if final_output_sr != MODEL_INTERNAL_SR and final_wav.numel() > 0:
317                # Resample expects CPU tensor if using torchaudio.transforms on older versions or for some backends
318                resampler = torchaudio.transforms.Resample(MODEL_INTERNAL_SR, final_output_sr, dtype=final_wav.dtype)
319                final_wav = resampler(final_wav.cpu()).to(self.device)  # Move back to device if needed later
320
321            pred_wavs.append(final_wav)
322
323        # 4. Final Truncation
324        processed_pred_wavs = []
325        for i, wav in enumerate(pred_wavs):
326            # Calculate expected length based on original latent, at the FINAL output sample rate
327            _num_latent_frames = latents[i].shape[-1]  # Use original latent item for shape
328            _num_mel_frames = _num_latent_frames * DCAE_LATENT_TO_MEL_STRIDE
329            _conceptual_native_audio_len = _num_mel_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
330            max_possible_len = int(_conceptual_native_audio_len * final_output_sr / MODEL_INTERNAL_SR)
331
332            current_wav_len = wav.shape[1]
333
334            if audio_lengths is not None:
335                # User-provided length is the primary target, capped by actual and max possible
336                target_len = min(audio_lengths[i], current_wav_len, max_possible_len)
337            else:
338                # No user length, use max possible capped by actual
339                target_len = min(max_possible_len, current_wav_len)
340
341            processed_pred_wavs.append(wav[:, : max(0, target_len)].cpu())  # Ensure length is non-negative
342
343        return final_output_sr, processed_pred_wavs
344
345    def forward(self, audios, audio_lengths=None, sr=None):
346        latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
347        sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
348        return sr, pred_wavs, latents, latent_lengths

Base class for all models.

[ModelMixin] takes care of storing the model configuration and provides methods for loading, downloading and saving models.

- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
@register_to_config
MusicDCAE( source_sample_rate=None, dcae_checkpoint_path='/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep/checkpoints/music_dcae_f8c8', vocoder_checkpoint_path='/Users/e6d64/Documents/GitHub/darkshapes/divisor/divisor/acestep/checkpoints/music_vocoder')
32    @register_to_config
33    def __init__(
34        self,
35        source_sample_rate=None,
36        dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH,
37        vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH,
38    ):
39        super(MusicDCAE, self).__init__()
40
41        self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
42        self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path)
43
44        if source_sample_rate is None:
45            source_sample_rate = 48000
46
47        self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
48
49        self.transform = transforms.Compose(
50            [
51                transforms.Normalize(0.5, 0.5),
52            ]
53        )
54        self.min_mel_value = -11.0
55        self.max_mel_value = 3.0
56        self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
57        self.mel_chunk_size = 1024
58        self.time_dimention_multiple = 8
59        self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
60        self.scale_factor = 0.1786
61        self.shift_factor = -1.9091

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dcae
vocoder
resampler
transform
min_mel_value
max_mel_value
audio_chunk_size
mel_chunk_size
time_dimention_multiple
latent_chunk_size
scale_factor
shift_factor
def load_audio(self, audio_path):
63    def load_audio(self, audio_path):
64        audio, sr = torchaudio.load(audio_path)
65        if audio.shape[0] == 1:
66            audio = audio.repeat(2, 1)
67        return audio, sr
def forward_mel(self, audios):
69    def forward_mel(self, audios):
70        mels = []
71        for i in range(len(audios)):
72            image = self.vocoder.mel_transform(audios[i])
73            mels.append(image)
74        mels = torch.stack(mels)
75        return mels
@torch.no_grad()
def encode(self, audios, audio_lengths=None, sr=None):
 77    @torch.no_grad()
 78    def encode(self, audios, audio_lengths=None, sr=None):
 79        if audio_lengths is None:
 80            audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
 81            audio_lengths = audio_lengths.to(audios.device)
 82
 83        # audios: N x 2 x T, 48kHz
 84        device = audios.device
 85        dtype = audios.dtype
 86
 87        if sr is None:
 88            sr = 48000
 89            resampler = self.resampler
 90        else:
 91            resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
 92
 93        audio = resampler(audios)
 94
 95        max_audio_len = audio.shape[-1]
 96        if max_audio_len % (8 * 512) != 0:
 97            audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512)))
 98
 99        mels = self.forward_mel(audio)
100        mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
101        mels = self.transform(mels)
102        latents = []
103        for mel in mels:
104            latent = self.dcae.encoder(mel.unsqueeze(0))
105            latents.append(latent)
106        latents = torch.cat(latents, dim=0)
107        latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
108        latents = (latents - self.shift_factor) * self.scale_factor
109        return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
111    @torch.no_grad()
112    def decode(self, latents, audio_lengths=None, sr=None):
113        latents = latents / self.scale_factor + self.shift_factor
114
115        pred_wavs = []
116
117        for latent in latents:
118            mels = self.dcae.decoder(latent.unsqueeze(0))
119            mels = mels * 0.5 + 0.5
120            mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
121
122            # wav = self.vocoder.decode(mels[0]).squeeze(1)
123            # decode waveform for each channels to reduce vram footprint
124            wav_ch1 = self.vocoder.decode(mels[:, 0, :, :]).squeeze(1).cpu()
125            wav_ch2 = self.vocoder.decode(mels[:, 1, :, :]).squeeze(1).cpu()
126            wav = torch.cat([wav_ch1, wav_ch2], dim=0)
127
128            if sr is not None:
129                resampler = torchaudio.transforms.Resample(44100, sr)
130                wav = resampler(wav.cpu().float())
131            else:
132                sr = 44100
133            pred_wavs.append(wav)
134
135        if audio_lengths is not None:
136            pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
137        return sr, pred_wavs
@torch.no_grad()
def decode_overlap(self, latents, audio_lengths=None, sr=None):
139    @torch.no_grad()
140    def decode_overlap(self, latents, audio_lengths=None, sr=None):
141        """
142        Decodes latents into waveforms using an overlapped DCAE and Vocoder.
143        """
144        print("Using Overlapped DCAE and Vocoder")
145
146        MODEL_INTERNAL_SR = 44100
147        DCAE_LATENT_TO_MEL_STRIDE = 8
148        VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME = 512
149
150        pred_wavs = []
151        final_output_sr = sr if sr is not None else MODEL_INTERNAL_SR
152
153        # --- DCAE Parameters ---
154        # dcae_win_len_latent: Window length in the latent domain for DCAE processing
155        dcae_win_len_latent = 512
156        # dcae_mel_win_len: Expected mel window length from DCAE decoder output (latent_win * stride)
157        dcae_mel_win_len = dcae_win_len_latent * 8
158        # dcae_anchor_offset: Offset from anchor point to actual start of latent window slice
159        dcae_anchor_offset = dcae_win_len_latent // 4
160        # dcae_anchor_hop: Hop size for anchor points in latent domain
161        dcae_anchor_hop = dcae_win_len_latent // 2
162        # dcae_mel_overlap_len: Overlap length in the mel domain to be trimmed/blended
163        dcae_mel_overlap_len = dcae_mel_win_len // 4
164
165        # --- Vocoder Parameters ---
166        # vocoder_win_len_audio: Audio samples per vocoder processing window
167        vocoder_win_len_audio = 512 * 512  # Example: 262144 samples
168        # vocoder_overlap_len_audio: Audio samples for overlap between vocoder windows
169        vocoder_overlap_len_audio = 1024
170        # vocoder_hop_len_audio: Hop size in audio samples for vocoder processing
171        vocoder_hop_len_audio = vocoder_win_len_audio - 2 * vocoder_overlap_len_audio
172        # vocoder_input_mel_frames_per_block: Number of mel frames fed to vocoder in one go
173        vocoder_input_mel_frames_per_block = vocoder_win_len_audio // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
174
175        crossfade_len_audio = 128  # Audio samples for crossfading vocoder outputs
176        cf_win_tail = torch.linspace(1, 0, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
177        cf_win_head = torch.linspace(0, 1, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
178
179        for latent_idx, latent_item in enumerate(latents):
180            latent_item = latent_item.to(self.device)
181            current_latent = (latent_item / self.scale_factor + self.shift_factor).unsqueeze(0)  # (1, C, H, W_latent)
182            latent_len = current_latent.shape[3]
183
184            # 1. DCAE: Latent to Mel Spectrogram (Overlapped)
185            mels_segments = []
186            if latent_len == 0:
187                pass  # No mel segments to generate
188            else:
189                # Determine anchor points for DCAE windows
190                # An anchor marks a reference point for a window slice.
191                # Window slice: current_latent[..., anchor - offset : anchor - offset + win_len]
192                # First anchor ensures window starts at 0. Last anchor ensures tail is covered.
193                dcae_anchors = list(range(dcae_anchor_offset, latent_len - dcae_anchor_offset, dcae_anchor_hop))
194                if not dcae_anchors:  # If latent is too short for the range, use one anchor
195                    dcae_anchors = [dcae_anchor_offset]
196
197                for i, anchor in enumerate(dcae_anchors):
198                    win_start_idx = max(0, anchor - dcae_anchor_offset)
199                    win_end_idx = min(latent_len, win_start_idx + dcae_win_len_latent)
200
201                    dcae_input_segment = current_latent[:, :, :, win_start_idx:win_end_idx]
202                    if dcae_input_segment.shape[3] == 0:
203                        continue
204
205                    mel_output_full = self.dcae.decoder(dcae_input_segment)  # (1, C, H_mel, W_mel_fixed_from_dcae)
206
207                    is_first = i == 0
208                    is_last = i == len(dcae_anchors) - 1
209
210                    if is_first and is_last:  # Only one segment
211                        # Use mel corresponding to actual input latent length
212                        true_mel_content_len = dcae_input_segment.shape[3] * DCAE_LATENT_TO_MEL_STRIDE
213                        mel_to_keep = mel_output_full[:, :, :, : min(true_mel_content_len, mel_output_full.shape[3])]
214                    elif is_first:  # First segment, trim end overlap
215                        mel_to_keep = mel_output_full[:, :, :, :-dcae_mel_overlap_len]
216                    elif is_last:  # Last segment, trim start overlap
217                        # And ensure we only take content relevant to the (potentially partial) last latent window
218                        # The mel_output_full is fixed length. The useful part starts after overlap.
219                        # The length of the useful part depends on how much of dcae_input_segment was actual content.
220                        # For simplicity in overlap-add, typically trim fixed overlap.
221                        # If dcae_input_segment was shorter than dcae_win_len_latent, mel_output_full might contain padding effects.
222                        # Standard OLA keeps the corresponding tail.
223                        mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:]
224                    else:  # Middle segment, trim both overlaps
225                        mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:-dcae_mel_overlap_len]
226
227                    if mel_to_keep.shape[3] > 0:
228                        mels_segments.append(mel_to_keep)
229
230            if not mels_segments:
231                num_mel_channels = current_latent.shape[1]
232                mel_height = self.dcae.decoder_output_mel_height
233                concatenated_mels = torch.empty((1, num_mel_channels, mel_height, 0), device=current_latent.device, dtype=current_latent.dtype)
234            else:
235                concatenated_mels = torch.cat(mels_segments, dim=3)
236
237            # Denormalize mels
238            concatenated_mels = concatenated_mels * 0.5 + 0.5
239            concatenated_mels = concatenated_mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
240
241            mel_total_frames = concatenated_mels.shape[3]
242
243            # 2. Vocoder: Mel Spectrogram to Waveform (Overlapped)
244            if mel_total_frames == 0:
245                # Assuming mono or stereo output based on mel channels (typically mono for vocoder from single mel)
246                num_audio_channels = 1  # Or determine from vocoder capabilities / mel channels
247                final_wav = torch.zeros((num_audio_channels, 0), device=self.device, dtype=torch.float32)
248            else:
249                # Initial vocoder window
250                # Vocoder expects (C_mel, H_mel, W_mel_block)
251                mel_block = concatenated_mels[0, :, :, :vocoder_input_mel_frames_per_block].to(self.device)
252
253                # Pad mel_block if it's shorter than vocoder_input_mel_frames_per_block (e.g. very short audio)
254                if 0 < mel_block.shape[2] < vocoder_input_mel_frames_per_block:
255                    pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
256                    mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0)  # Pad last dim
257
258                current_audio_output = self.vocoder.decode(mel_block)  # (C_audio, 1, Samples)
259                current_audio_output = current_audio_output[:, :, :-vocoder_overlap_len_audio]  # Remove end overlap
260
261                # p_audio_samples tracks the start of the *next* audio segment to generate (in conceptual total audio samples)
262                p_audio_samples = vocoder_hop_len_audio
263                conceptual_total_audio_len_native_sr = mel_total_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
264
265                pbar_total = 1 + max(0, (conceptual_total_audio_len_native_sr - (vocoder_win_len_audio - vocoder_overlap_len_audio))) // vocoder_hop_len_audio
266
267                # Use tqdm if you want a progress bar for the vocoder part
268                # with tqdm(total=pbar_total, desc=f"Vocoder {latent_idx+1}/{len(latents)}", leave=False) as pbar:
269                # pbar.update(1) # For initial window
270                # The loop for subsequent windows
271                while p_audio_samples < conceptual_total_audio_len_native_sr:
272                    mel_frame_start = p_audio_samples // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
273                    mel_frame_end = mel_frame_start + vocoder_input_mel_frames_per_block
274
275                    if mel_frame_start >= mel_total_frames:
276                        break  # No more mel frames
277
278                    mel_block = concatenated_mels[0, :, :, mel_frame_start : min(mel_frame_end, mel_total_frames)].to(self.device)
279
280                    if mel_block.shape[2] == 0:
281                        break  # Should not happen if mel_frame_start is valid
282
283                    # Pad if current mel_block is too short (end of sequence)
284                    if mel_block.shape[2] < vocoder_input_mel_frames_per_block:
285                        pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
286                        mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode="constant", value=0)
287
288                    new_audio_win = self.vocoder.decode(mel_block)  # (C_audio, 1, Samples)
289
290                    # Crossfade
291                    # Determine actual crossfade length based on available audio
292                    actual_cf_len = min(crossfade_len_audio, current_audio_output.shape[2], new_audio_win.shape[2] - (vocoder_overlap_len_audio - crossfade_len_audio))
293                    if actual_cf_len > 0:  # Ensure valid slice lengths for crossfade
294                        tail_part = current_audio_output[:, :, -actual_cf_len:]
295                        head_part = new_audio_win[:, :, vocoder_overlap_len_audio - actual_cf_len : vocoder_overlap_len_audio]
296
297                        crossfaded_segment = tail_part * cf_win_tail[:, :, :actual_cf_len] + head_part * cf_win_head[:, :, :actual_cf_len]
298
299                        current_audio_output = torch.cat([current_audio_output[:, :, :-actual_cf_len], crossfaded_segment], dim=2)
300
301                    # Append non-overlapping part of new_audio_win
302                    is_final_append = p_audio_samples + vocoder_hop_len_audio >= conceptual_total_audio_len_native_sr
303                    if is_final_append:
304                        segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:]
305                    else:
306                        segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:-vocoder_overlap_len_audio]
307
308                    current_audio_output = torch.cat([current_audio_output, segment_to_append], dim=2)
309
310                    p_audio_samples += vocoder_hop_len_audio
311                    # pbar.update(1) # if using tqdm
312
313                final_wav = current_audio_output.squeeze(1)  # (C_audio, Samples)
314
315            # 3. Resampling (if necessary)
316            if final_output_sr != MODEL_INTERNAL_SR and final_wav.numel() > 0:
317                # Resample expects CPU tensor if using torchaudio.transforms on older versions or for some backends
318                resampler = torchaudio.transforms.Resample(MODEL_INTERNAL_SR, final_output_sr, dtype=final_wav.dtype)
319                final_wav = resampler(final_wav.cpu()).to(self.device)  # Move back to device if needed later
320
321            pred_wavs.append(final_wav)
322
323        # 4. Final Truncation
324        processed_pred_wavs = []
325        for i, wav in enumerate(pred_wavs):
326            # Calculate expected length based on original latent, at the FINAL output sample rate
327            _num_latent_frames = latents[i].shape[-1]  # Use original latent item for shape
328            _num_mel_frames = _num_latent_frames * DCAE_LATENT_TO_MEL_STRIDE
329            _conceptual_native_audio_len = _num_mel_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
330            max_possible_len = int(_conceptual_native_audio_len * final_output_sr / MODEL_INTERNAL_SR)
331
332            current_wav_len = wav.shape[1]
333
334            if audio_lengths is not None:
335                # User-provided length is the primary target, capped by actual and max possible
336                target_len = min(audio_lengths[i], current_wav_len, max_possible_len)
337            else:
338                # No user length, use max possible capped by actual
339                target_len = min(max_possible_len, current_wav_len)
340
341            processed_pred_wavs.append(wav[:, : max(0, target_len)].cpu())  # Ensure length is non-negative
342
343        return final_output_sr, processed_pred_wavs

Decodes latents into waveforms using an overlapped DCAE and Vocoder.

def forward(self, audios, audio_lengths=None, sr=None):
345    def forward(self, audios, audio_lengths=None, sr=None):
346        latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
347        sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
348        return sr, pred_wavs, latents, latent_lengths

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.