divisor.acestep.music_dcae.music_log_mel
ACE-Step: A Step Towards Music Generation Foundation Model
https://github.com/ace-step/ACE-Step
Apache 2.0 License
1""" 2ACE-Step: A Step Towards Music Generation Foundation Model 3 4https://github.com/ace-step/ACE-Step 5 6Apache 2.0 License 7""" 8 9import torch 10import torch.nn as nn 11from torch import Tensor 12from torchaudio.transforms import MelScale 13 14 15class LinearSpectrogram(nn.Module): 16 def __init__( 17 self, 18 n_fft=2048, 19 win_length=2048, 20 hop_length=512, 21 center=False, 22 mode="pow2_sqrt", 23 ): 24 super().__init__() 25 26 self.n_fft = n_fft 27 self.win_length = win_length 28 self.hop_length = hop_length 29 self.center = center 30 self.mode = mode 31 32 self.register_buffer("window", torch.hann_window(win_length)) 33 34 def forward(self, y: Tensor) -> Tensor: 35 if y.ndim == 3: 36 y = y.squeeze(1) 37 38 y = torch.nn.functional.pad( 39 y.unsqueeze(1), 40 ( 41 (self.win_length - self.hop_length) // 2, 42 (self.win_length - self.hop_length + 1) // 2, 43 ), 44 mode="reflect", 45 ).squeeze(1) 46 dtype = y.dtype 47 spec = torch.stft( 48 y.float(), 49 self.n_fft, 50 hop_length=self.hop_length, 51 win_length=self.win_length, 52 window=self.window, 53 center=self.center, 54 pad_mode="reflect", 55 normalized=False, 56 onesided=True, 57 return_complex=True, 58 ) 59 spec = torch.view_as_real(spec) 60 61 if self.mode == "pow2_sqrt": 62 spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 63 spec = spec.to(dtype) 64 return spec 65 66 67class LogMelSpectrogram(nn.Module): 68 def __init__( 69 self, 70 sample_rate=44100, 71 n_fft=2048, 72 win_length=2048, 73 hop_length=512, 74 n_mels=128, 75 center=False, 76 f_min=0.0, 77 f_max=None, 78 ): 79 super().__init__() 80 81 self.sample_rate = sample_rate 82 self.n_fft = n_fft 83 self.win_length = win_length 84 self.hop_length = hop_length 85 self.center = center 86 self.n_mels = n_mels 87 self.f_min = f_min 88 self.f_max = f_max or sample_rate // 2 89 90 self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) 91 self.mel_scale = MelScale( 92 self.n_mels, 93 self.sample_rate, 94 self.f_min, 95 self.f_max, 96 self.n_fft // 2 + 1, 97 "slaney", 98 "slaney", 99 ) 100 101 def compress(self, x: Tensor) -> Tensor: 102 return torch.log(torch.clamp(x, min=1e-5)) 103 104 def decompress(self, x: Tensor) -> Tensor: 105 return torch.exp(x) 106 107 def forward(self, x: Tensor, return_linear: bool = False) -> Tensor: 108 linear = self.spectrogram(x) 109 x = self.mel_scale(linear) 110 x = self.compress(x) 111 # print(x.shape) 112 if return_linear: 113 return x, self.compress(linear) 114 115 return x
16class LinearSpectrogram(nn.Module): 17 def __init__( 18 self, 19 n_fft=2048, 20 win_length=2048, 21 hop_length=512, 22 center=False, 23 mode="pow2_sqrt", 24 ): 25 super().__init__() 26 27 self.n_fft = n_fft 28 self.win_length = win_length 29 self.hop_length = hop_length 30 self.center = center 31 self.mode = mode 32 33 self.register_buffer("window", torch.hann_window(win_length)) 34 35 def forward(self, y: Tensor) -> Tensor: 36 if y.ndim == 3: 37 y = y.squeeze(1) 38 39 y = torch.nn.functional.pad( 40 y.unsqueeze(1), 41 ( 42 (self.win_length - self.hop_length) // 2, 43 (self.win_length - self.hop_length + 1) // 2, 44 ), 45 mode="reflect", 46 ).squeeze(1) 47 dtype = y.dtype 48 spec = torch.stft( 49 y.float(), 50 self.n_fft, 51 hop_length=self.hop_length, 52 win_length=self.win_length, 53 window=self.window, 54 center=self.center, 55 pad_mode="reflect", 56 normalized=False, 57 onesided=True, 58 return_complex=True, 59 ) 60 spec = torch.view_as_real(spec) 61 62 if self.mode == "pow2_sqrt": 63 spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 64 spec = spec.to(dtype) 65 return spec
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
17 def __init__( 18 self, 19 n_fft=2048, 20 win_length=2048, 21 hop_length=512, 22 center=False, 23 mode="pow2_sqrt", 24 ): 25 super().__init__() 26 27 self.n_fft = n_fft 28 self.win_length = win_length 29 self.hop_length = hop_length 30 self.center = center 31 self.mode = mode 32 33 self.register_buffer("window", torch.hann_window(win_length))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
35 def forward(self, y: Tensor) -> Tensor: 36 if y.ndim == 3: 37 y = y.squeeze(1) 38 39 y = torch.nn.functional.pad( 40 y.unsqueeze(1), 41 ( 42 (self.win_length - self.hop_length) // 2, 43 (self.win_length - self.hop_length + 1) // 2, 44 ), 45 mode="reflect", 46 ).squeeze(1) 47 dtype = y.dtype 48 spec = torch.stft( 49 y.float(), 50 self.n_fft, 51 hop_length=self.hop_length, 52 win_length=self.win_length, 53 window=self.window, 54 center=self.center, 55 pad_mode="reflect", 56 normalized=False, 57 onesided=True, 58 return_complex=True, 59 ) 60 spec = torch.view_as_real(spec) 61 62 if self.mode == "pow2_sqrt": 63 spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 64 spec = spec.to(dtype) 65 return spec
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
68class LogMelSpectrogram(nn.Module): 69 def __init__( 70 self, 71 sample_rate=44100, 72 n_fft=2048, 73 win_length=2048, 74 hop_length=512, 75 n_mels=128, 76 center=False, 77 f_min=0.0, 78 f_max=None, 79 ): 80 super().__init__() 81 82 self.sample_rate = sample_rate 83 self.n_fft = n_fft 84 self.win_length = win_length 85 self.hop_length = hop_length 86 self.center = center 87 self.n_mels = n_mels 88 self.f_min = f_min 89 self.f_max = f_max or sample_rate // 2 90 91 self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) 92 self.mel_scale = MelScale( 93 self.n_mels, 94 self.sample_rate, 95 self.f_min, 96 self.f_max, 97 self.n_fft // 2 + 1, 98 "slaney", 99 "slaney", 100 ) 101 102 def compress(self, x: Tensor) -> Tensor: 103 return torch.log(torch.clamp(x, min=1e-5)) 104 105 def decompress(self, x: Tensor) -> Tensor: 106 return torch.exp(x) 107 108 def forward(self, x: Tensor, return_linear: bool = False) -> Tensor: 109 linear = self.spectrogram(x) 110 x = self.mel_scale(linear) 111 x = self.compress(x) 112 # print(x.shape) 113 if return_linear: 114 return x, self.compress(linear) 115 116 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
69 def __init__( 70 self, 71 sample_rate=44100, 72 n_fft=2048, 73 win_length=2048, 74 hop_length=512, 75 n_mels=128, 76 center=False, 77 f_min=0.0, 78 f_max=None, 79 ): 80 super().__init__() 81 82 self.sample_rate = sample_rate 83 self.n_fft = n_fft 84 self.win_length = win_length 85 self.hop_length = hop_length 86 self.center = center 87 self.n_mels = n_mels 88 self.f_min = f_min 89 self.f_max = f_max or sample_rate // 2 90 91 self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) 92 self.mel_scale = MelScale( 93 self.n_mels, 94 self.sample_rate, 95 self.f_min, 96 self.f_max, 97 self.n_fft // 2 + 1, 98 "slaney", 99 "slaney", 100 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
108 def forward(self, x: Tensor, return_linear: bool = False) -> Tensor: 109 linear = self.spectrogram(x) 110 x = self.mel_scale(linear) 111 x = self.compress(x) 112 # print(x.shape) 113 if return_linear: 114 return x, self.compress(linear) 115 116 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.