divisor.acestep.models.lyrics_utils.lyric_encoder

   1from typing import Optional, Tuple, Union
   2import math
   3import torch
   4from torch import nn
   5
   6
   7class ConvolutionModule(nn.Module):
   8    """ConvolutionModule in Conformer model."""
   9
  10    def __init__(
  11        self,
  12        channels: int,
  13        kernel_size: int = 15,
  14        activation: nn.Module = nn.ReLU(),
  15        norm: str = "batch_norm",
  16        causal: bool = False,
  17        bias: bool = True,
  18    ):
  19        """Construct an ConvolutionModule object.
  20        Args:
  21            channels (int): The number of channels of conv layers.
  22            kernel_size (int): Kernel size of conv layers.
  23            causal (int): Whether use causal convolution or not
  24        """
  25        super().__init__()
  26
  27        self.pointwise_conv1 = nn.Conv1d(
  28            channels,
  29            2 * channels,
  30            kernel_size=1,
  31            stride=1,
  32            padding=0,
  33            bias=bias,
  34        )
  35        # self.lorder is used to distinguish if it's a causal convolution,
  36        # if self.lorder > 0: it's a causal convolution, the input will be
  37        #    padded with self.lorder frames on the left in forward.
  38        # else: it's a symmetrical convolution
  39        if causal:
  40            padding = 0
  41            self.lorder = kernel_size - 1
  42        else:
  43            # kernel_size should be an odd number for none causal convolution
  44            assert (kernel_size - 1) % 2 == 0
  45            padding = (kernel_size - 1) // 2
  46            self.lorder = 0
  47        self.depthwise_conv = nn.Conv1d(
  48            channels,
  49            channels,
  50            kernel_size,
  51            stride=1,
  52            padding=padding,
  53            groups=channels,
  54            bias=bias,
  55        )
  56
  57        assert norm in ["batch_norm", "layer_norm"]
  58        if norm == "batch_norm":
  59            self.use_layer_norm = False
  60            self.norm = nn.BatchNorm1d(channels)
  61        else:
  62            self.use_layer_norm = True
  63            self.norm = nn.LayerNorm(channels)
  64
  65        self.pointwise_conv2 = nn.Conv1d(
  66            channels,
  67            channels,
  68            kernel_size=1,
  69            stride=1,
  70            padding=0,
  71            bias=bias,
  72        )
  73        self.activation = activation
  74
  75    def forward(
  76        self,
  77        x: torch.Tensor,
  78        mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
  79        cache: torch.Tensor = torch.zeros((0, 0, 0)),
  80    ) -> Tuple[torch.Tensor, torch.Tensor]:
  81        """Compute convolution module.
  82        Args:
  83            x (torch.Tensor): Input tensor (#batch, time, channels).
  84            mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
  85                (0, 0, 0) means fake mask.
  86            cache (torch.Tensor): left context cache, it is only
  87                used in causal convolution (#batch, channels, cache_t),
  88                (0, 0, 0) meas fake cache.
  89        Returns:
  90            torch.Tensor: Output tensor (#batch, time, channels).
  91        """
  92        # exchange the temporal dimension and the feature dimension
  93        x = x.transpose(1, 2)  # (#batch, channels, time)
  94
  95        # mask batch padding
  96        if mask_pad.size(2) > 0:  # time > 0
  97            x.masked_fill_(~mask_pad, 0.0)
  98
  99        if self.lorder > 0:
 100            if cache.size(2) == 0:  # cache_t == 0
 101                x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
 102            else:
 103                assert cache.size(0) == x.size(0)  # equal batch
 104                assert cache.size(1) == x.size(1)  # equal channel
 105                x = torch.cat((cache, x), dim=2)
 106            assert x.size(2) > self.lorder
 107            new_cache = x[:, :, -self.lorder :]
 108        else:
 109            # It's better we just return None if no cache is required,
 110            # However, for JIT export, here we just fake one tensor instead of
 111            # None.
 112            new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
 113
 114        # GLU mechanism
 115        x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
 116        x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
 117
 118        # 1D Depthwise Conv
 119        x = self.depthwise_conv(x)
 120        if self.use_layer_norm:
 121            x = x.transpose(1, 2)
 122        x = self.activation(self.norm(x))
 123        if self.use_layer_norm:
 124            x = x.transpose(1, 2)
 125        x = self.pointwise_conv2(x)
 126        # mask batch padding
 127        if mask_pad.size(2) > 0:  # time > 0
 128            x.masked_fill_(~mask_pad, 0.0)
 129
 130        return x.transpose(1, 2), new_cache
 131
 132
 133class PositionwiseFeedForward(torch.nn.Module):
 134    """Positionwise feed forward layer.
 135
 136    FeedForward are appied on each position of the sequence.
 137    The output dim is same with the input dim.
 138
 139    Args:
 140        idim (int): Input dimenstion.
 141        hidden_units (int): The number of hidden units.
 142        dropout_rate (float): Dropout rate.
 143        activation (torch.nn.Module): Activation function
 144    """
 145
 146    def __init__(
 147        self,
 148        idim: int,
 149        hidden_units: int,
 150        dropout_rate: float,
 151        activation: torch.nn.Module = torch.nn.ReLU(),
 152    ):
 153        """Construct a PositionwiseFeedForward object."""
 154        super(PositionwiseFeedForward, self).__init__()
 155        self.w_1 = torch.nn.Linear(idim, hidden_units)
 156        self.activation = activation
 157        self.dropout = torch.nn.Dropout(dropout_rate)
 158        self.w_2 = torch.nn.Linear(hidden_units, idim)
 159
 160    def forward(self, xs: torch.Tensor) -> torch.Tensor:
 161        """Forward function.
 162
 163        Args:
 164            xs: input tensor (B, L, D)
 165        Returns:
 166            output tensor, (B, L, D)
 167        """
 168        return self.w_2(self.dropout(self.activation(self.w_1(xs))))
 169
 170
 171class Swish(torch.nn.Module):
 172    """Construct an Swish object."""
 173
 174    def forward(self, x: torch.Tensor) -> torch.Tensor:
 175        """Return Swish activation function."""
 176        return x * torch.sigmoid(x)
 177
 178
 179class MultiHeadedAttention(nn.Module):
 180    """Multi-Head Attention layer.
 181
 182    Args:
 183        n_head (int): The number of heads.
 184        n_feat (int): The number of features.
 185        dropout_rate (float): Dropout rate.
 186
 187    """
 188
 189    def __init__(
 190        self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
 191    ):
 192        """Construct an MultiHeadedAttention object."""
 193        super().__init__()
 194        assert n_feat % n_head == 0
 195        # We assume d_v always equals d_k
 196        self.d_k = n_feat // n_head
 197        self.h = n_head
 198        self.linear_q = nn.Linear(n_feat, n_feat)
 199        self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
 200        self.linear_v = nn.Linear(n_feat, n_feat)
 201        self.linear_out = nn.Linear(n_feat, n_feat)
 202        self.dropout = nn.Dropout(p=dropout_rate)
 203
 204    def forward_qkv(
 205        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
 206    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 207        """Transform query, key and value.
 208
 209        Args:
 210            query (torch.Tensor): Query tensor (#batch, time1, size).
 211            key (torch.Tensor): Key tensor (#batch, time2, size).
 212            value (torch.Tensor): Value tensor (#batch, time2, size).
 213
 214        Returns:
 215            torch.Tensor: Transformed query tensor, size
 216                (#batch, n_head, time1, d_k).
 217            torch.Tensor: Transformed key tensor, size
 218                (#batch, n_head, time2, d_k).
 219            torch.Tensor: Transformed value tensor, size
 220                (#batch, n_head, time2, d_k).
 221
 222        """
 223        n_batch = query.size(0)
 224        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
 225        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
 226        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
 227        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
 228        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
 229        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
 230        return q, k, v
 231
 232    def forward_attention(
 233        self,
 234        value: torch.Tensor,
 235        scores: torch.Tensor,
 236        mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
 237    ) -> torch.Tensor:
 238        """Compute attention context vector.
 239
 240        Args:
 241            value (torch.Tensor): Transformed value, size
 242                (#batch, n_head, time2, d_k).
 243            scores (torch.Tensor): Attention score, size
 244                (#batch, n_head, time1, time2).
 245            mask (torch.Tensor): Mask, size (#batch, 1, time2) or
 246                (#batch, time1, time2), (0, 0, 0) means fake mask.
 247
 248        Returns:
 249            torch.Tensor: Transformed value (#batch, time1, d_model)
 250                weighted by the attention score (#batch, time1, time2).
 251
 252        """
 253        n_batch = value.size(0)
 254
 255        if mask.size(2) > 0:  # time2 > 0
 256            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
 257            # For last chunk, time2 might be larger than scores.size(-1)
 258            mask = mask[:, :, :, : scores.size(-1)]  # (batch, 1, *, time2)
 259            scores = scores.masked_fill(mask, -float("inf"))
 260            attn = torch.softmax(scores, dim=-1).masked_fill(
 261                mask, 0.0
 262            )  # (batch, head, time1, time2)
 263
 264        else:
 265            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
 266
 267        p_attn = self.dropout(attn)
 268        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
 269        x = (
 270            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
 271        )  # (batch, time1, d_model)
 272
 273        return self.linear_out(x)  # (batch, time1, d_model)
 274
 275    def forward(
 276        self,
 277        query: torch.Tensor,
 278        key: torch.Tensor,
 279        value: torch.Tensor,
 280        mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
 281        pos_emb: torch.Tensor = torch.empty(0),
 282        cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
 283    ) -> Tuple[torch.Tensor, torch.Tensor]:
 284        """Compute scaled dot product attention.
 285
 286        Args:
 287            query (torch.Tensor): Query tensor (#batch, time1, size).
 288            key (torch.Tensor): Key tensor (#batch, time2, size).
 289            value (torch.Tensor): Value tensor (#batch, time2, size).
 290            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
 291                (#batch, time1, time2).
 292                1.When applying cross attention between decoder and encoder,
 293                the batch padding mask for input is in (#batch, 1, T) shape.
 294                2.When applying self attention of encoder,
 295                the mask is in (#batch, T, T)  shape.
 296                3.When applying self attention of decoder,
 297                the mask is in (#batch, L, L)  shape.
 298                4.If the different position in decoder see different block
 299                of the encoder, such as Mocha, the passed in mask could be
 300                in (#batch, L, T) shape. But there is no such case in current
 301                CosyVoice.
 302            cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
 303                where `cache_t == chunk_size * num_decoding_left_chunks`
 304                and `head * d_k == size`
 305
 306
 307        Returns:
 308            torch.Tensor: Output tensor (#batch, time1, d_model).
 309            torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
 310                where `cache_t == chunk_size * num_decoding_left_chunks`
 311                and `head * d_k == size`
 312
 313        """
 314        q, k, v = self.forward_qkv(query, key, value)
 315        if cache.size(0) > 0:
 316            key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
 317            k = torch.cat([key_cache, k], dim=2)
 318            v = torch.cat([value_cache, v], dim=2)
 319        new_cache = torch.cat((k, v), dim=-1)
 320
 321        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
 322        return self.forward_attention(v, scores, mask), new_cache
 323
 324
 325class RelPositionMultiHeadedAttention(MultiHeadedAttention):
 326    """Multi-Head Attention layer with relative position encoding.
 327    Paper: https://arxiv.org/abs/1901.02860
 328    Args:
 329        n_head (int): The number of heads.
 330        n_feat (int): The number of features.
 331        dropout_rate (float): Dropout rate.
 332    """
 333
 334    def __init__(
 335        self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
 336    ):
 337        """Construct an RelPositionMultiHeadedAttention object."""
 338        super().__init__(n_head, n_feat, dropout_rate, key_bias)
 339        # linear transformation for positional encoding
 340        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
 341        # these two learnable bias are used in matrix c and matrix d
 342        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
 343        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
 344        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
 345        torch.nn.init.xavier_uniform_(self.pos_bias_u)
 346        torch.nn.init.xavier_uniform_(self.pos_bias_v)
 347
 348    def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
 349        """Compute relative positional encoding.
 350
 351        Args:
 352            x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
 353            time1 means the length of query vector.
 354
 355        Returns:
 356            torch.Tensor: Output tensor.
 357
 358        """
 359        zero_pad = torch.zeros(
 360            (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
 361        )
 362        x_padded = torch.cat([zero_pad, x], dim=-1)
 363
 364        x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
 365        x = x_padded[:, :, 1:].view_as(x)[
 366            :, :, :, : x.size(-1) // 2 + 1
 367        ]  # only keep the positions from 0 to time2
 368        return x
 369
 370    def forward(
 371        self,
 372        query: torch.Tensor,
 373        key: torch.Tensor,
 374        value: torch.Tensor,
 375        mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
 376        pos_emb: torch.Tensor = torch.empty(0),
 377        cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
 378    ) -> Tuple[torch.Tensor, torch.Tensor]:
 379        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
 380        Args:
 381            query (torch.Tensor): Query tensor (#batch, time1, size).
 382            key (torch.Tensor): Key tensor (#batch, time2, size).
 383            value (torch.Tensor): Value tensor (#batch, time2, size).
 384            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
 385                (#batch, time1, time2), (0, 0, 0) means fake mask.
 386            pos_emb (torch.Tensor): Positional embedding tensor
 387                (#batch, time2, size).
 388            cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
 389                where `cache_t == chunk_size * num_decoding_left_chunks`
 390                and `head * d_k == size`
 391        Returns:
 392            torch.Tensor: Output tensor (#batch, time1, d_model).
 393            torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
 394                where `cache_t == chunk_size * num_decoding_left_chunks`
 395                and `head * d_k == size`
 396        """
 397        q, k, v = self.forward_qkv(query, key, value)
 398        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
 399
 400        if cache.size(0) > 0:
 401            key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
 402            k = torch.cat([key_cache, k], dim=2)
 403            v = torch.cat([value_cache, v], dim=2)
 404        # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
 405        #   non-trivial to calculate `next_cache_start` here.
 406        new_cache = torch.cat((k, v), dim=-1)
 407
 408        n_batch_pos = pos_emb.size(0)
 409        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
 410        p = p.transpose(1, 2)  # (batch, head, time1, d_k)
 411
 412        # (batch, head, time1, d_k)
 413        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
 414        # (batch, head, time1, d_k)
 415        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
 416
 417        # compute attention score
 418        # first compute matrix a and matrix c
 419        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
 420        # (batch, head, time1, time2)
 421        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
 422
 423        # compute matrix b and matrix d
 424        # (batch, head, time1, time2)
 425        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
 426        # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
 427        if matrix_ac.shape != matrix_bd.shape:
 428            matrix_bd = self.rel_shift(matrix_bd)
 429
 430        scores = (matrix_ac + matrix_bd) / math.sqrt(
 431            self.d_k
 432        )  # (batch, head, time1, time2)
 433
 434        return self.forward_attention(v, scores, mask), new_cache
 435
 436
 437def subsequent_mask(
 438    size: int,
 439    device: torch.device = torch.device("cpu"),
 440) -> torch.Tensor:
 441    """Create mask for subsequent steps (size, size).
 442
 443    This mask is used only in decoder which works in an auto-regressive mode.
 444    This means the current step could only do attention with its left steps.
 445
 446    In encoder, fully attention is used when streaming is not necessary and
 447    the sequence is not long. In this  case, no attention mask is needed.
 448
 449    When streaming is need, chunk-based attention is used in encoder. See
 450    subsequent_chunk_mask for the chunk-based attention mask.
 451
 452    Args:
 453        size (int): size of mask
 454        str device (str): "cpu" or "cuda" or torch.Tensor.device
 455        dtype (torch.device): result dtype
 456
 457    Returns:
 458        torch.Tensor: mask
 459
 460    Examples:
 461        >>> subsequent_mask(3)
 462        [[1, 0, 0],
 463         [1, 1, 0],
 464         [1, 1, 1]]
 465    """
 466    arange = torch.arange(size, device=device)
 467    mask = arange.expand(size, size)
 468    arange = arange.unsqueeze(-1)
 469    mask = mask <= arange
 470    return mask
 471
 472
 473def subsequent_chunk_mask(
 474    size: int,
 475    chunk_size: int,
 476    num_left_chunks: int = -1,
 477    device: torch.device = torch.device("cpu"),
 478) -> torch.Tensor:
 479    """Create mask for subsequent steps (size, size) with chunk size,
 480       this is for streaming encoder
 481
 482    Args:
 483        size (int): size of mask
 484        chunk_size (int): size of chunk
 485        num_left_chunks (int): number of left chunks
 486            <0: use full chunk
 487            >=0: use num_left_chunks
 488        device (torch.device): "cpu" or "cuda" or torch.Tensor.device
 489
 490    Returns:
 491        torch.Tensor: mask
 492
 493    Examples:
 494        >>> subsequent_chunk_mask(4, 2)
 495        [[1, 1, 0, 0],
 496         [1, 1, 0, 0],
 497         [1, 1, 1, 1],
 498         [1, 1, 1, 1]]
 499    """
 500    ret = torch.zeros(size, size, device=device, dtype=torch.bool)
 501    for i in range(size):
 502        if num_left_chunks < 0:
 503            start = 0
 504        else:
 505            start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
 506        ending = min((i // chunk_size + 1) * chunk_size, size)
 507        ret[i, start:ending] = True
 508    return ret
 509
 510
 511def add_optional_chunk_mask(
 512    xs: torch.Tensor,
 513    masks: torch.Tensor,
 514    use_dynamic_chunk: bool,
 515    use_dynamic_left_chunk: bool,
 516    decoding_chunk_size: int,
 517    static_chunk_size: int,
 518    num_decoding_left_chunks: int,
 519    enable_full_context: bool = True,
 520):
 521    """Apply optional mask for encoder.
 522
 523    Args:
 524        xs (torch.Tensor): padded input, (B, L, D), L for max length
 525        mask (torch.Tensor): mask for xs, (B, 1, L)
 526        use_dynamic_chunk (bool): whether to use dynamic chunk or not
 527        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
 528            training.
 529        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
 530            0: default for training, use random dynamic chunk.
 531            <0: for decoding, use full chunk.
 532            >0: for decoding, use fixed chunk size as set.
 533        static_chunk_size (int): chunk size for static chunk training/decoding
 534            if it's greater than 0, if use_dynamic_chunk is true,
 535            this parameter will be ignored
 536        num_decoding_left_chunks: number of left chunks, this is for decoding,
 537            the chunk size is decoding_chunk_size.
 538            >=0: use num_decoding_left_chunks
 539            <0: use all left chunks
 540        enable_full_context (bool):
 541            True: chunk size is either [1, 25] or full context(max_len)
 542            False: chunk size ~ U[1, 25]
 543
 544    Returns:
 545        torch.Tensor: chunk mask of the input xs.
 546    """
 547    # Whether to use chunk mask or not
 548    if use_dynamic_chunk:
 549        max_len = xs.size(1)
 550        if decoding_chunk_size < 0:
 551            chunk_size = max_len
 552            num_left_chunks = -1
 553        elif decoding_chunk_size > 0:
 554            chunk_size = decoding_chunk_size
 555            num_left_chunks = num_decoding_left_chunks
 556        else:
 557            # chunk size is either [1, 25] or full context(max_len).
 558            # Since we use 4 times subsampling and allow up to 1s(100 frames)
 559            # delay, the maximum frame is 100 / 4 = 25.
 560            chunk_size = torch.randint(1, max_len, (1,)).item()
 561            num_left_chunks = -1
 562            if chunk_size > max_len // 2 and enable_full_context:
 563                chunk_size = max_len
 564            else:
 565                chunk_size = chunk_size % 25 + 1
 566                if use_dynamic_left_chunk:
 567                    max_left_chunks = (max_len - 1) // chunk_size
 568                    num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item()
 569        chunk_masks = subsequent_chunk_mask(
 570            xs.size(1), chunk_size, num_left_chunks, xs.device
 571        )  # (L, L)
 572        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
 573        chunk_masks = masks & chunk_masks  # (B, L, L)
 574    elif static_chunk_size > 0:
 575        num_left_chunks = num_decoding_left_chunks
 576        chunk_masks = subsequent_chunk_mask(
 577            xs.size(1), static_chunk_size, num_left_chunks, xs.device
 578        )  # (L, L)
 579        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
 580        chunk_masks = masks & chunk_masks  # (B, L, L)
 581    else:
 582        chunk_masks = masks
 583    return chunk_masks
 584
 585
 586class ConformerEncoderLayer(nn.Module):
 587    """Encoder layer module.
 588    Args:
 589        size (int): Input dimension.
 590        self_attn (torch.nn.Module): Self-attention module instance.
 591            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
 592            instance can be used as the argument.
 593        feed_forward (torch.nn.Module): Feed-forward module instance.
 594            `PositionwiseFeedForward` instance can be used as the argument.
 595        feed_forward_macaron (torch.nn.Module): Additional feed-forward module
 596             instance.
 597            `PositionwiseFeedForward` instance can be used as the argument.
 598        conv_module (torch.nn.Module): Convolution module instance.
 599            `ConvlutionModule` instance can be used as the argument.
 600        dropout_rate (float): Dropout rate.
 601        normalize_before (bool):
 602            True: use layer_norm before each sub-block.
 603            False: use layer_norm after each sub-block.
 604    """
 605
 606    def __init__(
 607        self,
 608        size: int,
 609        self_attn: torch.nn.Module,
 610        feed_forward: Optional[nn.Module] = None,
 611        feed_forward_macaron: Optional[nn.Module] = None,
 612        conv_module: Optional[nn.Module] = None,
 613        dropout_rate: float = 0.1,
 614        normalize_before: bool = True,
 615    ):
 616        """Construct an EncoderLayer object."""
 617        super().__init__()
 618        self.self_attn = self_attn
 619        self.feed_forward = feed_forward
 620        self.feed_forward_macaron = feed_forward_macaron
 621        self.conv_module = conv_module
 622        self.norm_ff = nn.LayerNorm(size, eps=1e-5)  # for the FNN module
 623        self.norm_mha = nn.LayerNorm(size, eps=1e-5)  # for the MHA module
 624        if feed_forward_macaron is not None:
 625            self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
 626            self.ff_scale = 0.5
 627        else:
 628            self.ff_scale = 1.0
 629        if self.conv_module is not None:
 630            self.norm_conv = nn.LayerNorm(size, eps=1e-5)  # for the CNN module
 631            self.norm_final = nn.LayerNorm(
 632                size, eps=1e-5
 633            )  # for the final output of the block
 634        self.dropout = nn.Dropout(dropout_rate)
 635        self.size = size
 636        self.normalize_before = normalize_before
 637
 638    def forward(
 639        self,
 640        x: torch.Tensor,
 641        mask: torch.Tensor,
 642        pos_emb: torch.Tensor,
 643        mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
 644        att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
 645        cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
 646    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
 647        """Compute encoded features.
 648
 649        Args:
 650            x (torch.Tensor): (#batch, time, size)
 651            mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
 652                (0, 0, 0) means fake mask.
 653            pos_emb (torch.Tensor): positional encoding, must not be None
 654                for ConformerEncoderLayer.
 655            mask_pad (torch.Tensor): batch padding mask used for conv module.
 656                (#batch, 1,time), (0, 0, 0) means fake mask.
 657            att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
 658                (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
 659            cnn_cache (torch.Tensor): Convolution cache in conformer layer
 660                (#batch=1, size, cache_t2)
 661        Returns:
 662            torch.Tensor: Output tensor (#batch, time, size).
 663            torch.Tensor: Mask tensor (#batch, time, time).
 664            torch.Tensor: att_cache tensor,
 665                (#batch=1, head, cache_t1 + time, d_k * 2).
 666            torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
 667        """
 668
 669        # whether to use macaron style
 670        if self.feed_forward_macaron is not None:
 671            residual = x
 672            if self.normalize_before:
 673                x = self.norm_ff_macaron(x)
 674            x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
 675            if not self.normalize_before:
 676                x = self.norm_ff_macaron(x)
 677
 678        # multi-headed self-attention module
 679        residual = x
 680        if self.normalize_before:
 681            x = self.norm_mha(x)
 682        x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
 683        x = residual + self.dropout(x_att)
 684        if not self.normalize_before:
 685            x = self.norm_mha(x)
 686
 687        # convolution module
 688        # Fake new cnn cache here, and then change it in conv_module
 689        new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
 690        if self.conv_module is not None:
 691            residual = x
 692            if self.normalize_before:
 693                x = self.norm_conv(x)
 694            x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
 695            x = residual + self.dropout(x)
 696
 697            if not self.normalize_before:
 698                x = self.norm_conv(x)
 699
 700        # feed forward module
 701        residual = x
 702        if self.normalize_before:
 703            x = self.norm_ff(x)
 704
 705        x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
 706        if not self.normalize_before:
 707            x = self.norm_ff(x)
 708
 709        if self.conv_module is not None:
 710            x = self.norm_final(x)
 711
 712        return x, mask, new_att_cache, new_cnn_cache
 713
 714
 715class EspnetRelPositionalEncoding(torch.nn.Module):
 716    """Relative positional encoding module (new implementation).
 717
 718    Details can be found in https://github.com/espnet/espnet/pull/2816.
 719
 720    See : Appendix B in https://arxiv.org/abs/1901.02860
 721
 722    Args:
 723        d_model (int): Embedding dimension.
 724        dropout_rate (float): Dropout rate.
 725        max_len (int): Maximum input length.
 726
 727    """
 728
 729    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
 730        """Construct an PositionalEncoding object."""
 731        super(EspnetRelPositionalEncoding, self).__init__()
 732        self.d_model = d_model
 733        self.xscale = math.sqrt(self.d_model)
 734        self.dropout = torch.nn.Dropout(p=dropout_rate)
 735        self.pe = None
 736        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
 737
 738    def extend_pe(self, x: torch.Tensor):
 739        """Reset the positional encodings."""
 740        if self.pe is not None:
 741            # self.pe contains both positive and negative parts
 742            # the length of self.pe is 2 * input_len - 1
 743            if self.pe.size(1) >= x.size(1) * 2 - 1:
 744                if self.pe.dtype != x.dtype or self.pe.device != x.device:
 745                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
 746                return
 747        # Suppose `i` means to the position of query vecotr and `j` means the
 748        # position of key vector. We use position relative positions when keys
 749        # are to the left (i>j) and negative relative positions otherwise (i<j).
 750        pe_positive = torch.zeros(x.size(1), self.d_model)
 751        pe_negative = torch.zeros(x.size(1), self.d_model)
 752        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
 753        div_term = torch.exp(
 754            torch.arange(0, self.d_model, 2, dtype=torch.float32)
 755            * -(math.log(10000.0) / self.d_model)
 756        )
 757        pe_positive[:, 0::2] = torch.sin(position * div_term)
 758        pe_positive[:, 1::2] = torch.cos(position * div_term)
 759        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
 760        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
 761
 762        # Reserve the order of positive indices and concat both positive and
 763        # negative indices. This is used to support the shifting trick
 764        # as in https://arxiv.org/abs/1901.02860
 765        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
 766        pe_negative = pe_negative[1:].unsqueeze(0)
 767        pe = torch.cat([pe_positive, pe_negative], dim=1)
 768        self.pe = pe.to(device=x.device, dtype=x.dtype)
 769
 770    def forward(
 771        self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
 772    ) -> Tuple[torch.Tensor, torch.Tensor]:
 773        """Add positional encoding.
 774
 775        Args:
 776            x (torch.Tensor): Input tensor (batch, time, `*`).
 777
 778        Returns:
 779            torch.Tensor: Encoded tensor (batch, time, `*`).
 780
 781        """
 782        self.extend_pe(x)
 783        x = x * self.xscale
 784        pos_emb = self.position_encoding(size=x.size(1), offset=offset)
 785        return self.dropout(x), self.dropout(pos_emb)
 786
 787    def position_encoding(
 788        self, offset: Union[int, torch.Tensor], size: int
 789    ) -> torch.Tensor:
 790        """For getting encoding in a streaming fashion
 791
 792        Attention!!!!!
 793        we apply dropout only once at the whole utterance level in a none
 794        streaming way, but will call this function several times with
 795        increasing input size in a streaming scenario, so the dropout will
 796        be applied several times.
 797
 798        Args:
 799            offset (int or torch.tensor): start offset
 800            size (int): required size of position encoding
 801
 802        Returns:
 803            torch.Tensor: Corresponding encoding
 804        """
 805        pos_emb = self.pe[
 806            :,
 807            self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
 808        ]
 809        return pos_emb
 810
 811
 812class LinearEmbed(torch.nn.Module):
 813    """Linear transform the input without subsampling
 814
 815    Args:
 816        idim (int): Input dimension.
 817        odim (int): Output dimension.
 818        dropout_rate (float): Dropout rate.
 819
 820    """
 821
 822    def __init__(
 823        self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
 824    ):
 825        """Construct an linear object."""
 826        super().__init__()
 827        self.out = torch.nn.Sequential(
 828            torch.nn.Linear(idim, odim),
 829            torch.nn.LayerNorm(odim, eps=1e-5),
 830            torch.nn.Dropout(dropout_rate),
 831        )
 832        self.pos_enc = pos_enc_class  # rel_pos_espnet
 833
 834    def position_encoding(
 835        self, offset: Union[int, torch.Tensor], size: int
 836    ) -> torch.Tensor:
 837        return self.pos_enc.position_encoding(offset, size)
 838
 839    def forward(
 840        self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
 841    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 842        """Input x.
 843
 844        Args:
 845            x (torch.Tensor): Input tensor (#batch, time, idim).
 846            x_mask (torch.Tensor): Input mask (#batch, 1, time).
 847
 848        Returns:
 849            torch.Tensor: linear input tensor (#batch, time', odim),
 850                where time' = time .
 851            torch.Tensor: linear input mask (#batch, 1, time'),
 852                where time' = time .
 853
 854        """
 855        x = self.out(x)
 856        x, pos_emb = self.pos_enc(x, offset)
 857        return x, pos_emb
 858
 859
 860ATTENTION_CLASSES = {
 861    "selfattn": MultiHeadedAttention,
 862    "rel_selfattn": RelPositionMultiHeadedAttention,
 863}
 864
 865ACTIVATION_CLASSES = {
 866    "hardtanh": torch.nn.Hardtanh,
 867    "tanh": torch.nn.Tanh,
 868    "relu": torch.nn.ReLU,
 869    "selu": torch.nn.SELU,
 870    "swish": getattr(torch.nn, "SiLU", Swish),
 871    "gelu": torch.nn.GELU,
 872}
 873
 874
 875def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
 876    """Make mask tensor containing indices of padded part.
 877
 878    See description of make_non_pad_mask.
 879
 880    Args:
 881        lengths (torch.Tensor): Batch of lengths (B,).
 882    Returns:
 883        torch.Tensor: Mask tensor containing indices of padded part.
 884
 885    Examples:
 886        >>> lengths = [5, 3, 2]
 887        >>> make_pad_mask(lengths)
 888        masks = [[0, 0, 0, 0 ,0],
 889                 [0, 0, 0, 1, 1],
 890                 [0, 0, 1, 1, 1]]
 891    """
 892    batch_size = lengths.size(0)
 893    max_len = max_len if max_len > 0 else lengths.max().item()
 894    seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
 895    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
 896    seq_length_expand = lengths.unsqueeze(-1)
 897    mask = seq_range_expand >= seq_length_expand
 898    return mask
 899
 900
 901# https://github.com/FunAudioLLM/CosyVoice/blob/main/examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml
 902class ConformerEncoder(torch.nn.Module):
 903    """Conformer encoder module."""
 904
 905    def __init__(
 906        self,
 907        input_size: int,
 908        output_size: int = 1024,
 909        attention_heads: int = 16,
 910        linear_units: int = 4096,
 911        num_blocks: int = 6,
 912        dropout_rate: float = 0.1,
 913        positional_dropout_rate: float = 0.1,
 914        attention_dropout_rate: float = 0.0,
 915        input_layer: str = "linear",
 916        pos_enc_layer_type: str = "rel_pos_espnet",
 917        normalize_before: bool = True,
 918        static_chunk_size: int = 1,  # 1: causal_mask; 0: full_mask
 919        use_dynamic_chunk: bool = False,
 920        use_dynamic_left_chunk: bool = False,
 921        positionwise_conv_kernel_size: int = 1,
 922        macaron_style: bool = False,
 923        selfattention_layer_type: str = "rel_selfattn",
 924        activation_type: str = "swish",
 925        use_cnn_module: bool = False,
 926        cnn_module_kernel: int = 15,
 927        causal: bool = False,
 928        cnn_module_norm: str = "batch_norm",
 929        key_bias: bool = True,
 930        gradient_checkpointing: bool = False,
 931    ):
 932        """Construct ConformerEncoder
 933
 934        Args:
 935            input_size to use_dynamic_chunk, see in BaseEncoder
 936            positionwise_conv_kernel_size (int): Kernel size of positionwise
 937                conv1d layer.
 938            macaron_style (bool): Whether to use macaron style for
 939                positionwise layer.
 940            selfattention_layer_type (str): Encoder attention layer type,
 941                the parameter has no effect now, it's just for configure
 942                compatibility. #'rel_selfattn'
 943            activation_type (str): Encoder activation function type.
 944            use_cnn_module (bool): Whether to use convolution module.
 945            cnn_module_kernel (int): Kernel size of convolution module.
 946            causal (bool): whether to use causal convolution or not.
 947            key_bias: whether use bias in attention.linear_k, False for whisper models.
 948        """
 949        super().__init__()
 950        self.output_size = output_size
 951        self.embed = LinearEmbed(
 952            input_size,
 953            output_size,
 954            dropout_rate,
 955            EspnetRelPositionalEncoding(output_size, positional_dropout_rate),
 956        )
 957        self.normalize_before = normalize_before
 958        self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
 959        self.gradient_checkpointing = gradient_checkpointing
 960        self.use_dynamic_chunk = use_dynamic_chunk
 961
 962        self.static_chunk_size = static_chunk_size
 963        self.use_dynamic_chunk = use_dynamic_chunk
 964        self.use_dynamic_left_chunk = use_dynamic_left_chunk
 965        activation = ACTIVATION_CLASSES[activation_type]()
 966
 967        # self-attention module definition
 968        encoder_selfattn_layer_args = (
 969            attention_heads,
 970            output_size,
 971            attention_dropout_rate,
 972            key_bias,
 973        )
 974        # feed-forward module definition
 975        positionwise_layer_args = (
 976            output_size,
 977            linear_units,
 978            dropout_rate,
 979            activation,
 980        )
 981        # convolution module definition
 982        convolution_layer_args = (
 983            output_size,
 984            cnn_module_kernel,
 985            activation,
 986            cnn_module_norm,
 987            causal,
 988        )
 989
 990        self.encoders = torch.nn.ModuleList(
 991            [
 992                ConformerEncoderLayer(
 993                    output_size,
 994                    RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
 995                    PositionwiseFeedForward(*positionwise_layer_args),
 996                    (
 997                        PositionwiseFeedForward(*positionwise_layer_args)
 998                        if macaron_style
 999                        else None
1000                    ),
1001                    (
1002                        ConvolutionModule(*convolution_layer_args)
1003                        if use_cnn_module
1004                        else None
1005                    ),
1006                    dropout_rate,
1007                    normalize_before,
1008                )
1009                for _ in range(num_blocks)
1010            ]
1011        )
1012
1013    def forward_layers(
1014        self,
1015        xs: torch.Tensor,
1016        chunk_masks: torch.Tensor,
1017        pos_emb: torch.Tensor,
1018        mask_pad: torch.Tensor,
1019    ) -> torch.Tensor:
1020        for layer in self.encoders:
1021            xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
1022        return xs
1023
1024    @torch.jit.unused
1025    def forward_layers_checkpointed(
1026        self,
1027        xs: torch.Tensor,
1028        chunk_masks: torch.Tensor,
1029        pos_emb: torch.Tensor,
1030        mask_pad: torch.Tensor,
1031    ) -> torch.Tensor:
1032        for layer in self.encoders:
1033            xs, chunk_masks, _, _ = torch.utils.checkpoint.checkpoint(
1034                layer.__call__, xs, chunk_masks, pos_emb, mask_pad, use_reentrant=False
1035            )
1036        return xs
1037
1038    def forward(
1039        self,
1040        xs: torch.Tensor,
1041        pad_mask: torch.Tensor,
1042        decoding_chunk_size: int = 0,
1043        num_decoding_left_chunks: int = -1,
1044    ) -> Tuple[torch.Tensor, torch.Tensor]:
1045        """Embed positions in tensor.
1046
1047        Args:
1048            xs: padded input tensor (B, T, D)
1049            xs_lens: input length (B)
1050            decoding_chunk_size: decoding chunk size for dynamic chunk
1051                0: default for training, use random dynamic chunk.
1052                <0: for decoding, use full chunk.
1053                >0: for decoding, use fixed chunk size as set.
1054            num_decoding_left_chunks: number of left chunks, this is for decoding,
1055            the chunk size is decoding_chunk_size.
1056                >=0: use num_decoding_left_chunks
1057                <0: use all left chunks
1058        Returns:
1059            encoder output tensor xs, and subsampled masks
1060            xs: padded output tensor (B, T' ~= T/subsample_rate, D)
1061            masks: torch.Tensor batch padding mask after subsample
1062                (B, 1, T' ~= T/subsample_rate)
1063        NOTE(xcsong):
1064            We pass the `__call__` method of the modules instead of `forward` to the
1065            checkpointing API because `__call__` attaches all the hooks of the module.
1066            https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
1067        """
1068        T = xs.size(1)
1069        masks = pad_mask.to(torch.bool).unsqueeze(1)  # (B, 1, T)
1070        xs, pos_emb = self.embed(xs)
1071        mask_pad = masks  # (B, 1, T/subsample_rate)
1072        chunk_masks = add_optional_chunk_mask(
1073            xs,
1074            masks,
1075            self.use_dynamic_chunk,
1076            self.use_dynamic_left_chunk,
1077            decoding_chunk_size,
1078            self.static_chunk_size,
1079            num_decoding_left_chunks,
1080        )
1081        if self.gradient_checkpointing and self.training:
1082            xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad)
1083        else:
1084            xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
1085        if self.normalize_before:
1086            xs = self.after_norm(xs)
1087        # Here we assume the mask is not changed in encoder layers, so just
1088        # return the masks before encoder layers, and the masks will be used
1089        # for cross attention with decoder later
1090        return xs, masks
class ConvolutionModule(torch.nn.modules.module.Module):
  8class ConvolutionModule(nn.Module):
  9    """ConvolutionModule in Conformer model."""
 10
 11    def __init__(
 12        self,
 13        channels: int,
 14        kernel_size: int = 15,
 15        activation: nn.Module = nn.ReLU(),
 16        norm: str = "batch_norm",
 17        causal: bool = False,
 18        bias: bool = True,
 19    ):
 20        """Construct an ConvolutionModule object.
 21        Args:
 22            channels (int): The number of channels of conv layers.
 23            kernel_size (int): Kernel size of conv layers.
 24            causal (int): Whether use causal convolution or not
 25        """
 26        super().__init__()
 27
 28        self.pointwise_conv1 = nn.Conv1d(
 29            channels,
 30            2 * channels,
 31            kernel_size=1,
 32            stride=1,
 33            padding=0,
 34            bias=bias,
 35        )
 36        # self.lorder is used to distinguish if it's a causal convolution,
 37        # if self.lorder > 0: it's a causal convolution, the input will be
 38        #    padded with self.lorder frames on the left in forward.
 39        # else: it's a symmetrical convolution
 40        if causal:
 41            padding = 0
 42            self.lorder = kernel_size - 1
 43        else:
 44            # kernel_size should be an odd number for none causal convolution
 45            assert (kernel_size - 1) % 2 == 0
 46            padding = (kernel_size - 1) // 2
 47            self.lorder = 0
 48        self.depthwise_conv = nn.Conv1d(
 49            channels,
 50            channels,
 51            kernel_size,
 52            stride=1,
 53            padding=padding,
 54            groups=channels,
 55            bias=bias,
 56        )
 57
 58        assert norm in ["batch_norm", "layer_norm"]
 59        if norm == "batch_norm":
 60            self.use_layer_norm = False
 61            self.norm = nn.BatchNorm1d(channels)
 62        else:
 63            self.use_layer_norm = True
 64            self.norm = nn.LayerNorm(channels)
 65
 66        self.pointwise_conv2 = nn.Conv1d(
 67            channels,
 68            channels,
 69            kernel_size=1,
 70            stride=1,
 71            padding=0,
 72            bias=bias,
 73        )
 74        self.activation = activation
 75
 76    def forward(
 77        self,
 78        x: torch.Tensor,
 79        mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
 80        cache: torch.Tensor = torch.zeros((0, 0, 0)),
 81    ) -> Tuple[torch.Tensor, torch.Tensor]:
 82        """Compute convolution module.
 83        Args:
 84            x (torch.Tensor): Input tensor (#batch, time, channels).
 85            mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
 86                (0, 0, 0) means fake mask.
 87            cache (torch.Tensor): left context cache, it is only
 88                used in causal convolution (#batch, channels, cache_t),
 89                (0, 0, 0) meas fake cache.
 90        Returns:
 91            torch.Tensor: Output tensor (#batch, time, channels).
 92        """
 93        # exchange the temporal dimension and the feature dimension
 94        x = x.transpose(1, 2)  # (#batch, channels, time)
 95
 96        # mask batch padding
 97        if mask_pad.size(2) > 0:  # time > 0
 98            x.masked_fill_(~mask_pad, 0.0)
 99
100        if self.lorder > 0:
101            if cache.size(2) == 0:  # cache_t == 0
102                x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
103            else:
104                assert cache.size(0) == x.size(0)  # equal batch
105                assert cache.size(1) == x.size(1)  # equal channel
106                x = torch.cat((cache, x), dim=2)
107            assert x.size(2) > self.lorder
108            new_cache = x[:, :, -self.lorder :]
109        else:
110            # It's better we just return None if no cache is required,
111            # However, for JIT export, here we just fake one tensor instead of
112            # None.
113            new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
114
115        # GLU mechanism
116        x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
117        x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
118
119        # 1D Depthwise Conv
120        x = self.depthwise_conv(x)
121        if self.use_layer_norm:
122            x = x.transpose(1, 2)
123        x = self.activation(self.norm(x))
124        if self.use_layer_norm:
125            x = x.transpose(1, 2)
126        x = self.pointwise_conv2(x)
127        # mask batch padding
128        if mask_pad.size(2) > 0:  # time > 0
129            x.masked_fill_(~mask_pad, 0.0)
130
131        return x.transpose(1, 2), new_cache

ConvolutionModule in Conformer model.

ConvolutionModule( channels: int, kernel_size: int = 15, activation: torch.nn.modules.module.Module = ReLU(), norm: str = 'batch_norm', causal: bool = False, bias: bool = True)
11    def __init__(
12        self,
13        channels: int,
14        kernel_size: int = 15,
15        activation: nn.Module = nn.ReLU(),
16        norm: str = "batch_norm",
17        causal: bool = False,
18        bias: bool = True,
19    ):
20        """Construct an ConvolutionModule object.
21        Args:
22            channels (int): The number of channels of conv layers.
23            kernel_size (int): Kernel size of conv layers.
24            causal (int): Whether use causal convolution or not
25        """
26        super().__init__()
27
28        self.pointwise_conv1 = nn.Conv1d(
29            channels,
30            2 * channels,
31            kernel_size=1,
32            stride=1,
33            padding=0,
34            bias=bias,
35        )
36        # self.lorder is used to distinguish if it's a causal convolution,
37        # if self.lorder > 0: it's a causal convolution, the input will be
38        #    padded with self.lorder frames on the left in forward.
39        # else: it's a symmetrical convolution
40        if causal:
41            padding = 0
42            self.lorder = kernel_size - 1
43        else:
44            # kernel_size should be an odd number for none causal convolution
45            assert (kernel_size - 1) % 2 == 0
46            padding = (kernel_size - 1) // 2
47            self.lorder = 0
48        self.depthwise_conv = nn.Conv1d(
49            channels,
50            channels,
51            kernel_size,
52            stride=1,
53            padding=padding,
54            groups=channels,
55            bias=bias,
56        )
57
58        assert norm in ["batch_norm", "layer_norm"]
59        if norm == "batch_norm":
60            self.use_layer_norm = False
61            self.norm = nn.BatchNorm1d(channels)
62        else:
63            self.use_layer_norm = True
64            self.norm = nn.LayerNorm(channels)
65
66        self.pointwise_conv2 = nn.Conv1d(
67            channels,
68            channels,
69            kernel_size=1,
70            stride=1,
71            padding=0,
72            bias=bias,
73        )
74        self.activation = activation

Construct an ConvolutionModule object. Args: channels (int): The number of channels of conv layers. kernel_size (int): Kernel size of conv layers. causal (int): Whether use causal convolution or not

pointwise_conv1
depthwise_conv
pointwise_conv2
activation
def forward( self, x: torch.Tensor, mask_pad: torch.Tensor = tensor([], size=(0, 0, 0), dtype=torch.bool), cache: torch.Tensor = tensor([], size=(0, 0, 0))) -> Tuple[torch.Tensor, torch.Tensor]:
 76    def forward(
 77        self,
 78        x: torch.Tensor,
 79        mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
 80        cache: torch.Tensor = torch.zeros((0, 0, 0)),
 81    ) -> Tuple[torch.Tensor, torch.Tensor]:
 82        """Compute convolution module.
 83        Args:
 84            x (torch.Tensor): Input tensor (#batch, time, channels).
 85            mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
 86                (0, 0, 0) means fake mask.
 87            cache (torch.Tensor): left context cache, it is only
 88                used in causal convolution (#batch, channels, cache_t),
 89                (0, 0, 0) meas fake cache.
 90        Returns:
 91            torch.Tensor: Output tensor (#batch, time, channels).
 92        """
 93        # exchange the temporal dimension and the feature dimension
 94        x = x.transpose(1, 2)  # (#batch, channels, time)
 95
 96        # mask batch padding
 97        if mask_pad.size(2) > 0:  # time > 0
 98            x.masked_fill_(~mask_pad, 0.0)
 99
100        if self.lorder > 0:
101            if cache.size(2) == 0:  # cache_t == 0
102                x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
103            else:
104                assert cache.size(0) == x.size(0)  # equal batch
105                assert cache.size(1) == x.size(1)  # equal channel
106                x = torch.cat((cache, x), dim=2)
107            assert x.size(2) > self.lorder
108            new_cache = x[:, :, -self.lorder :]
109        else:
110            # It's better we just return None if no cache is required,
111            # However, for JIT export, here we just fake one tensor instead of
112            # None.
113            new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
114
115        # GLU mechanism
116        x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
117        x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
118
119        # 1D Depthwise Conv
120        x = self.depthwise_conv(x)
121        if self.use_layer_norm:
122            x = x.transpose(1, 2)
123        x = self.activation(self.norm(x))
124        if self.use_layer_norm:
125            x = x.transpose(1, 2)
126        x = self.pointwise_conv2(x)
127        # mask batch padding
128        if mask_pad.size(2) > 0:  # time > 0
129            x.masked_fill_(~mask_pad, 0.0)
130
131        return x.transpose(1, 2), new_cache

Compute convolution module. Args: x (torch.Tensor): Input tensor (#batch, time, channels). mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), (0, 0, 0) means fake mask. cache (torch.Tensor): left context cache, it is only used in causal convolution (#batch, channels, cache_t), (0, 0, 0) meas fake cache. Returns: torch.Tensor: Output tensor (#batch, time, channels).

class PositionwiseFeedForward(torch.nn.modules.module.Module):
134class PositionwiseFeedForward(torch.nn.Module):
135    """Positionwise feed forward layer.
136
137    FeedForward are appied on each position of the sequence.
138    The output dim is same with the input dim.
139
140    Args:
141        idim (int): Input dimenstion.
142        hidden_units (int): The number of hidden units.
143        dropout_rate (float): Dropout rate.
144        activation (torch.nn.Module): Activation function
145    """
146
147    def __init__(
148        self,
149        idim: int,
150        hidden_units: int,
151        dropout_rate: float,
152        activation: torch.nn.Module = torch.nn.ReLU(),
153    ):
154        """Construct a PositionwiseFeedForward object."""
155        super(PositionwiseFeedForward, self).__init__()
156        self.w_1 = torch.nn.Linear(idim, hidden_units)
157        self.activation = activation
158        self.dropout = torch.nn.Dropout(dropout_rate)
159        self.w_2 = torch.nn.Linear(hidden_units, idim)
160
161    def forward(self, xs: torch.Tensor) -> torch.Tensor:
162        """Forward function.
163
164        Args:
165            xs: input tensor (B, L, D)
166        Returns:
167            output tensor, (B, L, D)
168        """
169        return self.w_2(self.dropout(self.activation(self.w_1(xs))))

Positionwise feed forward layer.

FeedForward are appied on each position of the sequence. The output dim is same with the input dim.

Args: idim (int): Input dimenstion. hidden_units (int): The number of hidden units. dropout_rate (float): Dropout rate. activation (torch.nn.Module): Activation function

PositionwiseFeedForward( idim: int, hidden_units: int, dropout_rate: float, activation: torch.nn.modules.module.Module = ReLU())
147    def __init__(
148        self,
149        idim: int,
150        hidden_units: int,
151        dropout_rate: float,
152        activation: torch.nn.Module = torch.nn.ReLU(),
153    ):
154        """Construct a PositionwiseFeedForward object."""
155        super(PositionwiseFeedForward, self).__init__()
156        self.w_1 = torch.nn.Linear(idim, hidden_units)
157        self.activation = activation
158        self.dropout = torch.nn.Dropout(dropout_rate)
159        self.w_2 = torch.nn.Linear(hidden_units, idim)

Construct a PositionwiseFeedForward object.

w_1
activation
dropout
w_2
def forward(self, xs: torch.Tensor) -> torch.Tensor:
161    def forward(self, xs: torch.Tensor) -> torch.Tensor:
162        """Forward function.
163
164        Args:
165            xs: input tensor (B, L, D)
166        Returns:
167            output tensor, (B, L, D)
168        """
169        return self.w_2(self.dropout(self.activation(self.w_1(xs))))

Forward function.

Args: xs: input tensor (B, L, D) Returns: output tensor, (B, L, D)

class Swish(torch.nn.modules.module.Module):
172class Swish(torch.nn.Module):
173    """Construct an Swish object."""
174
175    def forward(self, x: torch.Tensor) -> torch.Tensor:
176        """Return Swish activation function."""
177        return x * torch.sigmoid(x)

Construct an Swish object.

def forward(self, x: torch.Tensor) -> torch.Tensor:
175    def forward(self, x: torch.Tensor) -> torch.Tensor:
176        """Return Swish activation function."""
177        return x * torch.sigmoid(x)

Return Swish activation function.

class MultiHeadedAttention(torch.nn.modules.module.Module):
180class MultiHeadedAttention(nn.Module):
181    """Multi-Head Attention layer.
182
183    Args:
184        n_head (int): The number of heads.
185        n_feat (int): The number of features.
186        dropout_rate (float): Dropout rate.
187
188    """
189
190    def __init__(
191        self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
192    ):
193        """Construct an MultiHeadedAttention object."""
194        super().__init__()
195        assert n_feat % n_head == 0
196        # We assume d_v always equals d_k
197        self.d_k = n_feat // n_head
198        self.h = n_head
199        self.linear_q = nn.Linear(n_feat, n_feat)
200        self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
201        self.linear_v = nn.Linear(n_feat, n_feat)
202        self.linear_out = nn.Linear(n_feat, n_feat)
203        self.dropout = nn.Dropout(p=dropout_rate)
204
205    def forward_qkv(
206        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
207    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208        """Transform query, key and value.
209
210        Args:
211            query (torch.Tensor): Query tensor (#batch, time1, size).
212            key (torch.Tensor): Key tensor (#batch, time2, size).
213            value (torch.Tensor): Value tensor (#batch, time2, size).
214
215        Returns:
216            torch.Tensor: Transformed query tensor, size
217                (#batch, n_head, time1, d_k).
218            torch.Tensor: Transformed key tensor, size
219                (#batch, n_head, time2, d_k).
220            torch.Tensor: Transformed value tensor, size
221                (#batch, n_head, time2, d_k).
222
223        """
224        n_batch = query.size(0)
225        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
226        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
227        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
228        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
229        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
230        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
231        return q, k, v
232
233    def forward_attention(
234        self,
235        value: torch.Tensor,
236        scores: torch.Tensor,
237        mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
238    ) -> torch.Tensor:
239        """Compute attention context vector.
240
241        Args:
242            value (torch.Tensor): Transformed value, size
243                (#batch, n_head, time2, d_k).
244            scores (torch.Tensor): Attention score, size
245                (#batch, n_head, time1, time2).
246            mask (torch.Tensor): Mask, size (#batch, 1, time2) or
247                (#batch, time1, time2), (0, 0, 0) means fake mask.
248
249        Returns:
250            torch.Tensor: Transformed value (#batch, time1, d_model)
251                weighted by the attention score (#batch, time1, time2).
252
253        """
254        n_batch = value.size(0)
255
256        if mask.size(2) > 0:  # time2 > 0
257            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
258            # For last chunk, time2 might be larger than scores.size(-1)
259            mask = mask[:, :, :, : scores.size(-1)]  # (batch, 1, *, time2)
260            scores = scores.masked_fill(mask, -float("inf"))
261            attn = torch.softmax(scores, dim=-1).masked_fill(
262                mask, 0.0
263            )  # (batch, head, time1, time2)
264
265        else:
266            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
267
268        p_attn = self.dropout(attn)
269        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
270        x = (
271            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
272        )  # (batch, time1, d_model)
273
274        return self.linear_out(x)  # (batch, time1, d_model)
275
276    def forward(
277        self,
278        query: torch.Tensor,
279        key: torch.Tensor,
280        value: torch.Tensor,
281        mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
282        pos_emb: torch.Tensor = torch.empty(0),
283        cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
284    ) -> Tuple[torch.Tensor, torch.Tensor]:
285        """Compute scaled dot product attention.
286
287        Args:
288            query (torch.Tensor): Query tensor (#batch, time1, size).
289            key (torch.Tensor): Key tensor (#batch, time2, size).
290            value (torch.Tensor): Value tensor (#batch, time2, size).
291            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
292                (#batch, time1, time2).
293                1.When applying cross attention between decoder and encoder,
294                the batch padding mask for input is in (#batch, 1, T) shape.
295                2.When applying self attention of encoder,
296                the mask is in (#batch, T, T)  shape.
297                3.When applying self attention of decoder,
298                the mask is in (#batch, L, L)  shape.
299                4.If the different position in decoder see different block
300                of the encoder, such as Mocha, the passed in mask could be
301                in (#batch, L, T) shape. But there is no such case in current
302                CosyVoice.
303            cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
304                where `cache_t == chunk_size * num_decoding_left_chunks`
305                and `head * d_k == size`
306
307
308        Returns:
309            torch.Tensor: Output tensor (#batch, time1, d_model).
310            torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
311                where `cache_t == chunk_size * num_decoding_left_chunks`
312                and `head * d_k == size`
313
314        """
315        q, k, v = self.forward_qkv(query, key, value)
316        if cache.size(0) > 0:
317            key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
318            k = torch.cat([key_cache, k], dim=2)
319            v = torch.cat([value_cache, v], dim=2)
320        new_cache = torch.cat((k, v), dim=-1)
321
322        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
323        return self.forward_attention(v, scores, mask), new_cache

Multi-Head Attention layer.

Args: n_head (int): The number of heads. n_feat (int): The number of features. dropout_rate (float): Dropout rate.

MultiHeadedAttention(n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True)
190    def __init__(
191        self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
192    ):
193        """Construct an MultiHeadedAttention object."""
194        super().__init__()
195        assert n_feat % n_head == 0
196        # We assume d_v always equals d_k
197        self.d_k = n_feat // n_head
198        self.h = n_head
199        self.linear_q = nn.Linear(n_feat, n_feat)
200        self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
201        self.linear_v = nn.Linear(n_feat, n_feat)
202        self.linear_out = nn.Linear(n_feat, n_feat)
203        self.dropout = nn.Dropout(p=dropout_rate)

Construct an MultiHeadedAttention object.

d_k
h
linear_q
linear_k
linear_v
linear_out
dropout
def forward_qkv( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
205    def forward_qkv(
206        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
207    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208        """Transform query, key and value.
209
210        Args:
211            query (torch.Tensor): Query tensor (#batch, time1, size).
212            key (torch.Tensor): Key tensor (#batch, time2, size).
213            value (torch.Tensor): Value tensor (#batch, time2, size).
214
215        Returns:
216            torch.Tensor: Transformed query tensor, size
217                (#batch, n_head, time1, d_k).
218            torch.Tensor: Transformed key tensor, size
219                (#batch, n_head, time2, d_k).
220            torch.Tensor: Transformed value tensor, size
221                (#batch, n_head, time2, d_k).
222
223        """
224        n_batch = query.size(0)
225        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
226        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
227        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
228        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
229        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
230        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
231        return q, k, v

Transform query, key and value.

Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size).

Returns: torch.Tensor: Transformed query tensor, size (#batch, n_head, time1, d_k). torch.Tensor: Transformed key tensor, size (#batch, n_head, time2, d_k). torch.Tensor: Transformed value tensor, size (#batch, n_head, time2, d_k).

def forward_attention( self, value: torch.Tensor, scores: torch.Tensor, mask: torch.Tensor = tensor([], size=(0, 0, 0), dtype=torch.bool)) -> torch.Tensor:
233    def forward_attention(
234        self,
235        value: torch.Tensor,
236        scores: torch.Tensor,
237        mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
238    ) -> torch.Tensor:
239        """Compute attention context vector.
240
241        Args:
242            value (torch.Tensor): Transformed value, size
243                (#batch, n_head, time2, d_k).
244            scores (torch.Tensor): Attention score, size
245                (#batch, n_head, time1, time2).
246            mask (torch.Tensor): Mask, size (#batch, 1, time2) or
247                (#batch, time1, time2), (0, 0, 0) means fake mask.
248
249        Returns:
250            torch.Tensor: Transformed value (#batch, time1, d_model)
251                weighted by the attention score (#batch, time1, time2).
252
253        """
254        n_batch = value.size(0)
255
256        if mask.size(2) > 0:  # time2 > 0
257            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
258            # For last chunk, time2 might be larger than scores.size(-1)
259            mask = mask[:, :, :, : scores.size(-1)]  # (batch, 1, *, time2)
260            scores = scores.masked_fill(mask, -float("inf"))
261            attn = torch.softmax(scores, dim=-1).masked_fill(
262                mask, 0.0
263            )  # (batch, head, time1, time2)
264
265        else:
266            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
267
268        p_attn = self.dropout(attn)
269        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
270        x = (
271            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
272        )  # (batch, time1, d_model)
273
274        return self.linear_out(x)  # (batch, time1, d_model)

Compute attention context vector.

Args: value (torch.Tensor): Transformed value, size (#batch, n_head, time2, d_k). scores (torch.Tensor): Attention score, size (#batch, n_head, time1, time2). mask (torch.Tensor): Mask, size (#batch, 1, time2) or (#batch, time1, time2), (0, 0, 0) means fake mask.

Returns: torch.Tensor: Transformed value (#batch, time1, d_model) weighted by the attention score (#batch, time1, time2).

def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = tensor([], size=(0, 0, 0), dtype=torch.bool), pos_emb: torch.Tensor = tensor([]), cache: torch.Tensor = tensor([], size=(0, 0, 0, 0))) -> Tuple[torch.Tensor, torch.Tensor]:
276    def forward(
277        self,
278        query: torch.Tensor,
279        key: torch.Tensor,
280        value: torch.Tensor,
281        mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
282        pos_emb: torch.Tensor = torch.empty(0),
283        cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
284    ) -> Tuple[torch.Tensor, torch.Tensor]:
285        """Compute scaled dot product attention.
286
287        Args:
288            query (torch.Tensor): Query tensor (#batch, time1, size).
289            key (torch.Tensor): Key tensor (#batch, time2, size).
290            value (torch.Tensor): Value tensor (#batch, time2, size).
291            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
292                (#batch, time1, time2).
293                1.When applying cross attention between decoder and encoder,
294                the batch padding mask for input is in (#batch, 1, T) shape.
295                2.When applying self attention of encoder,
296                the mask is in (#batch, T, T)  shape.
297                3.When applying self attention of decoder,
298                the mask is in (#batch, L, L)  shape.
299                4.If the different position in decoder see different block
300                of the encoder, such as Mocha, the passed in mask could be
301                in (#batch, L, T) shape. But there is no such case in current
302                CosyVoice.
303            cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
304                where `cache_t == chunk_size * num_decoding_left_chunks`
305                and `head * d_k == size`
306
307
308        Returns:
309            torch.Tensor: Output tensor (#batch, time1, d_model).
310            torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
311                where `cache_t == chunk_size * num_decoding_left_chunks`
312                and `head * d_k == size`
313
314        """
315        q, k, v = self.forward_qkv(query, key, value)
316        if cache.size(0) > 0:
317            key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
318            k = torch.cat([key_cache, k], dim=2)
319            v = torch.cat([value_cache, v], dim=2)
320        new_cache = torch.cat((k, v), dim=-1)
321
322        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
323        return self.forward_attention(v, scores, mask), new_cache

Compute scaled dot product attention.

Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). mask (torch.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). 1.When applying cross attention between decoder and encoder, the batch padding mask for input is in (#batch, 1, T) shape. 2.When applying self attention of encoder, the mask is in (#batch, T, T) shape. 3.When applying self attention of decoder, the mask is in (#batch, L, L) shape. 4.If the different position in decoder see different block of the encoder, such as Mocha, the passed in mask could be in (#batch, L, T) shape. But there is no such case in current CosyVoice. cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), where cache_t == chunk_size * num_decoding_left_chunks and head * d_k == size

Returns: torch.Tensor: Output tensor (#batch, time1, d_model). torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) where cache_t == chunk_size * num_decoding_left_chunks and head * d_k == size

class RelPositionMultiHeadedAttention(MultiHeadedAttention):
326class RelPositionMultiHeadedAttention(MultiHeadedAttention):
327    """Multi-Head Attention layer with relative position encoding.
328    Paper: https://arxiv.org/abs/1901.02860
329    Args:
330        n_head (int): The number of heads.
331        n_feat (int): The number of features.
332        dropout_rate (float): Dropout rate.
333    """
334
335    def __init__(
336        self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
337    ):
338        """Construct an RelPositionMultiHeadedAttention object."""
339        super().__init__(n_head, n_feat, dropout_rate, key_bias)
340        # linear transformation for positional encoding
341        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
342        # these two learnable bias are used in matrix c and matrix d
343        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
344        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
345        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
346        torch.nn.init.xavier_uniform_(self.pos_bias_u)
347        torch.nn.init.xavier_uniform_(self.pos_bias_v)
348
349    def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
350        """Compute relative positional encoding.
351
352        Args:
353            x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
354            time1 means the length of query vector.
355
356        Returns:
357            torch.Tensor: Output tensor.
358
359        """
360        zero_pad = torch.zeros(
361            (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
362        )
363        x_padded = torch.cat([zero_pad, x], dim=-1)
364
365        x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
366        x = x_padded[:, :, 1:].view_as(x)[
367            :, :, :, : x.size(-1) // 2 + 1
368        ]  # only keep the positions from 0 to time2
369        return x
370
371    def forward(
372        self,
373        query: torch.Tensor,
374        key: torch.Tensor,
375        value: torch.Tensor,
376        mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
377        pos_emb: torch.Tensor = torch.empty(0),
378        cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
379    ) -> Tuple[torch.Tensor, torch.Tensor]:
380        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
381        Args:
382            query (torch.Tensor): Query tensor (#batch, time1, size).
383            key (torch.Tensor): Key tensor (#batch, time2, size).
384            value (torch.Tensor): Value tensor (#batch, time2, size).
385            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
386                (#batch, time1, time2), (0, 0, 0) means fake mask.
387            pos_emb (torch.Tensor): Positional embedding tensor
388                (#batch, time2, size).
389            cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
390                where `cache_t == chunk_size * num_decoding_left_chunks`
391                and `head * d_k == size`
392        Returns:
393            torch.Tensor: Output tensor (#batch, time1, d_model).
394            torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
395                where `cache_t == chunk_size * num_decoding_left_chunks`
396                and `head * d_k == size`
397        """
398        q, k, v = self.forward_qkv(query, key, value)
399        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
400
401        if cache.size(0) > 0:
402            key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
403            k = torch.cat([key_cache, k], dim=2)
404            v = torch.cat([value_cache, v], dim=2)
405        # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
406        #   non-trivial to calculate `next_cache_start` here.
407        new_cache = torch.cat((k, v), dim=-1)
408
409        n_batch_pos = pos_emb.size(0)
410        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
411        p = p.transpose(1, 2)  # (batch, head, time1, d_k)
412
413        # (batch, head, time1, d_k)
414        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
415        # (batch, head, time1, d_k)
416        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
417
418        # compute attention score
419        # first compute matrix a and matrix c
420        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
421        # (batch, head, time1, time2)
422        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
423
424        # compute matrix b and matrix d
425        # (batch, head, time1, time2)
426        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
427        # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
428        if matrix_ac.shape != matrix_bd.shape:
429            matrix_bd = self.rel_shift(matrix_bd)
430
431        scores = (matrix_ac + matrix_bd) / math.sqrt(
432            self.d_k
433        )  # (batch, head, time1, time2)
434
435        return self.forward_attention(v, scores, mask), new_cache

Multi-Head Attention layer with relative position encoding. Paper: https://arxiv.org/abs/1901.02860 Args: n_head (int): The number of heads. n_feat (int): The number of features. dropout_rate (float): Dropout rate.

RelPositionMultiHeadedAttention(n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True)
335    def __init__(
336        self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
337    ):
338        """Construct an RelPositionMultiHeadedAttention object."""
339        super().__init__(n_head, n_feat, dropout_rate, key_bias)
340        # linear transformation for positional encoding
341        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
342        # these two learnable bias are used in matrix c and matrix d
343        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
344        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
345        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
346        torch.nn.init.xavier_uniform_(self.pos_bias_u)
347        torch.nn.init.xavier_uniform_(self.pos_bias_v)

Construct an RelPositionMultiHeadedAttention object.

linear_pos
pos_bias_u
pos_bias_v
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
349    def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
350        """Compute relative positional encoding.
351
352        Args:
353            x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
354            time1 means the length of query vector.
355
356        Returns:
357            torch.Tensor: Output tensor.
358
359        """
360        zero_pad = torch.zeros(
361            (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
362        )
363        x_padded = torch.cat([zero_pad, x], dim=-1)
364
365        x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
366        x = x_padded[:, :, 1:].view_as(x)[
367            :, :, :, : x.size(-1) // 2 + 1
368        ]  # only keep the positions from 0 to time2
369        return x

Compute relative positional encoding.

Args: x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector.

Returns: torch.Tensor: Output tensor.

def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = tensor([], size=(0, 0, 0), dtype=torch.bool), pos_emb: torch.Tensor = tensor([]), cache: torch.Tensor = tensor([], size=(0, 0, 0, 0))) -> Tuple[torch.Tensor, torch.Tensor]:
371    def forward(
372        self,
373        query: torch.Tensor,
374        key: torch.Tensor,
375        value: torch.Tensor,
376        mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
377        pos_emb: torch.Tensor = torch.empty(0),
378        cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
379    ) -> Tuple[torch.Tensor, torch.Tensor]:
380        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
381        Args:
382            query (torch.Tensor): Query tensor (#batch, time1, size).
383            key (torch.Tensor): Key tensor (#batch, time2, size).
384            value (torch.Tensor): Value tensor (#batch, time2, size).
385            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
386                (#batch, time1, time2), (0, 0, 0) means fake mask.
387            pos_emb (torch.Tensor): Positional embedding tensor
388                (#batch, time2, size).
389            cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
390                where `cache_t == chunk_size * num_decoding_left_chunks`
391                and `head * d_k == size`
392        Returns:
393            torch.Tensor: Output tensor (#batch, time1, d_model).
394            torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
395                where `cache_t == chunk_size * num_decoding_left_chunks`
396                and `head * d_k == size`
397        """
398        q, k, v = self.forward_qkv(query, key, value)
399        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
400
401        if cache.size(0) > 0:
402            key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
403            k = torch.cat([key_cache, k], dim=2)
404            v = torch.cat([value_cache, v], dim=2)
405        # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
406        #   non-trivial to calculate `next_cache_start` here.
407        new_cache = torch.cat((k, v), dim=-1)
408
409        n_batch_pos = pos_emb.size(0)
410        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
411        p = p.transpose(1, 2)  # (batch, head, time1, d_k)
412
413        # (batch, head, time1, d_k)
414        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
415        # (batch, head, time1, d_k)
416        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
417
418        # compute attention score
419        # first compute matrix a and matrix c
420        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
421        # (batch, head, time1, time2)
422        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
423
424        # compute matrix b and matrix d
425        # (batch, head, time1, time2)
426        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
427        # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
428        if matrix_ac.shape != matrix_bd.shape:
429            matrix_bd = self.rel_shift(matrix_bd)
430
431        scores = (matrix_ac + matrix_bd) / math.sqrt(
432            self.d_k
433        )  # (batch, head, time1, time2)
434
435        return self.forward_attention(v, scores, mask), new_cache

Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). mask (torch.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2), (0, 0, 0) means fake mask. pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size). cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), where cache_t == chunk_size * num_decoding_left_chunks and head * d_k == size Returns: torch.Tensor: Output tensor (#batch, time1, d_model). torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) where cache_t == chunk_size * num_decoding_left_chunks and head * d_k == size

def subsequent_mask(size: int, device: torch.device = device(type='cpu')) -> torch.Tensor:
438def subsequent_mask(
439    size: int,
440    device: torch.device = torch.device("cpu"),
441) -> torch.Tensor:
442    """Create mask for subsequent steps (size, size).
443
444    This mask is used only in decoder which works in an auto-regressive mode.
445    This means the current step could only do attention with its left steps.
446
447    In encoder, fully attention is used when streaming is not necessary and
448    the sequence is not long. In this  case, no attention mask is needed.
449
450    When streaming is need, chunk-based attention is used in encoder. See
451    subsequent_chunk_mask for the chunk-based attention mask.
452
453    Args:
454        size (int): size of mask
455        str device (str): "cpu" or "cuda" or torch.Tensor.device
456        dtype (torch.device): result dtype
457
458    Returns:
459        torch.Tensor: mask
460
461    Examples:
462        >>> subsequent_mask(3)
463        [[1, 0, 0],
464         [1, 1, 0],
465         [1, 1, 1]]
466    """
467    arange = torch.arange(size, device=device)
468    mask = arange.expand(size, size)
469    arange = arange.unsqueeze(-1)
470    mask = mask <= arange
471    return mask

Create mask for subsequent steps (size, size).

This mask is used only in decoder which works in an auto-regressive mode. This means the current step could only do attention with its left steps.

In encoder, fully attention is used when streaming is not necessary and the sequence is not long. In this case, no attention mask is needed.

When streaming is need, chunk-based attention is used in encoder. See subsequent_chunk_mask for the chunk-based attention mask.

Args: size (int): size of mask str device (str): "cpu" or "cuda" or torch.Tensor.device dtype (torch.device): result dtype

Returns: torch.Tensor: mask

Examples:

subsequent_mask(3) [[1, 0, 0], [1, 1, 0], [1, 1, 1]]

def subsequent_chunk_mask( size: int, chunk_size: int, num_left_chunks: int = -1, device: torch.device = device(type='cpu')) -> torch.Tensor:
474def subsequent_chunk_mask(
475    size: int,
476    chunk_size: int,
477    num_left_chunks: int = -1,
478    device: torch.device = torch.device("cpu"),
479) -> torch.Tensor:
480    """Create mask for subsequent steps (size, size) with chunk size,
481       this is for streaming encoder
482
483    Args:
484        size (int): size of mask
485        chunk_size (int): size of chunk
486        num_left_chunks (int): number of left chunks
487            <0: use full chunk
488            >=0: use num_left_chunks
489        device (torch.device): "cpu" or "cuda" or torch.Tensor.device
490
491    Returns:
492        torch.Tensor: mask
493
494    Examples:
495        >>> subsequent_chunk_mask(4, 2)
496        [[1, 1, 0, 0],
497         [1, 1, 0, 0],
498         [1, 1, 1, 1],
499         [1, 1, 1, 1]]
500    """
501    ret = torch.zeros(size, size, device=device, dtype=torch.bool)
502    for i in range(size):
503        if num_left_chunks < 0:
504            start = 0
505        else:
506            start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
507        ending = min((i // chunk_size + 1) * chunk_size, size)
508        ret[i, start:ending] = True
509    return ret

Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder

Args: size (int): size of mask chunk_size (int): size of chunk num_left_chunks (int): number of left chunks <0: use full chunk

=0: use num_left_chunks device (torch.device): "cpu" or "cuda" or torch.Tensor.device

Returns: torch.Tensor: mask

Examples:

subsequent_chunk_mask(4, 2) [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]

def add_optional_chunk_mask( xs: torch.Tensor, masks: torch.Tensor, use_dynamic_chunk: bool, use_dynamic_left_chunk: bool, decoding_chunk_size: int, static_chunk_size: int, num_decoding_left_chunks: int, enable_full_context: bool = True):
512def add_optional_chunk_mask(
513    xs: torch.Tensor,
514    masks: torch.Tensor,
515    use_dynamic_chunk: bool,
516    use_dynamic_left_chunk: bool,
517    decoding_chunk_size: int,
518    static_chunk_size: int,
519    num_decoding_left_chunks: int,
520    enable_full_context: bool = True,
521):
522    """Apply optional mask for encoder.
523
524    Args:
525        xs (torch.Tensor): padded input, (B, L, D), L for max length
526        mask (torch.Tensor): mask for xs, (B, 1, L)
527        use_dynamic_chunk (bool): whether to use dynamic chunk or not
528        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
529            training.
530        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
531            0: default for training, use random dynamic chunk.
532            <0: for decoding, use full chunk.
533            >0: for decoding, use fixed chunk size as set.
534        static_chunk_size (int): chunk size for static chunk training/decoding
535            if it's greater than 0, if use_dynamic_chunk is true,
536            this parameter will be ignored
537        num_decoding_left_chunks: number of left chunks, this is for decoding,
538            the chunk size is decoding_chunk_size.
539            >=0: use num_decoding_left_chunks
540            <0: use all left chunks
541        enable_full_context (bool):
542            True: chunk size is either [1, 25] or full context(max_len)
543            False: chunk size ~ U[1, 25]
544
545    Returns:
546        torch.Tensor: chunk mask of the input xs.
547    """
548    # Whether to use chunk mask or not
549    if use_dynamic_chunk:
550        max_len = xs.size(1)
551        if decoding_chunk_size < 0:
552            chunk_size = max_len
553            num_left_chunks = -1
554        elif decoding_chunk_size > 0:
555            chunk_size = decoding_chunk_size
556            num_left_chunks = num_decoding_left_chunks
557        else:
558            # chunk size is either [1, 25] or full context(max_len).
559            # Since we use 4 times subsampling and allow up to 1s(100 frames)
560            # delay, the maximum frame is 100 / 4 = 25.
561            chunk_size = torch.randint(1, max_len, (1,)).item()
562            num_left_chunks = -1
563            if chunk_size > max_len // 2 and enable_full_context:
564                chunk_size = max_len
565            else:
566                chunk_size = chunk_size % 25 + 1
567                if use_dynamic_left_chunk:
568                    max_left_chunks = (max_len - 1) // chunk_size
569                    num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item()
570        chunk_masks = subsequent_chunk_mask(
571            xs.size(1), chunk_size, num_left_chunks, xs.device
572        )  # (L, L)
573        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
574        chunk_masks = masks & chunk_masks  # (B, L, L)
575    elif static_chunk_size > 0:
576        num_left_chunks = num_decoding_left_chunks
577        chunk_masks = subsequent_chunk_mask(
578            xs.size(1), static_chunk_size, num_left_chunks, xs.device
579        )  # (L, L)
580        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
581        chunk_masks = masks & chunk_masks  # (B, L, L)
582    else:
583        chunk_masks = masks
584    return chunk_masks

Apply optional mask for encoder.

Args: xs (torch.Tensor): padded input, (B, L, D), L for max length mask (torch.Tensor): mask for xs, (B, 1, L) use_dynamic_chunk (bool): whether to use dynamic chunk or not use_dynamic_left_chunk (bool): whether to use dynamic left chunk for training. decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 0: default for training, use random dynamic chunk. <0: for decoding, use full chunk.

0: for decoding, use fixed chunk size as set. static_chunk_size (int): chunk size for static chunk training/decoding if it's greater than 0, if use_dynamic_chunk is true, this parameter will be ignored num_decoding_left_chunks: number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. =0: use num_decoding_left_chunks <0: use all left chunks enable_full_context (bool): True: chunk size is either [1, 25] or full context(max_len) False: chunk size ~ U[1, 25]

Returns: torch.Tensor: chunk mask of the input xs.

class ConformerEncoderLayer(torch.nn.modules.module.Module):
587class ConformerEncoderLayer(nn.Module):
588    """Encoder layer module.
589    Args:
590        size (int): Input dimension.
591        self_attn (torch.nn.Module): Self-attention module instance.
592            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
593            instance can be used as the argument.
594        feed_forward (torch.nn.Module): Feed-forward module instance.
595            `PositionwiseFeedForward` instance can be used as the argument.
596        feed_forward_macaron (torch.nn.Module): Additional feed-forward module
597             instance.
598            `PositionwiseFeedForward` instance can be used as the argument.
599        conv_module (torch.nn.Module): Convolution module instance.
600            `ConvlutionModule` instance can be used as the argument.
601        dropout_rate (float): Dropout rate.
602        normalize_before (bool):
603            True: use layer_norm before each sub-block.
604            False: use layer_norm after each sub-block.
605    """
606
607    def __init__(
608        self,
609        size: int,
610        self_attn: torch.nn.Module,
611        feed_forward: Optional[nn.Module] = None,
612        feed_forward_macaron: Optional[nn.Module] = None,
613        conv_module: Optional[nn.Module] = None,
614        dropout_rate: float = 0.1,
615        normalize_before: bool = True,
616    ):
617        """Construct an EncoderLayer object."""
618        super().__init__()
619        self.self_attn = self_attn
620        self.feed_forward = feed_forward
621        self.feed_forward_macaron = feed_forward_macaron
622        self.conv_module = conv_module
623        self.norm_ff = nn.LayerNorm(size, eps=1e-5)  # for the FNN module
624        self.norm_mha = nn.LayerNorm(size, eps=1e-5)  # for the MHA module
625        if feed_forward_macaron is not None:
626            self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
627            self.ff_scale = 0.5
628        else:
629            self.ff_scale = 1.0
630        if self.conv_module is not None:
631            self.norm_conv = nn.LayerNorm(size, eps=1e-5)  # for the CNN module
632            self.norm_final = nn.LayerNorm(
633                size, eps=1e-5
634            )  # for the final output of the block
635        self.dropout = nn.Dropout(dropout_rate)
636        self.size = size
637        self.normalize_before = normalize_before
638
639    def forward(
640        self,
641        x: torch.Tensor,
642        mask: torch.Tensor,
643        pos_emb: torch.Tensor,
644        mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
645        att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
646        cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
647    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
648        """Compute encoded features.
649
650        Args:
651            x (torch.Tensor): (#batch, time, size)
652            mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
653                (0, 0, 0) means fake mask.
654            pos_emb (torch.Tensor): positional encoding, must not be None
655                for ConformerEncoderLayer.
656            mask_pad (torch.Tensor): batch padding mask used for conv module.
657                (#batch, 1,time), (0, 0, 0) means fake mask.
658            att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
659                (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
660            cnn_cache (torch.Tensor): Convolution cache in conformer layer
661                (#batch=1, size, cache_t2)
662        Returns:
663            torch.Tensor: Output tensor (#batch, time, size).
664            torch.Tensor: Mask tensor (#batch, time, time).
665            torch.Tensor: att_cache tensor,
666                (#batch=1, head, cache_t1 + time, d_k * 2).
667            torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
668        """
669
670        # whether to use macaron style
671        if self.feed_forward_macaron is not None:
672            residual = x
673            if self.normalize_before:
674                x = self.norm_ff_macaron(x)
675            x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
676            if not self.normalize_before:
677                x = self.norm_ff_macaron(x)
678
679        # multi-headed self-attention module
680        residual = x
681        if self.normalize_before:
682            x = self.norm_mha(x)
683        x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
684        x = residual + self.dropout(x_att)
685        if not self.normalize_before:
686            x = self.norm_mha(x)
687
688        # convolution module
689        # Fake new cnn cache here, and then change it in conv_module
690        new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
691        if self.conv_module is not None:
692            residual = x
693            if self.normalize_before:
694                x = self.norm_conv(x)
695            x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
696            x = residual + self.dropout(x)
697
698            if not self.normalize_before:
699                x = self.norm_conv(x)
700
701        # feed forward module
702        residual = x
703        if self.normalize_before:
704            x = self.norm_ff(x)
705
706        x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
707        if not self.normalize_before:
708            x = self.norm_ff(x)
709
710        if self.conv_module is not None:
711            x = self.norm_final(x)
712
713        return x, mask, new_att_cache, new_cnn_cache

Encoder layer module. Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. MultiHeadedAttention or RelPositionMultiHeadedAttention instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. PositionwiseFeedForward instance can be used as the argument. feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. PositionwiseFeedForward instance can be used as the argument. conv_module (torch.nn.Module): Convolution module instance. ConvlutionModule instance can be used as the argument. dropout_rate (float): Dropout rate. normalize_before (bool): True: use layer_norm before each sub-block. False: use layer_norm after each sub-block.

ConformerEncoderLayer( size: int, self_attn: torch.nn.modules.module.Module, feed_forward: Optional[torch.nn.modules.module.Module] = None, feed_forward_macaron: Optional[torch.nn.modules.module.Module] = None, conv_module: Optional[torch.nn.modules.module.Module] = None, dropout_rate: float = 0.1, normalize_before: bool = True)
607    def __init__(
608        self,
609        size: int,
610        self_attn: torch.nn.Module,
611        feed_forward: Optional[nn.Module] = None,
612        feed_forward_macaron: Optional[nn.Module] = None,
613        conv_module: Optional[nn.Module] = None,
614        dropout_rate: float = 0.1,
615        normalize_before: bool = True,
616    ):
617        """Construct an EncoderLayer object."""
618        super().__init__()
619        self.self_attn = self_attn
620        self.feed_forward = feed_forward
621        self.feed_forward_macaron = feed_forward_macaron
622        self.conv_module = conv_module
623        self.norm_ff = nn.LayerNorm(size, eps=1e-5)  # for the FNN module
624        self.norm_mha = nn.LayerNorm(size, eps=1e-5)  # for the MHA module
625        if feed_forward_macaron is not None:
626            self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
627            self.ff_scale = 0.5
628        else:
629            self.ff_scale = 1.0
630        if self.conv_module is not None:
631            self.norm_conv = nn.LayerNorm(size, eps=1e-5)  # for the CNN module
632            self.norm_final = nn.LayerNorm(
633                size, eps=1e-5
634            )  # for the final output of the block
635        self.dropout = nn.Dropout(dropout_rate)
636        self.size = size
637        self.normalize_before = normalize_before

Construct an EncoderLayer object.

self_attn
feed_forward
feed_forward_macaron
conv_module
norm_ff
norm_mha
dropout
size
normalize_before
def forward( self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor = tensor([], size=(0, 0, 0), dtype=torch.bool), att_cache: torch.Tensor = tensor([], size=(0, 0, 0, 0)), cnn_cache: torch.Tensor = tensor([], size=(0, 0, 0, 0))) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
639    def forward(
640        self,
641        x: torch.Tensor,
642        mask: torch.Tensor,
643        pos_emb: torch.Tensor,
644        mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
645        att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
646        cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
647    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
648        """Compute encoded features.
649
650        Args:
651            x (torch.Tensor): (#batch, time, size)
652            mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
653                (0, 0, 0) means fake mask.
654            pos_emb (torch.Tensor): positional encoding, must not be None
655                for ConformerEncoderLayer.
656            mask_pad (torch.Tensor): batch padding mask used for conv module.
657                (#batch, 1,time), (0, 0, 0) means fake mask.
658            att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
659                (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
660            cnn_cache (torch.Tensor): Convolution cache in conformer layer
661                (#batch=1, size, cache_t2)
662        Returns:
663            torch.Tensor: Output tensor (#batch, time, size).
664            torch.Tensor: Mask tensor (#batch, time, time).
665            torch.Tensor: att_cache tensor,
666                (#batch=1, head, cache_t1 + time, d_k * 2).
667            torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
668        """
669
670        # whether to use macaron style
671        if self.feed_forward_macaron is not None:
672            residual = x
673            if self.normalize_before:
674                x = self.norm_ff_macaron(x)
675            x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
676            if not self.normalize_before:
677                x = self.norm_ff_macaron(x)
678
679        # multi-headed self-attention module
680        residual = x
681        if self.normalize_before:
682            x = self.norm_mha(x)
683        x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
684        x = residual + self.dropout(x_att)
685        if not self.normalize_before:
686            x = self.norm_mha(x)
687
688        # convolution module
689        # Fake new cnn cache here, and then change it in conv_module
690        new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
691        if self.conv_module is not None:
692            residual = x
693            if self.normalize_before:
694                x = self.norm_conv(x)
695            x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
696            x = residual + self.dropout(x)
697
698            if not self.normalize_before:
699                x = self.norm_conv(x)
700
701        # feed forward module
702        residual = x
703        if self.normalize_before:
704            x = self.norm_ff(x)
705
706        x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
707        if not self.normalize_before:
708            x = self.norm_ff(x)
709
710        if self.conv_module is not None:
711            x = self.norm_final(x)
712
713        return x, mask, new_att_cache, new_cnn_cache

Compute encoded features.

Args: x (torch.Tensor): (#batch, time, size) mask (torch.Tensor): Mask tensor for the input (#batch, time,time), (0, 0, 0) means fake mask. pos_emb (torch.Tensor): positional encoding, must not be None for ConformerEncoderLayer. mask_pad (torch.Tensor): batch padding mask used for conv module. (#batch, 1,time), (0, 0, 0) means fake mask. att_cache (torch.Tensor): Cache tensor of the KEY & VALUE (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. cnn_cache (torch.Tensor): Convolution cache in conformer layer (#batch=1, size, cache_t2) Returns: torch.Tensor: Output tensor (#batch, time, size). torch.Tensor: Mask tensor (#batch, time, time). torch.Tensor: att_cache tensor, (#batch=1, head, cache_t1 + time, d_k * 2). torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).

class EspnetRelPositionalEncoding(torch.nn.modules.module.Module):
716class EspnetRelPositionalEncoding(torch.nn.Module):
717    """Relative positional encoding module (new implementation).
718
719    Details can be found in https://github.com/espnet/espnet/pull/2816.
720
721    See : Appendix B in https://arxiv.org/abs/1901.02860
722
723    Args:
724        d_model (int): Embedding dimension.
725        dropout_rate (float): Dropout rate.
726        max_len (int): Maximum input length.
727
728    """
729
730    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
731        """Construct an PositionalEncoding object."""
732        super(EspnetRelPositionalEncoding, self).__init__()
733        self.d_model = d_model
734        self.xscale = math.sqrt(self.d_model)
735        self.dropout = torch.nn.Dropout(p=dropout_rate)
736        self.pe = None
737        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
738
739    def extend_pe(self, x: torch.Tensor):
740        """Reset the positional encodings."""
741        if self.pe is not None:
742            # self.pe contains both positive and negative parts
743            # the length of self.pe is 2 * input_len - 1
744            if self.pe.size(1) >= x.size(1) * 2 - 1:
745                if self.pe.dtype != x.dtype or self.pe.device != x.device:
746                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
747                return
748        # Suppose `i` means to the position of query vecotr and `j` means the
749        # position of key vector. We use position relative positions when keys
750        # are to the left (i>j) and negative relative positions otherwise (i<j).
751        pe_positive = torch.zeros(x.size(1), self.d_model)
752        pe_negative = torch.zeros(x.size(1), self.d_model)
753        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
754        div_term = torch.exp(
755            torch.arange(0, self.d_model, 2, dtype=torch.float32)
756            * -(math.log(10000.0) / self.d_model)
757        )
758        pe_positive[:, 0::2] = torch.sin(position * div_term)
759        pe_positive[:, 1::2] = torch.cos(position * div_term)
760        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
761        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
762
763        # Reserve the order of positive indices and concat both positive and
764        # negative indices. This is used to support the shifting trick
765        # as in https://arxiv.org/abs/1901.02860
766        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
767        pe_negative = pe_negative[1:].unsqueeze(0)
768        pe = torch.cat([pe_positive, pe_negative], dim=1)
769        self.pe = pe.to(device=x.device, dtype=x.dtype)
770
771    def forward(
772        self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
773    ) -> Tuple[torch.Tensor, torch.Tensor]:
774        """Add positional encoding.
775
776        Args:
777            x (torch.Tensor): Input tensor (batch, time, `*`).
778
779        Returns:
780            torch.Tensor: Encoded tensor (batch, time, `*`).
781
782        """
783        self.extend_pe(x)
784        x = x * self.xscale
785        pos_emb = self.position_encoding(size=x.size(1), offset=offset)
786        return self.dropout(x), self.dropout(pos_emb)
787
788    def position_encoding(
789        self, offset: Union[int, torch.Tensor], size: int
790    ) -> torch.Tensor:
791        """For getting encoding in a streaming fashion
792
793        Attention!!!!!
794        we apply dropout only once at the whole utterance level in a none
795        streaming way, but will call this function several times with
796        increasing input size in a streaming scenario, so the dropout will
797        be applied several times.
798
799        Args:
800            offset (int or torch.tensor): start offset
801            size (int): required size of position encoding
802
803        Returns:
804            torch.Tensor: Corresponding encoding
805        """
806        pos_emb = self.pe[
807            :,
808            self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
809        ]
810        return pos_emb

Relative positional encoding module (new implementation).

Details can be found in https://github.com/espnet/espnet/pull/2816.

See : Appendix B in https://arxiv.org/abs/1901.02860

Args: d_model (int): Embedding dimension. dropout_rate (float): Dropout rate. max_len (int): Maximum input length.

EspnetRelPositionalEncoding(d_model: int, dropout_rate: float, max_len: int = 5000)
730    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
731        """Construct an PositionalEncoding object."""
732        super(EspnetRelPositionalEncoding, self).__init__()
733        self.d_model = d_model
734        self.xscale = math.sqrt(self.d_model)
735        self.dropout = torch.nn.Dropout(p=dropout_rate)
736        self.pe = None
737        self.extend_pe(torch.tensor(0.0).expand(1, max_len))

Construct an PositionalEncoding object.

d_model
xscale
dropout
pe
def extend_pe(self, x: torch.Tensor):
739    def extend_pe(self, x: torch.Tensor):
740        """Reset the positional encodings."""
741        if self.pe is not None:
742            # self.pe contains both positive and negative parts
743            # the length of self.pe is 2 * input_len - 1
744            if self.pe.size(1) >= x.size(1) * 2 - 1:
745                if self.pe.dtype != x.dtype or self.pe.device != x.device:
746                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
747                return
748        # Suppose `i` means to the position of query vecotr and `j` means the
749        # position of key vector. We use position relative positions when keys
750        # are to the left (i>j) and negative relative positions otherwise (i<j).
751        pe_positive = torch.zeros(x.size(1), self.d_model)
752        pe_negative = torch.zeros(x.size(1), self.d_model)
753        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
754        div_term = torch.exp(
755            torch.arange(0, self.d_model, 2, dtype=torch.float32)
756            * -(math.log(10000.0) / self.d_model)
757        )
758        pe_positive[:, 0::2] = torch.sin(position * div_term)
759        pe_positive[:, 1::2] = torch.cos(position * div_term)
760        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
761        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
762
763        # Reserve the order of positive indices and concat both positive and
764        # negative indices. This is used to support the shifting trick
765        # as in https://arxiv.org/abs/1901.02860
766        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
767        pe_negative = pe_negative[1:].unsqueeze(0)
768        pe = torch.cat([pe_positive, pe_negative], dim=1)
769        self.pe = pe.to(device=x.device, dtype=x.dtype)

Reset the positional encodings.

def forward( self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]:
771    def forward(
772        self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
773    ) -> Tuple[torch.Tensor, torch.Tensor]:
774        """Add positional encoding.
775
776        Args:
777            x (torch.Tensor): Input tensor (batch, time, `*`).
778
779        Returns:
780            torch.Tensor: Encoded tensor (batch, time, `*`).
781
782        """
783        self.extend_pe(x)
784        x = x * self.xscale
785        pos_emb = self.position_encoding(size=x.size(1), offset=offset)
786        return self.dropout(x), self.dropout(pos_emb)

Add positional encoding.

Args: x (torch.Tensor): Input tensor (batch, time, *).

Returns: torch.Tensor: Encoded tensor (batch, time, *).

def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
788    def position_encoding(
789        self, offset: Union[int, torch.Tensor], size: int
790    ) -> torch.Tensor:
791        """For getting encoding in a streaming fashion
792
793        Attention!!!!!
794        we apply dropout only once at the whole utterance level in a none
795        streaming way, but will call this function several times with
796        increasing input size in a streaming scenario, so the dropout will
797        be applied several times.
798
799        Args:
800            offset (int or torch.tensor): start offset
801            size (int): required size of position encoding
802
803        Returns:
804            torch.Tensor: Corresponding encoding
805        """
806        pos_emb = self.pe[
807            :,
808            self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
809        ]
810        return pos_emb

For getting encoding in a streaming fashion

Attention!!!!! we apply dropout only once at the whole utterance level in a none streaming way, but will call this function several times with increasing input size in a streaming scenario, so the dropout will be applied several times.

Args: offset (int or torch.tensor): start offset size (int): required size of position encoding

Returns: torch.Tensor: Corresponding encoding

class LinearEmbed(torch.nn.modules.module.Module):
813class LinearEmbed(torch.nn.Module):
814    """Linear transform the input without subsampling
815
816    Args:
817        idim (int): Input dimension.
818        odim (int): Output dimension.
819        dropout_rate (float): Dropout rate.
820
821    """
822
823    def __init__(
824        self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
825    ):
826        """Construct an linear object."""
827        super().__init__()
828        self.out = torch.nn.Sequential(
829            torch.nn.Linear(idim, odim),
830            torch.nn.LayerNorm(odim, eps=1e-5),
831            torch.nn.Dropout(dropout_rate),
832        )
833        self.pos_enc = pos_enc_class  # rel_pos_espnet
834
835    def position_encoding(
836        self, offset: Union[int, torch.Tensor], size: int
837    ) -> torch.Tensor:
838        return self.pos_enc.position_encoding(offset, size)
839
840    def forward(
841        self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
842    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
843        """Input x.
844
845        Args:
846            x (torch.Tensor): Input tensor (#batch, time, idim).
847            x_mask (torch.Tensor): Input mask (#batch, 1, time).
848
849        Returns:
850            torch.Tensor: linear input tensor (#batch, time', odim),
851                where time' = time .
852            torch.Tensor: linear input mask (#batch, 1, time'),
853                where time' = time .
854
855        """
856        x = self.out(x)
857        x, pos_emb = self.pos_enc(x, offset)
858        return x, pos_emb

Linear transform the input without subsampling

Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate.

LinearEmbed( idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.modules.module.Module)
823    def __init__(
824        self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
825    ):
826        """Construct an linear object."""
827        super().__init__()
828        self.out = torch.nn.Sequential(
829            torch.nn.Linear(idim, odim),
830            torch.nn.LayerNorm(odim, eps=1e-5),
831            torch.nn.Dropout(dropout_rate),
832        )
833        self.pos_enc = pos_enc_class  # rel_pos_espnet

Construct an linear object.

out
pos_enc
def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
835    def position_encoding(
836        self, offset: Union[int, torch.Tensor], size: int
837    ) -> torch.Tensor:
838        return self.pos_enc.position_encoding(offset, size)
def forward( self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
840    def forward(
841        self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
842    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
843        """Input x.
844
845        Args:
846            x (torch.Tensor): Input tensor (#batch, time, idim).
847            x_mask (torch.Tensor): Input mask (#batch, 1, time).
848
849        Returns:
850            torch.Tensor: linear input tensor (#batch, time', odim),
851                where time' = time .
852            torch.Tensor: linear input mask (#batch, 1, time'),
853                where time' = time .
854
855        """
856        x = self.out(x)
857        x, pos_emb = self.pos_enc(x, offset)
858        return x, pos_emb

Input x.

Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time).

Returns: torch.Tensor: linear input tensor (#batch, time', odim), where time' = time . torch.Tensor: linear input mask (#batch, 1, time'), where time' = time .

ATTENTION_CLASSES = {'selfattn': <class 'MultiHeadedAttention'>, 'rel_selfattn': <class 'RelPositionMultiHeadedAttention'>}
ACTIVATION_CLASSES = {'hardtanh': <class 'torch.nn.modules.activation.Hardtanh'>, 'tanh': <class 'torch.nn.modules.activation.Tanh'>, 'relu': <class 'torch.nn.modules.activation.ReLU'>, 'selu': <class 'torch.nn.modules.activation.SELU'>, 'swish': <class 'torch.nn.modules.activation.SiLU'>, 'gelu': <class 'torch.nn.modules.activation.GELU'>}
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
876def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
877    """Make mask tensor containing indices of padded part.
878
879    See description of make_non_pad_mask.
880
881    Args:
882        lengths (torch.Tensor): Batch of lengths (B,).
883    Returns:
884        torch.Tensor: Mask tensor containing indices of padded part.
885
886    Examples:
887        >>> lengths = [5, 3, 2]
888        >>> make_pad_mask(lengths)
889        masks = [[0, 0, 0, 0 ,0],
890                 [0, 0, 0, 1, 1],
891                 [0, 0, 1, 1, 1]]
892    """
893    batch_size = lengths.size(0)
894    max_len = max_len if max_len > 0 else lengths.max().item()
895    seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
896    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
897    seq_length_expand = lengths.unsqueeze(-1)
898    mask = seq_range_expand >= seq_length_expand
899    return mask

Make mask tensor containing indices of padded part.

See description of make_non_pad_mask.

Args: lengths (torch.Tensor): Batch of lengths (B,). Returns: torch.Tensor: Mask tensor containing indices of padded part.

Examples:

lengths = [5, 3, 2] make_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]]

class ConformerEncoder(torch.nn.modules.module.Module):
 903class ConformerEncoder(torch.nn.Module):
 904    """Conformer encoder module."""
 905
 906    def __init__(
 907        self,
 908        input_size: int,
 909        output_size: int = 1024,
 910        attention_heads: int = 16,
 911        linear_units: int = 4096,
 912        num_blocks: int = 6,
 913        dropout_rate: float = 0.1,
 914        positional_dropout_rate: float = 0.1,
 915        attention_dropout_rate: float = 0.0,
 916        input_layer: str = "linear",
 917        pos_enc_layer_type: str = "rel_pos_espnet",
 918        normalize_before: bool = True,
 919        static_chunk_size: int = 1,  # 1: causal_mask; 0: full_mask
 920        use_dynamic_chunk: bool = False,
 921        use_dynamic_left_chunk: bool = False,
 922        positionwise_conv_kernel_size: int = 1,
 923        macaron_style: bool = False,
 924        selfattention_layer_type: str = "rel_selfattn",
 925        activation_type: str = "swish",
 926        use_cnn_module: bool = False,
 927        cnn_module_kernel: int = 15,
 928        causal: bool = False,
 929        cnn_module_norm: str = "batch_norm",
 930        key_bias: bool = True,
 931        gradient_checkpointing: bool = False,
 932    ):
 933        """Construct ConformerEncoder
 934
 935        Args:
 936            input_size to use_dynamic_chunk, see in BaseEncoder
 937            positionwise_conv_kernel_size (int): Kernel size of positionwise
 938                conv1d layer.
 939            macaron_style (bool): Whether to use macaron style for
 940                positionwise layer.
 941            selfattention_layer_type (str): Encoder attention layer type,
 942                the parameter has no effect now, it's just for configure
 943                compatibility. #'rel_selfattn'
 944            activation_type (str): Encoder activation function type.
 945            use_cnn_module (bool): Whether to use convolution module.
 946            cnn_module_kernel (int): Kernel size of convolution module.
 947            causal (bool): whether to use causal convolution or not.
 948            key_bias: whether use bias in attention.linear_k, False for whisper models.
 949        """
 950        super().__init__()
 951        self.output_size = output_size
 952        self.embed = LinearEmbed(
 953            input_size,
 954            output_size,
 955            dropout_rate,
 956            EspnetRelPositionalEncoding(output_size, positional_dropout_rate),
 957        )
 958        self.normalize_before = normalize_before
 959        self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
 960        self.gradient_checkpointing = gradient_checkpointing
 961        self.use_dynamic_chunk = use_dynamic_chunk
 962
 963        self.static_chunk_size = static_chunk_size
 964        self.use_dynamic_chunk = use_dynamic_chunk
 965        self.use_dynamic_left_chunk = use_dynamic_left_chunk
 966        activation = ACTIVATION_CLASSES[activation_type]()
 967
 968        # self-attention module definition
 969        encoder_selfattn_layer_args = (
 970            attention_heads,
 971            output_size,
 972            attention_dropout_rate,
 973            key_bias,
 974        )
 975        # feed-forward module definition
 976        positionwise_layer_args = (
 977            output_size,
 978            linear_units,
 979            dropout_rate,
 980            activation,
 981        )
 982        # convolution module definition
 983        convolution_layer_args = (
 984            output_size,
 985            cnn_module_kernel,
 986            activation,
 987            cnn_module_norm,
 988            causal,
 989        )
 990
 991        self.encoders = torch.nn.ModuleList(
 992            [
 993                ConformerEncoderLayer(
 994                    output_size,
 995                    RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
 996                    PositionwiseFeedForward(*positionwise_layer_args),
 997                    (
 998                        PositionwiseFeedForward(*positionwise_layer_args)
 999                        if macaron_style
1000                        else None
1001                    ),
1002                    (
1003                        ConvolutionModule(*convolution_layer_args)
1004                        if use_cnn_module
1005                        else None
1006                    ),
1007                    dropout_rate,
1008                    normalize_before,
1009                )
1010                for _ in range(num_blocks)
1011            ]
1012        )
1013
1014    def forward_layers(
1015        self,
1016        xs: torch.Tensor,
1017        chunk_masks: torch.Tensor,
1018        pos_emb: torch.Tensor,
1019        mask_pad: torch.Tensor,
1020    ) -> torch.Tensor:
1021        for layer in self.encoders:
1022            xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
1023        return xs
1024
1025    @torch.jit.unused
1026    def forward_layers_checkpointed(
1027        self,
1028        xs: torch.Tensor,
1029        chunk_masks: torch.Tensor,
1030        pos_emb: torch.Tensor,
1031        mask_pad: torch.Tensor,
1032    ) -> torch.Tensor:
1033        for layer in self.encoders:
1034            xs, chunk_masks, _, _ = torch.utils.checkpoint.checkpoint(
1035                layer.__call__, xs, chunk_masks, pos_emb, mask_pad, use_reentrant=False
1036            )
1037        return xs
1038
1039    def forward(
1040        self,
1041        xs: torch.Tensor,
1042        pad_mask: torch.Tensor,
1043        decoding_chunk_size: int = 0,
1044        num_decoding_left_chunks: int = -1,
1045    ) -> Tuple[torch.Tensor, torch.Tensor]:
1046        """Embed positions in tensor.
1047
1048        Args:
1049            xs: padded input tensor (B, T, D)
1050            xs_lens: input length (B)
1051            decoding_chunk_size: decoding chunk size for dynamic chunk
1052                0: default for training, use random dynamic chunk.
1053                <0: for decoding, use full chunk.
1054                >0: for decoding, use fixed chunk size as set.
1055            num_decoding_left_chunks: number of left chunks, this is for decoding,
1056            the chunk size is decoding_chunk_size.
1057                >=0: use num_decoding_left_chunks
1058                <0: use all left chunks
1059        Returns:
1060            encoder output tensor xs, and subsampled masks
1061            xs: padded output tensor (B, T' ~= T/subsample_rate, D)
1062            masks: torch.Tensor batch padding mask after subsample
1063                (B, 1, T' ~= T/subsample_rate)
1064        NOTE(xcsong):
1065            We pass the `__call__` method of the modules instead of `forward` to the
1066            checkpointing API because `__call__` attaches all the hooks of the module.
1067            https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
1068        """
1069        T = xs.size(1)
1070        masks = pad_mask.to(torch.bool).unsqueeze(1)  # (B, 1, T)
1071        xs, pos_emb = self.embed(xs)
1072        mask_pad = masks  # (B, 1, T/subsample_rate)
1073        chunk_masks = add_optional_chunk_mask(
1074            xs,
1075            masks,
1076            self.use_dynamic_chunk,
1077            self.use_dynamic_left_chunk,
1078            decoding_chunk_size,
1079            self.static_chunk_size,
1080            num_decoding_left_chunks,
1081        )
1082        if self.gradient_checkpointing and self.training:
1083            xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad)
1084        else:
1085            xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
1086        if self.normalize_before:
1087            xs = self.after_norm(xs)
1088        # Here we assume the mask is not changed in encoder layers, so just
1089        # return the masks before encoder layers, and the masks will be used
1090        # for cross attention with decoder later
1091        return xs, masks

Conformer encoder module.

ConformerEncoder( input_size: int, output_size: int = 1024, attention_heads: int = 16, linear_units: int = 4096, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: str = 'linear', pos_enc_layer_type: str = 'rel_pos_espnet', normalize_before: bool = True, static_chunk_size: int = 1, use_dynamic_chunk: bool = False, use_dynamic_left_chunk: bool = False, positionwise_conv_kernel_size: int = 1, macaron_style: bool = False, selfattention_layer_type: str = 'rel_selfattn', activation_type: str = 'swish', use_cnn_module: bool = False, cnn_module_kernel: int = 15, causal: bool = False, cnn_module_norm: str = 'batch_norm', key_bias: bool = True, gradient_checkpointing: bool = False)
 906    def __init__(
 907        self,
 908        input_size: int,
 909        output_size: int = 1024,
 910        attention_heads: int = 16,
 911        linear_units: int = 4096,
 912        num_blocks: int = 6,
 913        dropout_rate: float = 0.1,
 914        positional_dropout_rate: float = 0.1,
 915        attention_dropout_rate: float = 0.0,
 916        input_layer: str = "linear",
 917        pos_enc_layer_type: str = "rel_pos_espnet",
 918        normalize_before: bool = True,
 919        static_chunk_size: int = 1,  # 1: causal_mask; 0: full_mask
 920        use_dynamic_chunk: bool = False,
 921        use_dynamic_left_chunk: bool = False,
 922        positionwise_conv_kernel_size: int = 1,
 923        macaron_style: bool = False,
 924        selfattention_layer_type: str = "rel_selfattn",
 925        activation_type: str = "swish",
 926        use_cnn_module: bool = False,
 927        cnn_module_kernel: int = 15,
 928        causal: bool = False,
 929        cnn_module_norm: str = "batch_norm",
 930        key_bias: bool = True,
 931        gradient_checkpointing: bool = False,
 932    ):
 933        """Construct ConformerEncoder
 934
 935        Args:
 936            input_size to use_dynamic_chunk, see in BaseEncoder
 937            positionwise_conv_kernel_size (int): Kernel size of positionwise
 938                conv1d layer.
 939            macaron_style (bool): Whether to use macaron style for
 940                positionwise layer.
 941            selfattention_layer_type (str): Encoder attention layer type,
 942                the parameter has no effect now, it's just for configure
 943                compatibility. #'rel_selfattn'
 944            activation_type (str): Encoder activation function type.
 945            use_cnn_module (bool): Whether to use convolution module.
 946            cnn_module_kernel (int): Kernel size of convolution module.
 947            causal (bool): whether to use causal convolution or not.
 948            key_bias: whether use bias in attention.linear_k, False for whisper models.
 949        """
 950        super().__init__()
 951        self.output_size = output_size
 952        self.embed = LinearEmbed(
 953            input_size,
 954            output_size,
 955            dropout_rate,
 956            EspnetRelPositionalEncoding(output_size, positional_dropout_rate),
 957        )
 958        self.normalize_before = normalize_before
 959        self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
 960        self.gradient_checkpointing = gradient_checkpointing
 961        self.use_dynamic_chunk = use_dynamic_chunk
 962
 963        self.static_chunk_size = static_chunk_size
 964        self.use_dynamic_chunk = use_dynamic_chunk
 965        self.use_dynamic_left_chunk = use_dynamic_left_chunk
 966        activation = ACTIVATION_CLASSES[activation_type]()
 967
 968        # self-attention module definition
 969        encoder_selfattn_layer_args = (
 970            attention_heads,
 971            output_size,
 972            attention_dropout_rate,
 973            key_bias,
 974        )
 975        # feed-forward module definition
 976        positionwise_layer_args = (
 977            output_size,
 978            linear_units,
 979            dropout_rate,
 980            activation,
 981        )
 982        # convolution module definition
 983        convolution_layer_args = (
 984            output_size,
 985            cnn_module_kernel,
 986            activation,
 987            cnn_module_norm,
 988            causal,
 989        )
 990
 991        self.encoders = torch.nn.ModuleList(
 992            [
 993                ConformerEncoderLayer(
 994                    output_size,
 995                    RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
 996                    PositionwiseFeedForward(*positionwise_layer_args),
 997                    (
 998                        PositionwiseFeedForward(*positionwise_layer_args)
 999                        if macaron_style
1000                        else None
1001                    ),
1002                    (
1003                        ConvolutionModule(*convolution_layer_args)
1004                        if use_cnn_module
1005                        else None
1006                    ),
1007                    dropout_rate,
1008                    normalize_before,
1009                )
1010                for _ in range(num_blocks)
1011            ]
1012        )

Construct ConformerEncoder

Args: input_size to use_dynamic_chunk, see in BaseEncoder positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. macaron_style (bool): Whether to use macaron style for positionwise layer. selfattention_layer_type (str): Encoder attention layer type, the parameter has no effect now, it's just for configure compatibility. #'rel_selfattn' activation_type (str): Encoder activation function type. use_cnn_module (bool): Whether to use convolution module. cnn_module_kernel (int): Kernel size of convolution module. causal (bool): whether to use causal convolution or not. key_bias: whether use bias in attention.linear_k, False for whisper models.

output_size
embed
normalize_before
after_norm
gradient_checkpointing
use_dynamic_chunk
static_chunk_size
use_dynamic_left_chunk
encoders
def forward_layers( self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor:
1014    def forward_layers(
1015        self,
1016        xs: torch.Tensor,
1017        chunk_masks: torch.Tensor,
1018        pos_emb: torch.Tensor,
1019        mask_pad: torch.Tensor,
1020    ) -> torch.Tensor:
1021        for layer in self.encoders:
1022            xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
1023        return xs
@torch.jit.unused
def forward_layers_checkpointed( self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor:
1025    @torch.jit.unused
1026    def forward_layers_checkpointed(
1027        self,
1028        xs: torch.Tensor,
1029        chunk_masks: torch.Tensor,
1030        pos_emb: torch.Tensor,
1031        mask_pad: torch.Tensor,
1032    ) -> torch.Tensor:
1033        for layer in self.encoders:
1034            xs, chunk_masks, _, _ = torch.utils.checkpoint.checkpoint(
1035                layer.__call__, xs, chunk_masks, pos_emb, mask_pad, use_reentrant=False
1036            )
1037        return xs
def forward( self, xs: torch.Tensor, pad_mask: torch.Tensor, decoding_chunk_size: int = 0, num_decoding_left_chunks: int = -1) -> Tuple[torch.Tensor, torch.Tensor]:
1039    def forward(
1040        self,
1041        xs: torch.Tensor,
1042        pad_mask: torch.Tensor,
1043        decoding_chunk_size: int = 0,
1044        num_decoding_left_chunks: int = -1,
1045    ) -> Tuple[torch.Tensor, torch.Tensor]:
1046        """Embed positions in tensor.
1047
1048        Args:
1049            xs: padded input tensor (B, T, D)
1050            xs_lens: input length (B)
1051            decoding_chunk_size: decoding chunk size for dynamic chunk
1052                0: default for training, use random dynamic chunk.
1053                <0: for decoding, use full chunk.
1054                >0: for decoding, use fixed chunk size as set.
1055            num_decoding_left_chunks: number of left chunks, this is for decoding,
1056            the chunk size is decoding_chunk_size.
1057                >=0: use num_decoding_left_chunks
1058                <0: use all left chunks
1059        Returns:
1060            encoder output tensor xs, and subsampled masks
1061            xs: padded output tensor (B, T' ~= T/subsample_rate, D)
1062            masks: torch.Tensor batch padding mask after subsample
1063                (B, 1, T' ~= T/subsample_rate)
1064        NOTE(xcsong):
1065            We pass the `__call__` method of the modules instead of `forward` to the
1066            checkpointing API because `__call__` attaches all the hooks of the module.
1067            https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
1068        """
1069        T = xs.size(1)
1070        masks = pad_mask.to(torch.bool).unsqueeze(1)  # (B, 1, T)
1071        xs, pos_emb = self.embed(xs)
1072        mask_pad = masks  # (B, 1, T/subsample_rate)
1073        chunk_masks = add_optional_chunk_mask(
1074            xs,
1075            masks,
1076            self.use_dynamic_chunk,
1077            self.use_dynamic_left_chunk,
1078            decoding_chunk_size,
1079            self.static_chunk_size,
1080            num_decoding_left_chunks,
1081        )
1082        if self.gradient_checkpointing and self.training:
1083            xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad)
1084        else:
1085            xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
1086        if self.normalize_before:
1087            xs = self.after_norm(xs)
1088        # Here we assume the mask is not changed in encoder layers, so just
1089        # return the masks before encoder layers, and the masks will be used
1090        # for cross attention with decoder later
1091        return xs, masks

Embed positions in tensor.

Args: xs: padded input tensor (B, T, D) xs_lens: input length (B) decoding_chunk_size: decoding chunk size for dynamic chunk 0: default for training, use random dynamic chunk. <0: for decoding, use full chunk.

0: for decoding, use fixed chunk size as set. num_decoding_left_chunks: number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. =0: use num_decoding_left_chunks <0: use all left chunks Returns: encoder output tensor xs, and subsampled masks xs: padded output tensor (B, T' ~= T/subsample_rate, D) masks: torch.Tensor batch padding mask after subsample (B, 1, T' ~= T/subsample_rate) NOTE(xcsong): We pass the __call__ method of the modules instead of forward to the checkpointing API because __call__ attaches all the hooks of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2