divisor.acestep.music_dcae.music_log_mel

ACE-Step: A Step Towards Music Generation Foundation Model

https://github.com/ace-step/ACE-Step

Apache 2.0 License

  1"""
  2ACE-Step: A Step Towards Music Generation Foundation Model
  3
  4https://github.com/ace-step/ACE-Step
  5
  6Apache 2.0 License
  7"""
  8
  9import torch
 10import torch.nn as nn
 11from torch import Tensor
 12from torchaudio.transforms import MelScale
 13
 14
 15class LinearSpectrogram(nn.Module):
 16    def __init__(
 17        self,
 18        n_fft=2048,
 19        win_length=2048,
 20        hop_length=512,
 21        center=False,
 22        mode="pow2_sqrt",
 23    ):
 24        super().__init__()
 25
 26        self.n_fft = n_fft
 27        self.win_length = win_length
 28        self.hop_length = hop_length
 29        self.center = center
 30        self.mode = mode
 31
 32        self.register_buffer("window", torch.hann_window(win_length))
 33
 34    def forward(self, y: Tensor) -> Tensor:
 35        if y.ndim == 3:
 36            y = y.squeeze(1)
 37
 38        y = torch.nn.functional.pad(
 39            y.unsqueeze(1),
 40            (
 41                (self.win_length - self.hop_length) // 2,
 42                (self.win_length - self.hop_length + 1) // 2,
 43            ),
 44            mode="reflect",
 45        ).squeeze(1)
 46        dtype = y.dtype
 47        spec = torch.stft(
 48            y.float(),
 49            self.n_fft,
 50            hop_length=self.hop_length,
 51            win_length=self.win_length,
 52            window=self.window,
 53            center=self.center,
 54            pad_mode="reflect",
 55            normalized=False,
 56            onesided=True,
 57            return_complex=True,
 58        )
 59        spec = torch.view_as_real(spec)
 60
 61        if self.mode == "pow2_sqrt":
 62            spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
 63        spec = spec.to(dtype)
 64        return spec
 65
 66
 67class LogMelSpectrogram(nn.Module):
 68    def __init__(
 69        self,
 70        sample_rate=44100,
 71        n_fft=2048,
 72        win_length=2048,
 73        hop_length=512,
 74        n_mels=128,
 75        center=False,
 76        f_min=0.0,
 77        f_max=None,
 78    ):
 79        super().__init__()
 80
 81        self.sample_rate = sample_rate
 82        self.n_fft = n_fft
 83        self.win_length = win_length
 84        self.hop_length = hop_length
 85        self.center = center
 86        self.n_mels = n_mels
 87        self.f_min = f_min
 88        self.f_max = f_max or sample_rate // 2
 89
 90        self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
 91        self.mel_scale = MelScale(
 92            self.n_mels,
 93            self.sample_rate,
 94            self.f_min,
 95            self.f_max,
 96            self.n_fft // 2 + 1,
 97            "slaney",
 98            "slaney",
 99        )
100
101    def compress(self, x: Tensor) -> Tensor:
102        return torch.log(torch.clamp(x, min=1e-5))
103
104    def decompress(self, x: Tensor) -> Tensor:
105        return torch.exp(x)
106
107    def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
108        linear = self.spectrogram(x)
109        x = self.mel_scale(linear)
110        x = self.compress(x)
111        # print(x.shape)
112        if return_linear:
113            return x, self.compress(linear)
114
115        return x
class LinearSpectrogram(torch.nn.modules.module.Module):
16class LinearSpectrogram(nn.Module):
17    def __init__(
18        self,
19        n_fft=2048,
20        win_length=2048,
21        hop_length=512,
22        center=False,
23        mode="pow2_sqrt",
24    ):
25        super().__init__()
26
27        self.n_fft = n_fft
28        self.win_length = win_length
29        self.hop_length = hop_length
30        self.center = center
31        self.mode = mode
32
33        self.register_buffer("window", torch.hann_window(win_length))
34
35    def forward(self, y: Tensor) -> Tensor:
36        if y.ndim == 3:
37            y = y.squeeze(1)
38
39        y = torch.nn.functional.pad(
40            y.unsqueeze(1),
41            (
42                (self.win_length - self.hop_length) // 2,
43                (self.win_length - self.hop_length + 1) // 2,
44            ),
45            mode="reflect",
46        ).squeeze(1)
47        dtype = y.dtype
48        spec = torch.stft(
49            y.float(),
50            self.n_fft,
51            hop_length=self.hop_length,
52            win_length=self.win_length,
53            window=self.window,
54            center=self.center,
55            pad_mode="reflect",
56            normalized=False,
57            onesided=True,
58            return_complex=True,
59        )
60        spec = torch.view_as_real(spec)
61
62        if self.mode == "pow2_sqrt":
63            spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
64        spec = spec.to(dtype)
65        return spec

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

