divisor.acestep.music_dcae.music_vocoder

ACE-Step: A Step Towards Music Generation Foundation Model

https://github.com/ace-step/ACE-Step

Apache 2.0 License

  1"""
  2ACE-Step: A Step Towards Music Generation Foundation Model
  3
  4https://github.com/ace-step/ACE-Step
  5
  6Apache 2.0 License
  7"""
  8
  9import librosa
 10import torch
 11from torch import nn
 12
 13from functools import partial
 14from math import prod
 15from typing import Callable, Tuple, List
 16
 17import numpy as np
 18import torch.nn.functional as F
 19from torch.nn import Conv1d
 20from torch.nn.utils import weight_norm
 21from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
 22from diffusers.models.modeling_utils import ModelMixin
 23from diffusers.loaders import FromOriginalModelMixin
 24from diffusers.configuration_utils import ConfigMixin, register_to_config
 25
 26
 27try:
 28    from music_log_mel import LogMelSpectrogram
 29except ImportError:
 30    from .music_log_mel import LogMelSpectrogram
 31
 32
 33def drop_path(
 34    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
 35):
 36    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
 37
 38    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
 39    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
 40    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
 41    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
 42    'survival rate' as the argument.
 43
 44    """  # noqa: E501
 45
 46    if drop_prob == 0.0 or not training:
 47        return x
 48    keep_prob = 1 - drop_prob
 49    shape = (x.shape[0],) + (1,) * (
 50        x.ndim - 1
 51    )  # work with diff dim tensors, not just 2D ConvNets
 52    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
 53    if keep_prob > 0.0 and scale_by_keep:
 54        random_tensor.div_(keep_prob)
 55    return x * random_tensor
 56
 57
 58class DropPath(nn.Module):
 59    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501
 60
 61    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
 62        super(DropPath, self).__init__()
 63        self.drop_prob = drop_prob
 64        self.scale_by_keep = scale_by_keep
 65
 66    def forward(self, x):
 67        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
 68
 69    def extra_repr(self):
 70        return f"drop_prob={round(self.drop_prob,3):0.3f}"
 71
 72
 73class LayerNorm(nn.Module):
 74    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
 75    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
 76    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
 77    with shape (batch_size, channels, height, width).
 78    """  # noqa: E501
 79
 80    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
 81        super().__init__()
 82        self.weight = nn.Parameter(torch.ones(normalized_shape))
 83        self.bias = nn.Parameter(torch.zeros(normalized_shape))
 84        self.eps = eps
 85        self.data_format = data_format
 86        if self.data_format not in ["channels_last", "channels_first"]:
 87            raise NotImplementedError
 88        self.normalized_shape = (normalized_shape,)
 89
 90    def forward(self, x):
 91        if self.data_format == "channels_last":
 92            return F.layer_norm(
 93                x, self.normalized_shape, self.weight, self.bias, self.eps
 94            )
 95        elif self.data_format == "channels_first":
 96            u = x.mean(1, keepdim=True)
 97            s = (x - u).pow(2).mean(1, keepdim=True)
 98            x = (x - u) / torch.sqrt(s + self.eps)
 99            x = self.weight[:, None] * x + self.bias[:, None]