LinearSpectrogram( n_fft=2048, win_length=2048, hop_length=512, center=False, mode='pow2_sqrt')
17    def __init__(
18        self,
19        n_fft=2048,
20        win_length=2048,
21        hop_length=512,
22        center=False,
23        mode="pow2_sqrt",
24    ):
25        super().__init__()
26
27        self.n_fft = n_fft
28        self.win_length = win_length
29        self.hop_length = hop_length
30        self.center = center
31        self.mode = mode
32
33        self.register_buffer("window", torch.hann_window(win_length))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

n_fft
win_length
hop_length
center
mode
def forward(self, y: torch.Tensor) -> torch.Tensor:
35    def forward(self, y: Tensor) -> Tensor:
36        if y.ndim == 3:
37            y = y.squeeze(1)
38
39        y = torch.nn.functional.pad(
40            y.unsqueeze(1),
41            (
42                (self.win_length - self.hop_length) // 2,
43                (self.win_length - self.hop_length + 1) // 2,
44            ),
45            mode="reflect",
46        ).squeeze(1)
47        dtype = y.dtype
48        spec = torch.stft(
49            y.float(),
50            self.n_fft,
51            hop_length=self.hop_length,
52            win_length=self.win_length,
53            window=self.window,
54            center=self.center,
55            pad_mode="reflect",
56            normalized=False,
57            onesided=True,
58            return_complex=True,
59        )
60        spec = torch.view_as_real(spec)
61
62        if self.mode == "pow2_sqrt":
63            spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
64        spec = spec.to(dtype)
65        return spec

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class LogMelSpectrogram(torch.nn.modules.module.Module):
 68class LogMelSpectrogram(nn.Module):
 69    def __init__(
 70        self,
 71        sample_rate=44100,
 72        n_fft=2048,
 73        win_length=2048,
 74        hop_length=512,
 75        n_mels=128,
 76        center=False,
 77        f_min=0.0,
 78        f_max=None,
 79    ):
 80        super().__init__()
 81
 82        self.sample_rate = sample_rate
 83        self.n_fft = n_fft
 84        self.win_length = win_length
 85        self.hop_length = hop_length
 86        self.center = center
 87        self.n_mels = n_mels
 88        self.f_min = f_min
 89        self.f_max = f_max or sample_rate // 2
 90
 91        self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
 92        self.mel_scale = MelScale(
 93            self.n_mels,
 94            self.sample_rate,
 95            self.f_min,
 96            self.f_max,
 97            self.n_fft // 2 + 1,
 98            "slaney",
 99            "slaney",
100        )
101
102    def compress(self, x: Tensor) -> Tensor:
103        return torch.log(torch.clamp(x, min=1e-5))
104
105    def decompress(self, x: Tensor) -> Tensor:
106        return torch.exp(x)
107
108    def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
109        linear = self.spectrogram(x)
110        x = self.mel_scale(linear)
111        x = self.compress(x)
112        # print(x.shape)
113        if return_linear:
114            return x, self.compress(linear)
115
116        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

LogMelSpectrogram( sample_rate=44100, n_fft=2048, win_length=2048, hop_length=512, n_mels=128, center=False, f_min=0.0, f_max=None)
 69    def __init__(
 70        self,
 71        sample_rate=44100,
 72        n_fft=2048,
 73        win_length=2048,
 74        hop_length=512,
 75        n_mels=128,
 76        center=False,
 77        f_min=0.0,
 78        f_max=None,
 79    ):
 80        super().__init__()
 81
 82        self.sample_rate = sample_rate
 83        self.n_fft = n_fft
 84        self.win_length = win_length
 85        self.hop_length = hop_length
 86        self.center = center
 87        self.n_mels = n_mels
 88        self.f_min = f_min
 89        self.f_max = f_max or sample_rate // 2
 90
 91        self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
 92        self.mel_scale = MelScale(
 93            self.n_mels,
 94            self.sample_rate,
 95            self.f_min,
 96            self.f_max,
 97            self.n_fft // 2 + 1,
 98            "slaney",
 99            "slaney",
100        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

sample_rate
n_fft
win_length
hop_length
center
n_mels
f_min
f_max
spectrogram
mel_scale
def compress(self, x: torch.Tensor) -> torch.Tensor:
102    def compress(self, x: Tensor) -> Tensor:
103        return torch.log(torch.clamp(x, min=1e-5))
def decompress(self, x: torch.Tensor) -> torch.Tensor:
105    def decompress(self, x: Tensor) -> Tensor:
106        return torch.exp(x)
def forward(self, x: torch.Tensor, return_linear: bool = False) -> torch.Tensor:
108    def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
109        linear = self.spectrogram(x)
110        x = self.mel_scale(linear)
111        x = self.compress(x)
112        # print(x.shape)
113        if return_linear:
114            return x, self.compress(linear)
115
116        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.