100            return x
101
102
103class ConvNeXtBlock(nn.Module):
104    r"""ConvNeXt Block. There are two equivalent implementations:
105    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
106    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
107    We use (2) as we find it slightly faster in PyTorch
108
109    Args:
110        dim (int): Number of input channels.
111        drop_path (float): Stochastic depth rate. Default: 0.0
112        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
113        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
114        kernel_size (int): Kernel size for depthwise conv. Default: 7.
115        dilation (int): Dilation for depthwise conv. Default: 1.
116    """  # noqa: E501
117
118    def __init__(
119        self,
120        dim: int,
121        drop_path: float = 0.0,
122        layer_scale_init_value: float = 1e-6,
123        mlp_ratio: float = 4.0,
124        kernel_size: int = 7,
125        dilation: int = 1,
126    ):
127        super().__init__()
128
129        self.dwconv = nn.Conv1d(
130            dim,
131            dim,
132            kernel_size=kernel_size,
133            padding=int(dilation * (kernel_size - 1) / 2),
134            groups=dim,
135        )  # depthwise conv
136        self.norm = LayerNorm(dim, eps=1e-6)
137        self.pwconv1 = nn.Linear(
138            dim, int(mlp_ratio * dim)
139        )  # pointwise/1x1 convs, implemented with linear layers
140        self.act = nn.GELU()
141        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
142        self.gamma = (
143            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
144            if layer_scale_init_value > 0
145            else None
146        )
147        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
148
149    def forward(self, x, apply_residual: bool = True):
150        input = x
151
152        x = self.dwconv(x)
153        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
154        x = self.norm(x)
155        x = self.pwconv1(x)
156        x = self.act(x)
157        x = self.pwconv2(x)
158
159        if self.gamma is not None:
160            x = self.gamma * x
161
162        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
163        x = self.drop_path(x)
164
165        if apply_residual:
166            x = input + x
167
168        return x
169
170
171class ParallelConvNeXtBlock(nn.Module):
172    def __init__(self, kernel_sizes: List[int], *args, **kwargs):
173        super().__init__()
174        self.blocks = nn.ModuleList(
175            [
176                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
177                for kernel_size in kernel_sizes
178            ]
179        )
180
181    def forward(self, x: torch.Tensor) -> torch.Tensor:
182        return torch.stack(
183            [block(x, apply_residual=False) for block in self.blocks] + [x],
184            dim=1,
185        ).sum(dim=1)
186
187
188class ConvNeXtEncoder(nn.Module):
189    def __init__(
190        self,
191        input_channels=3,
192        depths=[3, 3, 9, 3],
193        dims=[96, 192, 384, 768],
194        drop_path_rate=0.0,
195        layer_scale_init_value=1e-6,
196        kernel_sizes: Tuple[int] = (7,),
197    ):
198        super().__init__()
199        assert len(depths) == len(dims)
200
201        self.channel_layers = nn.ModuleList()
202        stem = nn.Sequential(
203            nn.Conv1d(
204                input_channels,
205                dims[0],
206                kernel_size=7,
207                padding=3,
208                padding_mode="replicate",
209            ),
210            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
211        )
212        self.channel_layers.append(stem)
213
214        for i in range(len(depths) - 1):
215            mid_layer = nn.Sequential(
216                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
217                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
218            )
219            self.channel_layers.append(mid_layer)
220
221        block_fn = (
222            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
223            if len(kernel_sizes) == 1
224            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
225        )
226
227        self.stages = nn.ModuleList()
228        drop_path_rates = [
229            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
230        ]
231
232        cur = 0
233        for i in range(len(depths)):
234            stage = nn.Sequential(
235                *[
236                    block_fn(
237                        dim=dims[i],
238                        drop_path=drop_path_rates[cur + j],
239                        layer_scale_init_value=layer_scale_init_value,
240                    )
241                    for j in range(depths[i])
242                ]
243            )
244            self.stages.append(stage)
245            cur += depths[i]
246
247        self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
248        self.apply(self._init_weights)
249
250    def _init_weights(self, m):
251        if isinstance(m, (nn.Conv1d, nn.Linear)):
252            nn.init.trunc_normal_(m.weight, std=0.02)
253            nn.init.constant_(m.bias, 0)
254
255    def forward(
256        self,
257        x: torch.Tensor,
258    ) -> torch.Tensor:
259        for channel_layer, stage in zip(self.channel_layers, self.stages):
260            x = channel_layer(x)
261            x = stage(x)
262
263        return self.norm(x)
264
265
266def init_weights(m, mean=0.0, std=0.01):
267    classname = m.__class__.__name__
268    if classname.find("Conv") != -1:
269        m.weight.data.normal_(mean, std)
270
271
272def get_padding(kernel_size, dilation=1):
273    return (kernel_size * dilation - dilation) // 2
274
275
276class ResBlock1(torch.nn.Module):
277    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
278        super().__init__()
279
280        self.convs1 = nn.ModuleList(
281            [
282                weight_norm(
283                    Conv1d(
284                        channels,
285                        channels,
286                        kernel_size,
287                        1,
288                        dilation=dilation[0],
289                        padding=get_padding(kernel_size, dilation[0]),
290                    )
291                ),
292                weight_norm(
293                    Conv1d(
294                        channels,
295                        channels,
296                        kernel_size,
297                        1,
298                        dilation=dilation[1],
299                        padding=get_padding(kernel_size, dilation[1]),
300                    )
301                ),
302                weight_norm(
303                    Conv1d(
304                        channels,
305                        channels,
306                        kernel_size,
307                        1,
308                        dilation=dilation[2],
309                        padding=get_padding(kernel_size, dilation[2]),
310                    )
311                ),
312            ]
313        )
314        self.convs1.apply(init_weights)
315
316        self.convs2 = nn.ModuleList(
317            [
318                weight_norm(
319                    Conv1d(
320                        channels,
321                        channels,
322                        kernel_size,
323                        1,
324                        dilation=1,
325                        padding=get_padding(kernel_size, 1),
326                    )
327                ),
328                weight_norm(
329                    Conv1d(
330                        channels,
331                        channels,
332                        kernel_size,
333                        1,
334                        dilation=1,
335                        padding=get_padding(kernel_size, 1),
336                    )
337                ),
338                weight_norm(
339                    Conv1d(
340                        channels,
341                        channels,
342                        kernel_size,
343                        1,
344                        dilation=1,
345                        padding=get_padding(kernel_size, 1),
346                    )
347                ),
348            ]
349        )
350        self.convs2.apply(init_weights)
351
352    def forward(self, x):
353        for c1, c2 in zip(self.convs1, self.convs2):
354            xt = F.silu(x)
355            xt = c1(xt)
356            xt = F.silu(xt)
357            xt = c2(xt)
358            x = xt + x
359        return x
360
361    def remove_weight_norm(self):
362        for conv in self.convs1:
363            remove_weight_norm(conv)
364        for conv in self.convs2:
365            remove_weight_norm(conv)
366
367
368class HiFiGANGenerator(nn.Module):
369    def __init__(
370        self,
371        *,
372        hop_length: int = 512,
373        upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
374        upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
375        resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
376        resblock_dilation_sizes: Tuple[Tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
377        num_mels: int = 128,
378        upsample_initial_channel: int = 512,
379        use_template: bool = True,
380        pre_conv_kernel_size: int = 7,
381        post_conv_kernel_size: int = 7,
382        post_activation: Callable = partial(nn.SiLU, inplace=True),
383    ):
384        super().__init__()
385
386        assert (
387            prod(upsample_rates) == hop_length
388        ), f"hop_length must be {prod(upsample_rates)}"
389
390        self.conv_pre = weight_norm(
391            nn.Conv1d(
392                num_mels,
393                upsample_initial_channel,
394                pre_conv_kernel_size,
395                1,
396                padding=get_padding(pre_conv_kernel_size),
397            )
398        )
399
400        self.num_upsamples = len(upsample_rates)
401        self.num_kernels = len(resblock_kernel_sizes)
402
403        self.noise_convs = nn.ModuleList()
404        self.use_template = use_template
405        self.ups = nn.ModuleList()
406
407        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
408            c_cur = upsample_initial_channel // (2 ** (i + 1))
409            self.ups.append(
410                weight_norm(
411                    nn.ConvTranspose1d(
412                        upsample_initial_channel // (2**i),
413                        upsample_initial_channel // (2 ** (i + 1)),
414                        k,
415                        u,
416                        padding=(k - u) // 2,
417                    )
418                )
419            )
420
421            if not use_template:
422                continue
423
424            if i + 1 < len(upsample_rates):
425                stride_f0 = np.prod(upsample_rates[i + 1 :])
426                self.noise_convs.append(
427                    Conv1d(
428                        1,
429                        c_cur,
430                        kernel_size=stride_f0 * 2,
431                        stride=stride_f0,
432                        padding=stride_f0 // 2,
433                    )
434                )
435            else:
436                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
437
438        self.resblocks = nn.ModuleList()
439        for i in range(len(self.ups)):
440            ch = upsample_initial_channel // (2 ** (i + 1))
441            for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
442                self.resblocks.append(ResBlock1(ch, k, d))
443
444        self.activation_post = post_activation()
445        self.conv_post = weight_norm(
446            nn.Conv1d(
447                ch,
448                1,
449                post_conv_kernel_size,
450                1,
451                padding=get_padding(post_conv_kernel_size),
452            )
453        )
454        self.ups.apply(init_weights)
455        self.conv_post.apply(init_weights)
456
457    def forward(self, x, template=None):
458        x = self.conv_pre(x)
459
460        for i in range(self.num_upsamples):
461            x = F.silu(x, inplace=True)
462            x = self.ups[i](x)
463
464            if self.use_template:
465                x = x + self.noise_convs[i](template)
466
467            xs = None
468
469            for j in range(self.num_kernels):
470                if xs is None:
471                    xs = self.resblocks[i * self.num_kernels + j](x)
472                else:
473                    xs += self.resblocks[i * self.num_kernels + j](x)
474
475            x = xs / self.num_kernels
476
477        x = self.activation_post(x)
478        x = self.conv_post(x)
479        x = torch.tanh(x)
480
481        return x
482
483    def remove_weight_norm(self):
484        for up in self.ups:
485            remove_weight_norm(up)
486        for block in self.resblocks:
487            block.remove_weight_norm()
488        remove_weight_norm(self.conv_pre)
489        remove_weight_norm(self.conv_post)
490
491
492class ADaMoSHiFiGANV1(ModelMixin, ConfigMixin, FromOriginalModelMixin):
493
494    @register_to_config
495    def __init__(
496        self,
497        input_channels: int = 128,
498        depths: List[int] = [3, 3, 9, 3],
499        dims: List[int] = [128, 256, 384, 512],
500        drop_path_rate: float = 0.0,
501        kernel_sizes: Tuple[int] = (7,),
502        upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
503        upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
504        resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
505        resblock_dilation_sizes: Tuple[Tuple[int]] = (
506            (1, 3, 5),
507            (1, 3, 5),
508            (1, 3, 5),
509            (1, 3, 5),
510        ),
511        num_mels: int = 512,
512        upsample_initial_channel: int = 1024,
513        use_template: bool = False,
514        pre_conv_kernel_size: int = 13,
515        post_conv_kernel_size: int = 13,
516        sampling_rate: int = 44100,
517        n_fft: int = 2048,
518        win_length: int = 2048,
519        hop_length: int = 512,
520        f_min: int = 40,
521        f_max: int = 16000,
522        n_mels: int = 128,
523    ):
524        super().__init__()
525
526        self.backbone = ConvNeXtEncoder(
527            input_channels=input_channels,
528            depths=depths,
529            dims=dims,
530            drop_path_rate=drop_path_rate,
531            kernel_sizes=kernel_sizes,
532        )
533
534        self.head = HiFiGANGenerator(
535            hop_length=hop_length,
536            upsample_rates=upsample_rates,
537            upsample_kernel_sizes=upsample_kernel_sizes,
538            resblock_kernel_sizes=resblock_kernel_sizes,
539            resblock_dilation_sizes=resblock_dilation_sizes,
540            num_mels=num_mels,
541            upsample_initial_channel=upsample_initial_channel,
542            use_template=use_template,
543            pre_conv_kernel_size=pre_conv_kernel_size,
544            post_conv_kernel_size=post_conv_kernel_size,
545        )
546        self.sampling_rate = sampling_rate
547        self.mel_transform = LogMelSpectrogram(
548            sample_rate=sampling_rate,
549            n_fft=n_fft,
550            win_length=win_length,
551            hop_length=hop_length,
552            f_min=f_min,
553            f_max=f_max,
554            n_mels=n_mels,
555        )
556        self.eval()
557
558    @torch.no_grad()
559    def decode(self, mel):
560        y = self.backbone(mel)
561        y = self.head(y)
562        return y
563
564    @torch.no_grad()
565    def encode(self, x):
566        return self.mel_transform(x)
567
568    def forward(self, mel):
569        y = self.backbone(mel)
570        y = self.head(y)
571        return y
572
573
574if __name__ == "__main__":
575    import soundfile as sf
576
577    x = "test_audio.wav"
578    model = ADaMoSHiFiGANV1.from_pretrained(
579        "./checkpoints/music_vocoder", local_files_only=True
580    )
581
582    wav, sr = librosa.load(x, sr=44100, mono=True)
583    wav = torch.from_numpy(wav).float()[None]
584    mel = model.encode(wav)
585
586    wav = model.decode(mel)[0].mT
587    sf.write("test_audio_vocoder_rec.wav", wav.cpu().numpy(), 44100)
def drop_path( x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
34def drop_path(
35    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
36):
37    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
38
39    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
40    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
41    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
42    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
43    'survival rate' as the argument.
44
45    """  # noqa: E501
46
47    if drop_prob == 0.0 or not training:
48        return x
49    keep_prob = 1 - drop_prob
50    shape = (x.shape[0],) + (1,) * (
51        x.ndim - 1
52    )  # work with diff dim tensors, not just 2D ConvNets
53    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
54    if keep_prob > 0.0 and scale_by_keep:
55        random_tensor.div_(keep_prob)
56    return x * random_tensor

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.

class DropPath(torch.nn.modules.module.Module):
59class DropPath(nn.Module):
60    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501
61
62    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
63        super(DropPath, self).__init__()
64        self.drop_prob = drop_prob
65        self.scale_by_keep = scale_by_keep
66
67    def forward(self, x):
68        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
69
70    def extra_repr(self):
71        return f"drop_prob={round(self.drop_prob,3):0.3f}"

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

DropPath(drop_prob: float = 0.0, scale_by_keep: bool = True)
62    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
63        super(DropPath, self).__init__()
64        self.drop_prob = drop_prob
65        self.scale_by_keep = scale_by_keep

Initialize internal Module state, shared by both nn.Module and ScriptModule.

drop_prob
scale_by_keep
def forward(self, x):
67    def forward(self, x):
68        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def extra_repr(self):
70    def extra_repr(self):
71        return f"drop_prob={round(self.drop_prob,3):0.3f}"

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

class LayerNorm(torch.nn.modules.module.Module):
 74class LayerNorm(nn.Module):
 75    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
 76    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
 77    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
 78    with shape (batch_size, channels, height, width).
 79    """  # noqa: E501
 80
 81    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
 82        super().__init__()
 83        self.weight = nn.Parameter(torch.ones(normalized_shape))
 84        self.bias = nn.Parameter(torch.zeros(normalized_shape))
 85        self.eps = eps
 86        self.data_format = data_format
 87        if self.data_format not in ["channels_last", "channels_first"]:
 88            raise NotImplementedError
 89        self.normalized_shape = (normalized_shape,)
 90
 91    def forward(self, x):
 92        if self.data_format == "channels_last":
 93            return F.layer_norm(
 94                x, self.normalized_shape, self.weight, self.bias, self.eps
 95            )
 96        elif self.data_format == "channels_first":
 97            u = x.mean(1, keepdim=True)
 98            s = (x - u).pow(2).mean(1, keepdim=True)
 99            x = (x - u) / torch.sqrt(s + self.eps)
100            x = self.weight[:, None] * x + self.bias[:, None]
101            return x

LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).

LayerNorm(normalized_shape, eps=1e-06, data_format='channels_last')
81    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
82        super().__init__()
83        self.weight = nn.Parameter(torch.ones(normalized_shape))
84        self.bias = nn.Parameter(torch.zeros(normalized_shape))
85        self.eps = eps
86        self.data_format = data_format
87        if self.data_format not in ["channels_last", "channels_first"]:
88            raise NotImplementedError
89        self.normalized_shape = (normalized_shape,)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

weight
bias
eps
data_format
normalized_shape
def forward(self, x):
 91    def forward(self, x):
 92        if self.data_format == "channels_last":
 93            return F.layer_norm(
 94                x, self.normalized_shape, self.weight, self.bias, self.eps
 95            )
 96        elif self.data_format == "channels_first":
 97            u = x.mean(1, keepdim=True)
 98            s = (x - u).pow(2).mean(1, keepdim=True)
 99            x = (x - u) / torch.sqrt(s + self.eps)
100            x = self.weight[:, None] * x + self.bias[:, None]
101            return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ConvNeXtBlock(torch.nn.modules.module.Module):
104class ConvNeXtBlock(nn.Module):
105    r"""ConvNeXt Block. There are two equivalent implementations:
106    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
107    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
108    We use (2) as we find it slightly faster in PyTorch
109
110    Args:
111        dim (int): Number of input channels.
112        drop_path (float): Stochastic depth rate. Default: 0.0
113        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
114        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
115        kernel_size (int): Kernel size for depthwise conv. Default: 7.
116        dilation (int): Dilation for depthwise conv. Default: 1.
117    """  # noqa: E501
118
119    def __init__(
120        self,
121        dim: int,
122        drop_path: float = 0.0,
123        layer_scale_init_value: float = 1e-6,
124        mlp_ratio: float = 4.0,
125        kernel_size: int = 7,
126        dilation: int = 1,
127    ):
128        super().__init__()
129
130        self.dwconv = nn.Conv1d(
131            dim,
132            dim,
133            kernel_size=kernel_size,
134            padding=int(dilation * (kernel_size - 1) / 2),
135            groups=dim,
136        )  # depthwise conv
137        self.norm = LayerNorm(dim, eps=1e-6)
138        self.pwconv1 = nn.Linear(
139            dim, int(mlp_ratio * dim)
140        )  # pointwise/1x1 convs, implemented with linear layers
141        self.act = nn.GELU()
142        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
143        self.gamma = (
144            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
145            if layer_scale_init_value > 0
146            else None
147        )
148        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
149
150    def forward(self, x, apply_residual: bool = True):
151        input = x
152
153        x = self.dwconv(x)
154        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
155        x = self.norm(x)
156        x = self.pwconv1(x)
157        x = self.act(x)
158        x = self.pwconv2(x)
159
160        if self.gamma is not None:
161            x = self.gamma * x
162
163        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
164        x = self.drop_path(x)
165
166        if apply_residual:
167            x = input + x
168
169        return x

ConvNeXt Block. There are two equivalent implementations: (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back We use (2) as we find it slightly faster in PyTorch

Args: dim (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. kernel_size (int): Kernel size for depthwise conv. Default: 7. dilation (int): Dilation for depthwise conv. Default: 1.

ConvNeXtBlock( dim: int, drop_path: float = 0.0, layer_scale_init_value: float = 1e-06, mlp_ratio: float = 4.0, kernel_size: int = 7, dilation: int = 1)
119    def __init__(
120        self,
121        dim: int,
122        drop_path: float = 0.0,
123        layer_scale_init_value: float = 1e-6,
124        mlp_ratio: float = 4.0,
125        kernel_size: int = 7,
126        dilation: int = 1,
127    ):
128        super().__init__()
129
130        self.dwconv = nn.Conv1d(
131            dim,
132            dim,
133            kernel_size=kernel_size,
134            padding=int(dilation * (kernel_size - 1) / 2),
135            groups=dim,
136        )  # depthwise conv
137        self.norm = LayerNorm(dim, eps=1e-6)
138        self.pwconv1 = nn.Linear(
139            dim, int(mlp_ratio * dim)
140        )  # pointwise/1x1 convs, implemented with linear layers
141        self.act = nn.GELU()
142        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
143        self.gamma = (
144            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
145            if layer_scale_init_value > 0
146            else None
147        )
148        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dwconv
norm
pwconv1
act
pwconv2
gamma
drop_path
def forward(self, x, apply_residual: bool = True):
150    def forward(self, x, apply_residual: bool = True):
151        input = x
152
153        x = self.dwconv(x)
154        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
155        x = self.norm(x)
156        x = self.pwconv1(x)
157        x = self.act(x)
158        x = self.pwconv2(x)
159
160        if self.gamma is not None:
161            x = self.gamma * x
162
163        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
164        x = self.drop_path(x)
165
166        if apply_residual:
167            x = input + x
168
169        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ParallelConvNeXtBlock(torch.nn.modules.module.Module):
172class ParallelConvNeXtBlock(nn.Module):
173    def __init__(self, kernel_sizes: List[int], *args, **kwargs):
174        super().__init__()
175        self.blocks = nn.ModuleList(
176            [
177                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
178                for kernel_size in kernel_sizes
179            ]
180        )
181
182    def forward(self, x: torch.Tensor) -> torch.Tensor:
183        return torch.stack(
184            [block(x, apply_residual=False) for block in self.blocks] + [x],
185            dim=1,
186        ).sum(dim=1)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ParallelConvNeXtBlock(kernel_sizes: List[int], *args, **kwargs)
173    def __init__(self, kernel_sizes: List[int], *args, **kwargs):
174        super().__init__()
175        self.blocks = nn.ModuleList(
176            [
177                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
178                for kernel_size in kernel_sizes
179            ]
180        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

blocks
def forward(self, x: torch.Tensor) -> torch.Tensor:
182    def forward(self, x: torch.Tensor) -> torch.Tensor:
183        return torch.stack(
184            [block(x, apply_residual=False) for block in self.blocks] + [x],
185            dim=1,
186        ).sum(dim=1)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ConvNeXtEncoder(torch.nn.modules.module.Module):
189class ConvNeXtEncoder(nn.Module):
190    def __init__(
191        self,
192        input_channels=3,
193        depths=[3, 3, 9, 3],
194        dims=[96, 192, 384, 768],
195        drop_path_rate=0.0,
196        layer_scale_init_value=1e-6,
197        kernel_sizes: Tuple[int] = (7,),
198    ):
199        super().__init__()
200        assert len(depths) == len(dims)
201
202        self.channel_layers = nn.ModuleList()
203        stem = nn.Sequential(
204            nn.Conv1d(
205                input_channels,
206                dims[0],
207                kernel_size=7,
208                padding=3,
209                padding_mode="replicate",
210            ),
211            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
212        )
213        self.channel_layers.append(stem)
214
215        for i in range(len(depths) - 1):
216            mid_layer = nn.Sequential(
217                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
218                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
219            )
220            self.channel_layers.append(mid_layer)
221
222        block_fn = (
223            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
224            if len(kernel_sizes) == 1
225            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
226        )
227
228        self.stages = nn.ModuleList()
229        drop_path_rates = [
230            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
231        ]
232
233        cur = 0
234        for i in range(len(depths)):
235            stage = nn.Sequential(
236                *[
237                    block_fn(
238                        dim=dims[i],
239                        drop_path=drop_path_rates[cur + j],
240                        layer_scale_init_value=layer_scale_init_value,
241                    )
242                    for j in range(depths[i])
243                ]
244            )
245            self.stages.append(stage)
246            cur += depths[i]
247
248        self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
249        self.apply(self._init_weights)
250
251    def _init_weights(self, m):
252        if isinstance(m, (nn.Conv1d, nn.Linear)):
253            nn.init.trunc_normal_(m.weight, std=0.02)
254            nn.init.constant_(m.bias, 0)
255
256    def forward(
257        self,
258        x: torch.Tensor,
259    ) -> torch.Tensor:
260        for channel_layer, stage in zip(self.channel_layers, self.stages):
261            x = channel_layer(x)
262            x = stage(x)
263
264        return self.norm(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ConvNeXtEncoder( input_channels=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.0, layer_scale_init_value=1e-06, kernel_sizes: Tuple[int] = (7,))
190    def __init__(
191        self,
192        input_channels=3,
193        depths=[3, 3, 9, 3],
194        dims=[96, 192, 384, 768],
195        drop_path_rate=0.0,
196        layer_scale_init_value=1e-6,
197        kernel_sizes: Tuple[int] = (7,),
198    ):
199        super().__init__()
200        assert len(depths) == len(dims)
201
202        self.channel_layers = nn.ModuleList()
203        stem = nn.Sequential(
204            nn.Conv1d(
205                input_channels,
206                dims[0],
207                kernel_size=7,
208                padding=3,
209                padding_mode="replicate",
210            ),
211            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
212        )
213        self.channel_layers.append(stem)
214
215        for i in range(len(depths) - 1):
216            mid_layer = nn.Sequential(
217                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
218                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
219            )
220            self.channel_layers.append(mid_layer)
221
222        block_fn = (
223            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
224            if len(kernel_sizes) == 1
225            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
226        )
227
228        self.stages = nn.ModuleList()
229        drop_path_rates = [
230            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
231        ]
232
233        cur = 0
234        for i in range(len(depths)):
235            stage = nn.Sequential(
236                *[
237                    block_fn(
238                        dim=dims[i],
239                        drop_path=drop_path_rates[cur + j],
240                        layer_scale_init_value=layer_scale_init_value,
241                    )
242                    for j in range(depths[i])
243                ]
244            )
245            self.stages.append(stage)
246            cur += depths[i]
247
248        self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
249        self.apply(self._init_weights)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

channel_layers
stages
norm
def forward(self, x: torch.Tensor) -> torch.Tensor:
256    def forward(
257        self,
258        x: torch.Tensor,
259    ) -> torch.Tensor:
260        for channel_layer, stage in zip(self.channel_layers, self.stages):
261            x = channel_layer(x)
262            x = stage(x)
263
264        return self.norm(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def init_weights(m, mean=0.0, std=0.01):
267def init_weights(m, mean=0.0, std=0.01):
268    classname = m.__class__.__name__
269    if classname.find("Conv") != -1:
270        m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
273def get_padding(kernel_size, dilation=1):
274    return (kernel_size * dilation - dilation) // 2
class ResBlock1(torch.nn.modules.module.Module):
277class ResBlock1(torch.nn.Module):
278    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
279        super().__init__()
280
281        self.convs1 = nn.ModuleList(
282            [
283                weight_norm(
284                    Conv1d(
285                        channels,
286                        channels,
287                        kernel_size,
288                        1,
289                        dilation=dilation[0],
290                        padding=get_padding(kernel_size, dilation[0]),
291                    )
292                ),
293                weight_norm(
294                    Conv1d(
295                        channels,
296                        channels,
297                        kernel_size,
298                        1,
299                        dilation=dilation[1],
300                        padding=get_padding(kernel_size, dilation[1]),
301                    )
302                ),
303                weight_norm(
304                    Conv1d(
305                        channels,
306                        channels,
307                        kernel_size,
308                        1,
309                        dilation=dilation[2],
310                        padding=get_padding(kernel_size, dilation[2]),
311                    )
312                ),
313            ]
314        )
315        self.convs1.apply(init_weights)
316
317        self.convs2 = nn.ModuleList(
318            [
319                weight_norm(
320                    Conv1d(
321                        channels,
322                        channels,
323                        kernel_size,
324                        1,
325                        dilation=1,
326                        padding=get_padding(kernel_size, 1),
327                    )
328                ),
329                weight_norm(
330                    Conv1d(
331                        channels,
332                        channels,
333                        kernel_size,
334                        1,
335                        dilation=1,
336                        padding=get_padding(kernel_size, 1),
337                    )
338                ),
339                weight_norm(
340                    Conv1d(
341                        channels,
342                        channels,
343                        kernel_size,
344                        1,
345                        dilation=1,
346                        padding=get_padding(kernel_size, 1),
347                    )
348                ),
349            ]
350        )
351        self.convs2.apply(init_weights)
352
353    def forward(self, x):
354        for c1, c2 in zip(self.convs1, self.convs2):
355            xt = F.silu(x)
356            xt = c1(xt)
357            xt = F.silu(xt)
358            xt = c2(xt)
359            x = xt + x
360        return x
361
362    def remove_weight_norm(self):
363        for conv in self.convs1:
364            remove_weight_norm(conv)
365        for conv in self.convs2:
366            remove_weight_norm(conv)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ResBlock1(channels, kernel_size=3, dilation=(1, 3, 5))
278    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
279        super().__init__()
280
281        self.convs1 = nn.ModuleList(
282            [
283                weight_norm(
284                    Conv1d(
285                        channels,
286                        channels,
287                        kernel_size,
288                        1,
289                        dilation=dilation[0],
290                        padding=get_padding(kernel_size, dilation[0]),
291                    )
292                ),
293                weight_norm(
294                    Conv1d(
295                        channels,
296                        channels,
297                        kernel_size,
298                        1,
299                        dilation=dilation[1],
300                        padding=get_padding(kernel_size, dilation[1]),
301                    )
302                ),
303                weight_norm(
304                    Conv1d(
305                        channels,
306                        channels,
307                        kernel_size,
308                        1,
309                        dilation=dilation[2],
310                        padding=get_padding(kernel_size, dilation[2]),
311                    )
312                ),
313            ]
314        )
315        self.convs1.apply(init_weights)
316
317        self.convs2 = nn.ModuleList(
318            [
319                weight_norm(
320                    Conv1d(
321                        channels,
322                        channels,
323                        kernel_size,
324                        1,
325                        dilation=1,
326                        padding=get_padding(kernel_size, 1),
327                    )
328                ),
329                weight_norm(
330                    Conv1d(
331                        channels,
332                        channels,
333                        kernel_size,
334                        1,
335                        dilation=1,
336                        padding=get_padding(kernel_size, 1),
337                    )
338                ),
339                weight_norm(
340                    Conv1d(
341                        channels,
342                        channels,
343                        kernel_size,
344                        1,
345                        dilation=1,
346                        padding=get_padding(kernel_size, 1),
347                    )
348                ),
349            ]
350        )
351        self.convs2.apply(init_weights)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

convs1
convs2
def forward(self, x):
353    def forward(self, x):
354        for c1, c2 in zip(self.convs1, self.convs2):
355            xt = F.silu(x)
356            xt = c1(xt)
357            xt = F.silu(xt)
358            xt = c2(xt)
359            x = xt + x
360        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def remove_weight_norm(self):
362    def remove_weight_norm(self):
363        for conv in self.convs1:
364            remove_weight_norm(conv)
365        for conv in self.convs2:
366            remove_weight_norm(conv)
class HiFiGANGenerator(torch.nn.modules.module.Module):
369class HiFiGANGenerator(nn.Module):
370    def __init__(
371        self,
372        *,
373        hop_length: int = 512,
374        upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
375        upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
376        resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
377        resblock_dilation_sizes: Tuple[Tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
378        num_mels: int = 128,
379        upsample_initial_channel: int = 512,
380        use_template: bool = True,
381        pre_conv_kernel_size: int = 7,
382        post_conv_kernel_size: int = 7,
383        post_activation: Callable = partial(nn.SiLU, inplace=True),
384    ):
385        super().__init__()
386
387        assert (
388            prod(upsample_rates) == hop_length
389        ), f"hop_length must be {prod(upsample_rates)}"
390
391        self.conv_pre = weight_norm(
392            nn.Conv1d(
393                num_mels,
394                upsample_initial_channel,
395                pre_conv_kernel_size,
396                1,
397                padding=get_padding(pre_conv_kernel_size),
398            )
399        )
400
401        self.num_upsamples = len(upsample_rates)
402        self.num_kernels = len(resblock_kernel_sizes)
403
404        self.noise_convs = nn.ModuleList()
405        self.use_template = use_template
406        self.ups = nn.ModuleList()
407
408        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
409            c_cur = upsample_initial_channel // (2 ** (i + 1))
410            self.ups.append(
411                weight_norm(
412                    nn.ConvTranspose1d(
413                        upsample_initial_channel // (2**i),
414                        upsample_initial_channel // (2 ** (i + 1)),
415                        k,
416                        u,
417                        padding=(k - u) // 2,
418                    )
419                )
420            )
421
422            if not use_template:
423                continue
424
425            if i + 1 < len(upsample_rates):
426                stride_f0 = np.prod(upsample_rates[i + 1 :])
427                self.noise_convs.append(
428                    Conv1d(
429                        1,
430                        c_cur,
431                        kernel_size=stride_f0 * 2,
432                        stride=stride_f0,
433                        padding=stride_f0 // 2,
434                    )
435                )
436            else:
437                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
438
439        self.resblocks = nn.ModuleList()
440        for i in range(len(self.ups)):
441            ch = upsample_initial_channel // (2 ** (i + 1))
442            for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
443                self.resblocks.append(ResBlock1(ch, k, d))
444
445        self.activation_post = post_activation()
446        self.conv_post = weight_norm(
447            nn.Conv1d(
448                ch,
449                1,
450                post_conv_kernel_size,
451                1,
452                padding=get_padding(post_conv_kernel_size),
453            )
454        )
455        self.ups.apply(init_weights)
456        self.conv_post.apply(init_weights)
457
458    def forward(self, x, template=None):
459        x = self.conv_pre(x)
460
461        for i in range(self.num_upsamples):
462            x = F.silu(x, inplace=True)
463            x = self.ups[i](x)
464
465            if self.use_template:
466                x = x + self.noise_convs[i](template)
467
468            xs = None
469
470            for j in range(self.num_kernels):
471                if xs is None:
472                    xs = self.resblocks[i * self.num_kernels + j](x)
473                else:
474                    xs += self.resblocks[i * self.num_kernels + j](x)
475
476            x = xs / self.num_kernels
477
478        x = self.activation_post(x)
479        x = self.conv_post(x)
480        x = torch.tanh(x)
481
482        return x
483
484    def remove_weight_norm(self):
485        for up in self.ups:
486            remove_weight_norm(up)
487        for block in self.resblocks:
488            block.remove_weight_norm()
489        remove_weight_norm(self.conv_pre)
490        remove_weight_norm(self.conv_post)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

HiFiGANGenerator( *, hop_length: int = 512, upsample_rates: Tuple[int] = (8, 8, 2, 2, 2), upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2), resblock_kernel_sizes: Tuple[int] = (3, 7, 11), resblock_dilation_sizes: Tuple[Tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), num_mels: int = 128, upsample_initial_channel: int = 512, use_template: bool = True, pre_conv_kernel_size: int = 7, post_conv_kernel_size: int = 7, post_activation: Callable = functools.partial(<class 'torch.nn.modules.activation.SiLU'>, inplace=True))
370    def __init__(
371        self,
372        *,
373        hop_length: int = 512,
374        upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
375        upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
376        resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
377        resblock_dilation_sizes: Tuple[Tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
378        num_mels: int = 128,
379        upsample_initial_channel: int = 512,
380        use_template: bool = True,
381        pre_conv_kernel_size: int = 7,
382        post_conv_kernel_size: int = 7,
383        post_activation: Callable = partial(nn.SiLU, inplace=True),
384    ):
385        super().__init__()
386
387        assert (
388            prod(upsample_rates) == hop_length
389        ), f"hop_length must be {prod(upsample_rates)}"
390
391        self.conv_pre = weight_norm(
392            nn.Conv1d(
393                num_mels,
394                upsample_initial_channel,
395                pre_conv_kernel_size,
396                1,
397                padding=get_padding(pre_conv_kernel_size),
398            )
399        )
400
401        self.num_upsamples = len(upsample_rates)
402        self.num_kernels = len(resblock_kernel_sizes)
403
404        self.noise_convs = nn.ModuleList()
405        self.use_template = use_template
406        self.ups = nn.ModuleList()
407
408        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
409            c_cur = upsample_initial_channel // (2 ** (i + 1))
410            self.ups.append(
411                weight_norm(
412                    nn.ConvTranspose1d(
413                        upsample_initial_channel // (2**i),
414                        upsample_initial_channel // (2 ** (i + 1)),
415                        k,
416                        u,
417                        padding=(k - u) // 2,
418                    )
419                )
420            )
421
422            if not use_template:
423                continue
424
425            if i + 1 < len(upsample_rates):
426                stride_f0 = np.prod(upsample_rates[i + 1 :])
427                self.noise_convs.append(
428                    Conv1d(
429                        1,
430                        c_cur,
431                        kernel_size=stride_f0 * 2,
432                        stride=stride_f0,
433                        padding=stride_f0 // 2,
434                    )
435                )
436            else:
437                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
438
439        self.resblocks = nn.ModuleList()
440        for i in range(len(self.ups)):
441            ch = upsample_initial_channel // (2 ** (i + 1))
442            for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
443                self.resblocks.append(ResBlock1(ch, k, d))
444
445        self.activation_post = post_activation()
446        self.conv_post = weight_norm(
447            nn.Conv1d(
448                ch,
449                1,
450                post_conv_kernel_size,
451                1,
452                padding=get_padding(post_conv_kernel_size),
453            )
454        )
455        self.ups.apply(init_weights)
456        self.conv_post.apply(init_weights)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv_pre
num_upsamples
num_kernels
noise_convs
use_template
ups
resblocks
activation_post
conv_post
def forward(self, x, template=None):
458    def forward(self, x, template=None):
459        x = self.conv_pre(x)
460
461        for i in range(self.num_upsamples):
462            x = F.silu(x, inplace=True)
463            x = self.ups[i](x)
464
465            if self.use_template:
466                x = x + self.noise_convs[i](template)
467
468            xs = None
469
470            for j in range(self.num_kernels):
471                if xs is None:
472                    xs = self.resblocks[i * self.num_kernels + j](x)
473                else:
474                    xs += self.resblocks[i * self.num_kernels + j](x)
475
476            x = xs / self.num_kernels
477
478        x = self.activation_post(x)
479        x = self.conv_post(x)
480        x = torch.tanh(x)
481
482        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def remove_weight_norm(self):
484    def remove_weight_norm(self):
485        for up in self.ups:
486            remove_weight_norm(up)
487        for block in self.resblocks:
488            block.remove_weight_norm()
489        remove_weight_norm(self.conv_pre)
490        remove_weight_norm(self.conv_post)
class ADaMoSHiFiGANV1(diffusers.models.modeling_utils.ModelMixin, diffusers.configuration_utils.ConfigMixin, diffusers.loaders.single_file_model.FromOriginalModelMixin):
493class ADaMoSHiFiGANV1(ModelMixin, ConfigMixin, FromOriginalModelMixin):
494
495    @register_to_config
496    def __init__(
497        self,
498        input_channels: int = 128,
499        depths: List[int] = [3, 3, 9, 3],
500        dims: List[int] = [128, 256, 384, 512],
501        drop_path_rate: float = 0.0,
502        kernel_sizes: Tuple[int] = (7,),
503        upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
504        upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
505        resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
506        resblock_dilation_sizes: Tuple[Tuple[int]] = (
507            (1, 3, 5),
508            (1, 3, 5),
509            (1, 3, 5),
510            (1, 3, 5),
511        ),
512        num_mels: int = 512,
513        upsample_initial_channel: int = 1024,
514        use_template: bool = False,
515        pre_conv_kernel_size: int = 13,
516        post_conv_kernel_size: int = 13,
517        sampling_rate: int = 44100,
518        n_fft: int = 2048,
519        win_length: int = 2048,
520        hop_length: int = 512,
521        f_min: int = 40,
522        f_max: int = 16000,
523        n_mels: int = 128,
524    ):
525        super().__init__()
526
527        self.backbone = ConvNeXtEncoder(
528            input_channels=input_channels,
529            depths=depths,
530            dims=dims,
531            drop_path_rate=drop_path_rate,
532            kernel_sizes=kernel_sizes,
533        )
534
535        self.head = HiFiGANGenerator(
536            hop_length=hop_length,
537            upsample_rates=upsample_rates,
538            upsample_kernel_sizes=upsample_kernel_sizes,
539            resblock_kernel_sizes=resblock_kernel_sizes,
540            resblock_dilation_sizes=resblock_dilation_sizes,
541            num_mels=num_mels,
542            upsample_initial_channel=upsample_initial_channel,
543            use_template=use_template,
544            pre_conv_kernel_size=pre_conv_kernel_size,
545            post_conv_kernel_size=post_conv_kernel_size,
546        )
547        self.sampling_rate = sampling_rate
548        self.mel_transform = LogMelSpectrogram(
549            sample_rate=sampling_rate,
550            n_fft=n_fft,
551            win_length=win_length,
552            hop_length=hop_length,
553            f_min=f_min,
554            f_max=f_max,
555            n_mels=n_mels,
556        )
557        self.eval()
558
559    @torch.no_grad()
560    def decode(self, mel):
561        y = self.backbone(mel)
562        y = self.head(y)
563        return y
564
565    @torch.no_grad()
566    def encode(self, x):
567        return self.mel_transform(x)
568
569    def forward(self, mel):
570        y = self.backbone(mel)
571        y = self.head(y)
572        return y

Base class for all models.

[ModelMixin] takes care of storing the model configuration and provides methods for loading, downloading and saving models.

- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
@register_to_config
ADaMoSHiFiGANV1( input_channels: int = 128, depths: List[int] = [3, 3, 9, 3], dims: List[int] = [128, 256, 384, 512], drop_path_rate: float = 0.0, kernel_sizes: Tuple[int] = (7,), upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2), upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4), resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13), resblock_dilation_sizes: Tuple[Tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)), num_mels: int = 512, upsample_initial_channel: int = 1024, use_template: bool = False, pre_conv_kernel_size: int = 13, post_conv_kernel_size: int = 13, sampling_rate: int = 44100, n_fft: int = 2048, win_length: int = 2048, hop_length: int = 512, f_min: int = 40, f_max: int = 16000, n_mels: int = 128)
495    @register_to_config
496    def __init__(
497        self,
498        input_channels: int = 128,
499        depths: List[int] = [3, 3, 9, 3],
500        dims: List[int] = [128, 256, 384, 512],
501        drop_path_rate: float = 0.0,
502        kernel_sizes: Tuple[int] = (7,),
503        upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
504        upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
505        resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
506        resblock_dilation_sizes: Tuple[Tuple[int]] = (
507            (1, 3, 5),
508            (1, 3, 5),
509            (1, 3, 5),
510            (1, 3, 5),
511        ),
512        num_mels: int = 512,
513        upsample_initial_channel: int = 1024,
514        use_template: bool = False,
515        pre_conv_kernel_size: int = 13,
516        post_conv_kernel_size: int = 13,
517        sampling_rate: int = 44100,
518        n_fft: int = 2048,
519        win_length: int = 2048,
520        hop_length: int = 512,
521        f_min: int = 40,
522        f_max: int = 16000,
523        n_mels: int = 128,
524    ):
525        super().__init__()
526
527        self.backbone = ConvNeXtEncoder(
528            input_channels=input_channels,
529            depths=depths,
530            dims=dims,
531            drop_path_rate=drop_path_rate,
532            kernel_sizes=kernel_sizes,
533        )
534
535        self.head = HiFiGANGenerator(
536            hop_length=hop_length,
537            upsample_rates=upsample_rates,
538            upsample_kernel_sizes=upsample_kernel_sizes,
539            resblock_kernel_sizes=resblock_kernel_sizes,
540            resblock_dilation_sizes=resblock_dilation_sizes,
541            num_mels=num_mels,
542            upsample_initial_channel=upsample_initial_channel,
543            use_template=use_template,
544            pre_conv_kernel_size=pre_conv_kernel_size,
545            post_conv_kernel_size=post_conv_kernel_size,
546        )
547        self.sampling_rate = sampling_rate
548        self.mel_transform = LogMelSpectrogram(
549            sample_rate=sampling_rate,
550            n_fft=n_fft,
551            win_length=win_length,
552            hop_length=hop_length,
553            f_min=f_min,
554            f_max=f_max,
555            n_mels=n_mels,
556        )
557        self.eval()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

backbone
head
sampling_rate
mel_transform
@torch.no_grad()
def decode(self, mel):
559    @torch.no_grad()
560    def decode(self, mel):
561        y = self.backbone(mel)
562        y = self.head(y)
563        return y
@torch.no_grad()
def encode(self, x):
565    @torch.no_grad()
566    def encode(self, x):
567        return self.mel_transform(x)
def forward(self, mel):
569    def forward(self, mel):
570        y = self.backbone(mel)
571        y = self.head(y)
572        return y

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.