divisor.acestep.models.lyrics_utils.lyric_encoder
1from typing import Optional, Tuple, Union 2import math 3import torch 4from torch import nn 5 6 7class ConvolutionModule(nn.Module): 8 """ConvolutionModule in Conformer model.""" 9 10 def __init__( 11 self, 12 channels: int, 13 kernel_size: int = 15, 14 activation: nn.Module = nn.ReLU(), 15 norm: str = "batch_norm", 16 causal: bool = False, 17 bias: bool = True, 18 ): 19 """Construct an ConvolutionModule object. 20 Args: 21 channels (int): The number of channels of conv layers. 22 kernel_size (int): Kernel size of conv layers. 23 causal (int): Whether use causal convolution or not 24 """ 25 super().__init__() 26 27 self.pointwise_conv1 = nn.Conv1d( 28 channels, 29 2 * channels, 30 kernel_size=1, 31 stride=1, 32 padding=0, 33 bias=bias, 34 ) 35 # self.lorder is used to distinguish if it's a causal convolution, 36 # if self.lorder > 0: it's a causal convolution, the input will be 37 # padded with self.lorder frames on the left in forward. 38 # else: it's a symmetrical convolution 39 if causal: 40 padding = 0 41 self.lorder = kernel_size - 1 42 else: 43 # kernel_size should be an odd number for none causal convolution 44 assert (kernel_size - 1) % 2 == 0 45 padding = (kernel_size - 1) // 2 46 self.lorder = 0 47 self.depthwise_conv = nn.Conv1d( 48 channels, 49 channels, 50 kernel_size, 51 stride=1, 52 padding=padding, 53 groups=channels, 54 bias=bias, 55 ) 56 57 assert norm in ["batch_norm", "layer_norm"] 58 if norm == "batch_norm": 59 self.use_layer_norm = False 60 self.norm = nn.BatchNorm1d(channels) 61 else: 62 self.use_layer_norm = True 63 self.norm = nn.LayerNorm(channels) 64 65 self.pointwise_conv2 = nn.Conv1d( 66 channels, 67 channels, 68 kernel_size=1, 69 stride=1, 70 padding=0, 71 bias=bias, 72 ) 73 self.activation = activation 74 75 def forward( 76 self, 77 x: torch.Tensor, 78 mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 79 cache: torch.Tensor = torch.zeros((0, 0, 0)), 80 ) -> Tuple[torch.Tensor, torch.Tensor]: 81 """Compute convolution module. 82 Args: 83 x (torch.Tensor): Input tensor (#batch, time, channels). 84 mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 85 (0, 0, 0) means fake mask. 86 cache (torch.Tensor): left context cache, it is only 87 used in causal convolution (#batch, channels, cache_t), 88 (0, 0, 0) meas fake cache. 89 Returns: 90 torch.Tensor: Output tensor (#batch, time, channels). 91 """ 92 # exchange the temporal dimension and the feature dimension 93 x = x.transpose(1, 2) # (#batch, channels, time) 94 95 # mask batch padding 96 if mask_pad.size(2) > 0: # time > 0 97 x.masked_fill_(~mask_pad, 0.0) 98 99 if self.lorder > 0: 100 if cache.size(2) == 0: # cache_t == 0 101 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) 102 else: 103 assert cache.size(0) == x.size(0) # equal batch 104 assert cache.size(1) == x.size(1) # equal channel 105 x = torch.cat((cache, x), dim=2) 106 assert x.size(2) > self.lorder 107 new_cache = x[:, :, -self.lorder :] 108 else: 109 # It's better we just return None if no cache is required, 110 # However, for JIT export, here we just fake one tensor instead of 111 # None. 112 new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 113 114 # GLU mechanism 115 x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 116 x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 117 118 # 1D Depthwise Conv 119 x = self.depthwise_conv(x) 120 if self.use_layer_norm: 121 x = x.transpose(1, 2) 122 x = self.activation(self.norm(x)) 123 if self.use_layer_norm: 124 x = x.transpose(1, 2) 125 x = self.pointwise_conv2(x) 126 # mask batch padding 127 if mask_pad.size(2) > 0: # time > 0 128 x.masked_fill_(~mask_pad, 0.0) 129 130 return x.transpose(1, 2), new_cache 131 132 133class PositionwiseFeedForward(torch.nn.Module): 134 """Positionwise feed forward layer. 135 136 FeedForward are appied on each position of the sequence. 137 The output dim is same with the input dim. 138 139 Args: 140 idim (int): Input dimenstion. 141 hidden_units (int): The number of hidden units. 142 dropout_rate (float): Dropout rate. 143 activation (torch.nn.Module): Activation function 144 """ 145 146 def __init__( 147 self, 148 idim: int, 149 hidden_units: int, 150 dropout_rate: float, 151 activation: torch.nn.Module = torch.nn.ReLU(), 152 ): 153 """Construct a PositionwiseFeedForward object.""" 154 super(PositionwiseFeedForward, self).__init__() 155 self.w_1 = torch.nn.Linear(idim, hidden_units) 156 self.activation = activation 157 self.dropout = torch.nn.Dropout(dropout_rate) 158 self.w_2 = torch.nn.Linear(hidden_units, idim) 159 160 def forward(self, xs: torch.Tensor) -> torch.Tensor: 161 """Forward function. 162 163 Args: 164 xs: input tensor (B, L, D) 165 Returns: 166 output tensor, (B, L, D) 167 """ 168 return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 169 170 171class Swish(torch.nn.Module): 172 """Construct an Swish object.""" 173 174 def forward(self, x: torch.Tensor) -> torch.Tensor: 175 """Return Swish activation function.""" 176 return x * torch.sigmoid(x) 177 178 179class MultiHeadedAttention(nn.Module): 180 """Multi-Head Attention layer. 181 182 Args: 183 n_head (int): The number of heads. 184 n_feat (int): The number of features. 185 dropout_rate (float): Dropout rate. 186 187 """ 188 189 def __init__( 190 self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True 191 ): 192 """Construct an MultiHeadedAttention object.""" 193 super().__init__() 194 assert n_feat % n_head == 0 195 # We assume d_v always equals d_k 196 self.d_k = n_feat // n_head 197 self.h = n_head 198 self.linear_q = nn.Linear(n_feat, n_feat) 199 self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) 200 self.linear_v = nn.Linear(n_feat, n_feat) 201 self.linear_out = nn.Linear(n_feat, n_feat) 202 self.dropout = nn.Dropout(p=dropout_rate) 203 204 def forward_qkv( 205 self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 206 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 207 """Transform query, key and value. 208 209 Args: 210 query (torch.Tensor): Query tensor (#batch, time1, size). 211 key (torch.Tensor): Key tensor (#batch, time2, size). 212 value (torch.Tensor): Value tensor (#batch, time2, size). 213 214 Returns: 215 torch.Tensor: Transformed query tensor, size 216 (#batch, n_head, time1, d_k). 217 torch.Tensor: Transformed key tensor, size 218 (#batch, n_head, time2, d_k). 219 torch.Tensor: Transformed value tensor, size 220 (#batch, n_head, time2, d_k). 221 222 """ 223 n_batch = query.size(0) 224 q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 225 k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 226 v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 227 q = q.transpose(1, 2) # (batch, head, time1, d_k) 228 k = k.transpose(1, 2) # (batch, head, time2, d_k) 229 v = v.transpose(1, 2) # (batch, head, time2, d_k) 230 return q, k, v 231 232 def forward_attention( 233 self, 234 value: torch.Tensor, 235 scores: torch.Tensor, 236 mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 237 ) -> torch.Tensor: 238 """Compute attention context vector. 239 240 Args: 241 value (torch.Tensor): Transformed value, size 242 (#batch, n_head, time2, d_k). 243 scores (torch.Tensor): Attention score, size 244 (#batch, n_head, time1, time2). 245 mask (torch.Tensor): Mask, size (#batch, 1, time2) or 246 (#batch, time1, time2), (0, 0, 0) means fake mask. 247 248 Returns: 249 torch.Tensor: Transformed value (#batch, time1, d_model) 250 weighted by the attention score (#batch, time1, time2). 251 252 """ 253 n_batch = value.size(0) 254 255 if mask.size(2) > 0: # time2 > 0 256 mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 257 # For last chunk, time2 might be larger than scores.size(-1) 258 mask = mask[:, :, :, : scores.size(-1)] # (batch, 1, *, time2) 259 scores = scores.masked_fill(mask, -float("inf")) 260 attn = torch.softmax(scores, dim=-1).masked_fill( 261 mask, 0.0 262 ) # (batch, head, time1, time2) 263 264 else: 265 attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 266 267 p_attn = self.dropout(attn) 268 x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 269 x = ( 270 x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) 271 ) # (batch, time1, d_model) 272 273 return self.linear_out(x) # (batch, time1, d_model) 274 275 def forward( 276 self, 277 query: torch.Tensor, 278 key: torch.Tensor, 279 value: torch.Tensor, 280 mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 281 pos_emb: torch.Tensor = torch.empty(0), 282 cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 283 ) -> Tuple[torch.Tensor, torch.Tensor]: 284 """Compute scaled dot product attention. 285 286 Args: 287 query (torch.Tensor): Query tensor (#batch, time1, size). 288 key (torch.Tensor): Key tensor (#batch, time2, size). 289 value (torch.Tensor): Value tensor (#batch, time2, size). 290 mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 291 (#batch, time1, time2). 292 1.When applying cross attention between decoder and encoder, 293 the batch padding mask for input is in (#batch, 1, T) shape. 294 2.When applying self attention of encoder, 295 the mask is in (#batch, T, T) shape. 296 3.When applying self attention of decoder, 297 the mask is in (#batch, L, L) shape. 298 4.If the different position in decoder see different block 299 of the encoder, such as Mocha, the passed in mask could be 300 in (#batch, L, T) shape. But there is no such case in current 301 CosyVoice. 302 cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), 303 where `cache_t == chunk_size * num_decoding_left_chunks` 304 and `head * d_k == size` 305 306 307 Returns: 308 torch.Tensor: Output tensor (#batch, time1, d_model). 309 torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) 310 where `cache_t == chunk_size * num_decoding_left_chunks` 311 and `head * d_k == size` 312 313 """ 314 q, k, v = self.forward_qkv(query, key, value) 315 if cache.size(0) > 0: 316 key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) 317 k = torch.cat([key_cache, k], dim=2) 318 v = torch.cat([value_cache, v], dim=2) 319 new_cache = torch.cat((k, v), dim=-1) 320 321 scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 322 return self.forward_attention(v, scores, mask), new_cache 323 324 325class RelPositionMultiHeadedAttention(MultiHeadedAttention): 326 """Multi-Head Attention layer with relative position encoding. 327 Paper: https://arxiv.org/abs/1901.02860 328 Args: 329 n_head (int): The number of heads. 330 n_feat (int): The number of features. 331 dropout_rate (float): Dropout rate. 332 """ 333 334 def __init__( 335 self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True 336 ): 337 """Construct an RelPositionMultiHeadedAttention object.""" 338 super().__init__(n_head, n_feat, dropout_rate, key_bias) 339 # linear transformation for positional encoding 340 self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) 341 # these two learnable bias are used in matrix c and matrix d 342 # as described in https://arxiv.org/abs/1901.02860 Section 3.3 343 self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 344 self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 345 torch.nn.init.xavier_uniform_(self.pos_bias_u) 346 torch.nn.init.xavier_uniform_(self.pos_bias_v) 347 348 def rel_shift(self, x: torch.Tensor) -> torch.Tensor: 349 """Compute relative positional encoding. 350 351 Args: 352 x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). 353 time1 means the length of query vector. 354 355 Returns: 356 torch.Tensor: Output tensor. 357 358 """ 359 zero_pad = torch.zeros( 360 (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype 361 ) 362 x_padded = torch.cat([zero_pad, x], dim=-1) 363 364 x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2)) 365 x = x_padded[:, :, 1:].view_as(x)[ 366 :, :, :, : x.size(-1) // 2 + 1 367 ] # only keep the positions from 0 to time2 368 return x 369 370 def forward( 371 self, 372 query: torch.Tensor, 373 key: torch.Tensor, 374 value: torch.Tensor, 375 mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 376 pos_emb: torch.Tensor = torch.empty(0), 377 cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 378 ) -> Tuple[torch.Tensor, torch.Tensor]: 379 """Compute 'Scaled Dot Product Attention' with rel. positional encoding. 380 Args: 381 query (torch.Tensor): Query tensor (#batch, time1, size). 382 key (torch.Tensor): Key tensor (#batch, time2, size). 383 value (torch.Tensor): Value tensor (#batch, time2, size). 384 mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 385 (#batch, time1, time2), (0, 0, 0) means fake mask. 386 pos_emb (torch.Tensor): Positional embedding tensor 387 (#batch, time2, size). 388 cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), 389 where `cache_t == chunk_size * num_decoding_left_chunks` 390 and `head * d_k == size` 391 Returns: 392 torch.Tensor: Output tensor (#batch, time1, d_model). 393 torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) 394 where `cache_t == chunk_size * num_decoding_left_chunks` 395 and `head * d_k == size` 396 """ 397 q, k, v = self.forward_qkv(query, key, value) 398 q = q.transpose(1, 2) # (batch, time1, head, d_k) 399 400 if cache.size(0) > 0: 401 key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) 402 k = torch.cat([key_cache, k], dim=2) 403 v = torch.cat([value_cache, v], dim=2) 404 # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's 405 # non-trivial to calculate `next_cache_start` here. 406 new_cache = torch.cat((k, v), dim=-1) 407 408 n_batch_pos = pos_emb.size(0) 409 p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) 410 p = p.transpose(1, 2) # (batch, head, time1, d_k) 411 412 # (batch, head, time1, d_k) 413 q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 414 # (batch, head, time1, d_k) 415 q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 416 417 # compute attention score 418 # first compute matrix a and matrix c 419 # as described in https://arxiv.org/abs/1901.02860 Section 3.3 420 # (batch, head, time1, time2) 421 matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 422 423 # compute matrix b and matrix d 424 # (batch, head, time1, time2) 425 matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 426 # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used 427 if matrix_ac.shape != matrix_bd.shape: 428 matrix_bd = self.rel_shift(matrix_bd) 429 430 scores = (matrix_ac + matrix_bd) / math.sqrt( 431 self.d_k 432 ) # (batch, head, time1, time2) 433 434 return self.forward_attention(v, scores, mask), new_cache 435 436 437def subsequent_mask( 438 size: int, 439 device: torch.device = torch.device("cpu"), 440) -> torch.Tensor: 441 """Create mask for subsequent steps (size, size). 442 443 This mask is used only in decoder which works in an auto-regressive mode. 444 This means the current step could only do attention with its left steps. 445 446 In encoder, fully attention is used when streaming is not necessary and 447 the sequence is not long. In this case, no attention mask is needed. 448 449 When streaming is need, chunk-based attention is used in encoder. See 450 subsequent_chunk_mask for the chunk-based attention mask. 451 452 Args: 453 size (int): size of mask 454 str device (str): "cpu" or "cuda" or torch.Tensor.device 455 dtype (torch.device): result dtype 456 457 Returns: 458 torch.Tensor: mask 459 460 Examples: 461 >>> subsequent_mask(3) 462 [[1, 0, 0], 463 [1, 1, 0], 464 [1, 1, 1]] 465 """ 466 arange = torch.arange(size, device=device) 467 mask = arange.expand(size, size) 468 arange = arange.unsqueeze(-1) 469 mask = mask <= arange 470 return mask 471 472 473def subsequent_chunk_mask( 474 size: int, 475 chunk_size: int, 476 num_left_chunks: int = -1, 477 device: torch.device = torch.device("cpu"), 478) -> torch.Tensor: 479 """Create mask for subsequent steps (size, size) with chunk size, 480 this is for streaming encoder 481 482 Args: 483 size (int): size of mask 484 chunk_size (int): size of chunk 485 num_left_chunks (int): number of left chunks 486 <0: use full chunk 487 >=0: use num_left_chunks 488 device (torch.device): "cpu" or "cuda" or torch.Tensor.device 489 490 Returns: 491 torch.Tensor: mask 492 493 Examples: 494 >>> subsequent_chunk_mask(4, 2) 495 [[1, 1, 0, 0], 496 [1, 1, 0, 0], 497 [1, 1, 1, 1], 498 [1, 1, 1, 1]] 499 """ 500 ret = torch.zeros(size, size, device=device, dtype=torch.bool) 501 for i in range(size): 502 if num_left_chunks < 0: 503 start = 0 504 else: 505 start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) 506 ending = min((i // chunk_size + 1) * chunk_size, size) 507 ret[i, start:ending] = True 508 return ret 509 510 511def add_optional_chunk_mask( 512 xs: torch.Tensor, 513 masks: torch.Tensor, 514 use_dynamic_chunk: bool, 515 use_dynamic_left_chunk: bool, 516 decoding_chunk_size: int, 517 static_chunk_size: int, 518 num_decoding_left_chunks: int, 519 enable_full_context: bool = True, 520): 521 """Apply optional mask for encoder. 522 523 Args: 524 xs (torch.Tensor): padded input, (B, L, D), L for max length 525 mask (torch.Tensor): mask for xs, (B, 1, L) 526 use_dynamic_chunk (bool): whether to use dynamic chunk or not 527 use_dynamic_left_chunk (bool): whether to use dynamic left chunk for 528 training. 529 decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 530 0: default for training, use random dynamic chunk. 531 <0: for decoding, use full chunk. 532 >0: for decoding, use fixed chunk size as set. 533 static_chunk_size (int): chunk size for static chunk training/decoding 534 if it's greater than 0, if use_dynamic_chunk is true, 535 this parameter will be ignored 536 num_decoding_left_chunks: number of left chunks, this is for decoding, 537 the chunk size is decoding_chunk_size. 538 >=0: use num_decoding_left_chunks 539 <0: use all left chunks 540 enable_full_context (bool): 541 True: chunk size is either [1, 25] or full context(max_len) 542 False: chunk size ~ U[1, 25] 543 544 Returns: 545 torch.Tensor: chunk mask of the input xs. 546 """ 547 # Whether to use chunk mask or not 548 if use_dynamic_chunk: 549 max_len = xs.size(1) 550 if decoding_chunk_size < 0: 551 chunk_size = max_len 552 num_left_chunks = -1 553 elif decoding_chunk_size > 0: 554 chunk_size = decoding_chunk_size 555 num_left_chunks = num_decoding_left_chunks 556 else: 557 # chunk size is either [1, 25] or full context(max_len). 558 # Since we use 4 times subsampling and allow up to 1s(100 frames) 559 # delay, the maximum frame is 100 / 4 = 25. 560 chunk_size = torch.randint(1, max_len, (1,)).item() 561 num_left_chunks = -1 562 if chunk_size > max_len // 2 and enable_full_context: 563 chunk_size = max_len 564 else: 565 chunk_size = chunk_size % 25 + 1 566 if use_dynamic_left_chunk: 567 max_left_chunks = (max_len - 1) // chunk_size 568 num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item() 569 chunk_masks = subsequent_chunk_mask( 570 xs.size(1), chunk_size, num_left_chunks, xs.device 571 ) # (L, L) 572 chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 573 chunk_masks = masks & chunk_masks # (B, L, L) 574 elif static_chunk_size > 0: 575 num_left_chunks = num_decoding_left_chunks 576 chunk_masks = subsequent_chunk_mask( 577 xs.size(1), static_chunk_size, num_left_chunks, xs.device 578 ) # (L, L) 579 chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 580 chunk_masks = masks & chunk_masks # (B, L, L) 581 else: 582 chunk_masks = masks 583 return chunk_masks 584 585 586class ConformerEncoderLayer(nn.Module): 587 """Encoder layer module. 588 Args: 589 size (int): Input dimension. 590 self_attn (torch.nn.Module): Self-attention module instance. 591 `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 592 instance can be used as the argument. 593 feed_forward (torch.nn.Module): Feed-forward module instance. 594 `PositionwiseFeedForward` instance can be used as the argument. 595 feed_forward_macaron (torch.nn.Module): Additional feed-forward module 596 instance. 597 `PositionwiseFeedForward` instance can be used as the argument. 598 conv_module (torch.nn.Module): Convolution module instance. 599 `ConvlutionModule` instance can be used as the argument. 600 dropout_rate (float): Dropout rate. 601 normalize_before (bool): 602 True: use layer_norm before each sub-block. 603 False: use layer_norm after each sub-block. 604 """ 605 606 def __init__( 607 self, 608 size: int, 609 self_attn: torch.nn.Module, 610 feed_forward: Optional[nn.Module] = None, 611 feed_forward_macaron: Optional[nn.Module] = None, 612 conv_module: Optional[nn.Module] = None, 613 dropout_rate: float = 0.1, 614 normalize_before: bool = True, 615 ): 616 """Construct an EncoderLayer object.""" 617 super().__init__() 618 self.self_attn = self_attn 619 self.feed_forward = feed_forward 620 self.feed_forward_macaron = feed_forward_macaron 621 self.conv_module = conv_module 622 self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module 623 self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module 624 if feed_forward_macaron is not None: 625 self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) 626 self.ff_scale = 0.5 627 else: 628 self.ff_scale = 1.0 629 if self.conv_module is not None: 630 self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module 631 self.norm_final = nn.LayerNorm( 632 size, eps=1e-5 633 ) # for the final output of the block 634 self.dropout = nn.Dropout(dropout_rate) 635 self.size = size 636 self.normalize_before = normalize_before 637 638 def forward( 639 self, 640 x: torch.Tensor, 641 mask: torch.Tensor, 642 pos_emb: torch.Tensor, 643 mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 644 att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 645 cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 646 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 647 """Compute encoded features. 648 649 Args: 650 x (torch.Tensor): (#batch, time, size) 651 mask (torch.Tensor): Mask tensor for the input (#batch, time,time), 652 (0, 0, 0) means fake mask. 653 pos_emb (torch.Tensor): positional encoding, must not be None 654 for ConformerEncoderLayer. 655 mask_pad (torch.Tensor): batch padding mask used for conv module. 656 (#batch, 1,time), (0, 0, 0) means fake mask. 657 att_cache (torch.Tensor): Cache tensor of the KEY & VALUE 658 (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. 659 cnn_cache (torch.Tensor): Convolution cache in conformer layer 660 (#batch=1, size, cache_t2) 661 Returns: 662 torch.Tensor: Output tensor (#batch, time, size). 663 torch.Tensor: Mask tensor (#batch, time, time). 664 torch.Tensor: att_cache tensor, 665 (#batch=1, head, cache_t1 + time, d_k * 2). 666 torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). 667 """ 668 669 # whether to use macaron style 670 if self.feed_forward_macaron is not None: 671 residual = x 672 if self.normalize_before: 673 x = self.norm_ff_macaron(x) 674 x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) 675 if not self.normalize_before: 676 x = self.norm_ff_macaron(x) 677 678 # multi-headed self-attention module 679 residual = x 680 if self.normalize_before: 681 x = self.norm_mha(x) 682 x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) 683 x = residual + self.dropout(x_att) 684 if not self.normalize_before: 685 x = self.norm_mha(x) 686 687 # convolution module 688 # Fake new cnn cache here, and then change it in conv_module 689 new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 690 if self.conv_module is not None: 691 residual = x 692 if self.normalize_before: 693 x = self.norm_conv(x) 694 x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 695 x = residual + self.dropout(x) 696 697 if not self.normalize_before: 698 x = self.norm_conv(x) 699 700 # feed forward module 701 residual = x 702 if self.normalize_before: 703 x = self.norm_ff(x) 704 705 x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 706 if not self.normalize_before: 707 x = self.norm_ff(x) 708 709 if self.conv_module is not None: 710 x = self.norm_final(x) 711 712 return x, mask, new_att_cache, new_cnn_cache 713 714 715class EspnetRelPositionalEncoding(torch.nn.Module): 716 """Relative positional encoding module (new implementation). 717 718 Details can be found in https://github.com/espnet/espnet/pull/2816. 719 720 See : Appendix B in https://arxiv.org/abs/1901.02860 721 722 Args: 723 d_model (int): Embedding dimension. 724 dropout_rate (float): Dropout rate. 725 max_len (int): Maximum input length. 726 727 """ 728 729 def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): 730 """Construct an PositionalEncoding object.""" 731 super(EspnetRelPositionalEncoding, self).__init__() 732 self.d_model = d_model 733 self.xscale = math.sqrt(self.d_model) 734 self.dropout = torch.nn.Dropout(p=dropout_rate) 735 self.pe = None 736 self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 737 738 def extend_pe(self, x: torch.Tensor): 739 """Reset the positional encodings.""" 740 if self.pe is not None: 741 # self.pe contains both positive and negative parts 742 # the length of self.pe is 2 * input_len - 1 743 if self.pe.size(1) >= x.size(1) * 2 - 1: 744 if self.pe.dtype != x.dtype or self.pe.device != x.device: 745 self.pe = self.pe.to(dtype=x.dtype, device=x.device) 746 return 747 # Suppose `i` means to the position of query vecotr and `j` means the 748 # position of key vector. We use position relative positions when keys 749 # are to the left (i>j) and negative relative positions otherwise (i<j). 750 pe_positive = torch.zeros(x.size(1), self.d_model) 751 pe_negative = torch.zeros(x.size(1), self.d_model) 752 position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 753 div_term = torch.exp( 754 torch.arange(0, self.d_model, 2, dtype=torch.float32) 755 * -(math.log(10000.0) / self.d_model) 756 ) 757 pe_positive[:, 0::2] = torch.sin(position * div_term) 758 pe_positive[:, 1::2] = torch.cos(position * div_term) 759 pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) 760 pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) 761 762 # Reserve the order of positive indices and concat both positive and 763 # negative indices. This is used to support the shifting trick 764 # as in https://arxiv.org/abs/1901.02860 765 pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) 766 pe_negative = pe_negative[1:].unsqueeze(0) 767 pe = torch.cat([pe_positive, pe_negative], dim=1) 768 self.pe = pe.to(device=x.device, dtype=x.dtype) 769 770 def forward( 771 self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0 772 ) -> Tuple[torch.Tensor, torch.Tensor]: 773 """Add positional encoding. 774 775 Args: 776 x (torch.Tensor): Input tensor (batch, time, `*`). 777 778 Returns: 779 torch.Tensor: Encoded tensor (batch, time, `*`). 780 781 """ 782 self.extend_pe(x) 783 x = x * self.xscale 784 pos_emb = self.position_encoding(size=x.size(1), offset=offset) 785 return self.dropout(x), self.dropout(pos_emb) 786 787 def position_encoding( 788 self, offset: Union[int, torch.Tensor], size: int 789 ) -> torch.Tensor: 790 """For getting encoding in a streaming fashion 791 792 Attention!!!!! 793 we apply dropout only once at the whole utterance level in a none 794 streaming way, but will call this function several times with 795 increasing input size in a streaming scenario, so the dropout will 796 be applied several times. 797 798 Args: 799 offset (int or torch.tensor): start offset 800 size (int): required size of position encoding 801 802 Returns: 803 torch.Tensor: Corresponding encoding 804 """ 805 pos_emb = self.pe[ 806 :, 807 self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size, 808 ] 809 return pos_emb 810 811 812class LinearEmbed(torch.nn.Module): 813 """Linear transform the input without subsampling 814 815 Args: 816 idim (int): Input dimension. 817 odim (int): Output dimension. 818 dropout_rate (float): Dropout rate. 819 820 """ 821 822 def __init__( 823 self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module 824 ): 825 """Construct an linear object.""" 826 super().__init__() 827 self.out = torch.nn.Sequential( 828 torch.nn.Linear(idim, odim), 829 torch.nn.LayerNorm(odim, eps=1e-5), 830 torch.nn.Dropout(dropout_rate), 831 ) 832 self.pos_enc = pos_enc_class # rel_pos_espnet 833 834 def position_encoding( 835 self, offset: Union[int, torch.Tensor], size: int 836 ) -> torch.Tensor: 837 return self.pos_enc.position_encoding(offset, size) 838 839 def forward( 840 self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0 841 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 842 """Input x. 843 844 Args: 845 x (torch.Tensor): Input tensor (#batch, time, idim). 846 x_mask (torch.Tensor): Input mask (#batch, 1, time). 847 848 Returns: 849 torch.Tensor: linear input tensor (#batch, time', odim), 850 where time' = time . 851 torch.Tensor: linear input mask (#batch, 1, time'), 852 where time' = time . 853 854 """ 855 x = self.out(x) 856 x, pos_emb = self.pos_enc(x, offset) 857 return x, pos_emb 858 859 860ATTENTION_CLASSES = { 861 "selfattn": MultiHeadedAttention, 862 "rel_selfattn": RelPositionMultiHeadedAttention, 863} 864 865ACTIVATION_CLASSES = { 866 "hardtanh": torch.nn.Hardtanh, 867 "tanh": torch.nn.Tanh, 868 "relu": torch.nn.ReLU, 869 "selu": torch.nn.SELU, 870 "swish": getattr(torch.nn, "SiLU", Swish), 871 "gelu": torch.nn.GELU, 872} 873 874 875def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: 876 """Make mask tensor containing indices of padded part. 877 878 See description of make_non_pad_mask. 879 880 Args: 881 lengths (torch.Tensor): Batch of lengths (B,). 882 Returns: 883 torch.Tensor: Mask tensor containing indices of padded part. 884 885 Examples: 886 >>> lengths = [5, 3, 2] 887 >>> make_pad_mask(lengths) 888 masks = [[0, 0, 0, 0 ,0], 889 [0, 0, 0, 1, 1], 890 [0, 0, 1, 1, 1]] 891 """ 892 batch_size = lengths.size(0) 893 max_len = max_len if max_len > 0 else lengths.max().item() 894 seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) 895 seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 896 seq_length_expand = lengths.unsqueeze(-1) 897 mask = seq_range_expand >= seq_length_expand 898 return mask 899 900 901# https://github.com/FunAudioLLM/CosyVoice/blob/main/examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml 902class ConformerEncoder(torch.nn.Module): 903 """Conformer encoder module.""" 904 905 def __init__( 906 self, 907 input_size: int, 908 output_size: int = 1024, 909 attention_heads: int = 16, 910 linear_units: int = 4096, 911 num_blocks: int = 6, 912 dropout_rate: float = 0.1, 913 positional_dropout_rate: float = 0.1, 914 attention_dropout_rate: float = 0.0, 915 input_layer: str = "linear", 916 pos_enc_layer_type: str = "rel_pos_espnet", 917 normalize_before: bool = True, 918 static_chunk_size: int = 1, # 1: causal_mask; 0: full_mask 919 use_dynamic_chunk: bool = False, 920 use_dynamic_left_chunk: bool = False, 921 positionwise_conv_kernel_size: int = 1, 922 macaron_style: bool = False, 923 selfattention_layer_type: str = "rel_selfattn", 924 activation_type: str = "swish", 925 use_cnn_module: bool = False, 926 cnn_module_kernel: int = 15, 927 causal: bool = False, 928 cnn_module_norm: str = "batch_norm", 929 key_bias: bool = True, 930 gradient_checkpointing: bool = False, 931 ): 932 """Construct ConformerEncoder 933 934 Args: 935 input_size to use_dynamic_chunk, see in BaseEncoder 936 positionwise_conv_kernel_size (int): Kernel size of positionwise 937 conv1d layer. 938 macaron_style (bool): Whether to use macaron style for 939 positionwise layer. 940 selfattention_layer_type (str): Encoder attention layer type, 941 the parameter has no effect now, it's just for configure 942 compatibility. #'rel_selfattn' 943 activation_type (str): Encoder activation function type. 944 use_cnn_module (bool): Whether to use convolution module. 945 cnn_module_kernel (int): Kernel size of convolution module. 946 causal (bool): whether to use causal convolution or not. 947 key_bias: whether use bias in attention.linear_k, False for whisper models. 948 """ 949 super().__init__() 950 self.output_size = output_size 951 self.embed = LinearEmbed( 952 input_size, 953 output_size, 954 dropout_rate, 955 EspnetRelPositionalEncoding(output_size, positional_dropout_rate), 956 ) 957 self.normalize_before = normalize_before 958 self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) 959 self.gradient_checkpointing = gradient_checkpointing 960 self.use_dynamic_chunk = use_dynamic_chunk 961 962 self.static_chunk_size = static_chunk_size 963 self.use_dynamic_chunk = use_dynamic_chunk 964 self.use_dynamic_left_chunk = use_dynamic_left_chunk 965 activation = ACTIVATION_CLASSES[activation_type]() 966 967 # self-attention module definition 968 encoder_selfattn_layer_args = ( 969 attention_heads, 970 output_size, 971 attention_dropout_rate, 972 key_bias, 973 ) 974 # feed-forward module definition 975 positionwise_layer_args = ( 976 output_size, 977 linear_units, 978 dropout_rate, 979 activation, 980 ) 981 # convolution module definition 982 convolution_layer_args = ( 983 output_size, 984 cnn_module_kernel, 985 activation, 986 cnn_module_norm, 987 causal, 988 ) 989 990 self.encoders = torch.nn.ModuleList( 991 [ 992 ConformerEncoderLayer( 993 output_size, 994 RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args), 995 PositionwiseFeedForward(*positionwise_layer_args), 996 ( 997 PositionwiseFeedForward(*positionwise_layer_args) 998 if macaron_style 999 else None 1000 ), 1001 ( 1002 ConvolutionModule(*convolution_layer_args) 1003 if use_cnn_module 1004 else None 1005 ), 1006 dropout_rate, 1007 normalize_before, 1008 ) 1009 for _ in range(num_blocks) 1010 ] 1011 ) 1012 1013 def forward_layers( 1014 self, 1015 xs: torch.Tensor, 1016 chunk_masks: torch.Tensor, 1017 pos_emb: torch.Tensor, 1018 mask_pad: torch.Tensor, 1019 ) -> torch.Tensor: 1020 for layer in self.encoders: 1021 xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) 1022 return xs 1023 1024 @torch.jit.unused 1025 def forward_layers_checkpointed( 1026 self, 1027 xs: torch.Tensor, 1028 chunk_masks: torch.Tensor, 1029 pos_emb: torch.Tensor, 1030 mask_pad: torch.Tensor, 1031 ) -> torch.Tensor: 1032 for layer in self.encoders: 1033 xs, chunk_masks, _, _ = torch.utils.checkpoint.checkpoint( 1034 layer.__call__, xs, chunk_masks, pos_emb, mask_pad, use_reentrant=False 1035 ) 1036 return xs 1037 1038 def forward( 1039 self, 1040 xs: torch.Tensor, 1041 pad_mask: torch.Tensor, 1042 decoding_chunk_size: int = 0, 1043 num_decoding_left_chunks: int = -1, 1044 ) -> Tuple[torch.Tensor, torch.Tensor]: 1045 """Embed positions in tensor. 1046 1047 Args: 1048 xs: padded input tensor (B, T, D) 1049 xs_lens: input length (B) 1050 decoding_chunk_size: decoding chunk size for dynamic chunk 1051 0: default for training, use random dynamic chunk. 1052 <0: for decoding, use full chunk. 1053 >0: for decoding, use fixed chunk size as set. 1054 num_decoding_left_chunks: number of left chunks, this is for decoding, 1055 the chunk size is decoding_chunk_size. 1056 >=0: use num_decoding_left_chunks 1057 <0: use all left chunks 1058 Returns: 1059 encoder output tensor xs, and subsampled masks 1060 xs: padded output tensor (B, T' ~= T/subsample_rate, D) 1061 masks: torch.Tensor batch padding mask after subsample 1062 (B, 1, T' ~= T/subsample_rate) 1063 NOTE(xcsong): 1064 We pass the `__call__` method of the modules instead of `forward` to the 1065 checkpointing API because `__call__` attaches all the hooks of the module. 1066 https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 1067 """ 1068 T = xs.size(1) 1069 masks = pad_mask.to(torch.bool).unsqueeze(1) # (B, 1, T) 1070 xs, pos_emb = self.embed(xs) 1071 mask_pad = masks # (B, 1, T/subsample_rate) 1072 chunk_masks = add_optional_chunk_mask( 1073 xs, 1074 masks, 1075 self.use_dynamic_chunk, 1076 self.use_dynamic_left_chunk, 1077 decoding_chunk_size, 1078 self.static_chunk_size, 1079 num_decoding_left_chunks, 1080 ) 1081 if self.gradient_checkpointing and self.training: 1082 xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad) 1083 else: 1084 xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) 1085 if self.normalize_before: 1086 xs = self.after_norm(xs) 1087 # Here we assume the mask is not changed in encoder layers, so just 1088 # return the masks before encoder layers, and the masks will be used 1089 # for cross attention with decoder later 1090 return xs, masks
8class ConvolutionModule(nn.Module): 9 """ConvolutionModule in Conformer model.""" 10 11 def __init__( 12 self, 13 channels: int, 14 kernel_size: int = 15, 15 activation: nn.Module = nn.ReLU(), 16 norm: str = "batch_norm", 17 causal: bool = False, 18 bias: bool = True, 19 ): 20 """Construct an ConvolutionModule object. 21 Args: 22 channels (int): The number of channels of conv layers. 23 kernel_size (int): Kernel size of conv layers. 24 causal (int): Whether use causal convolution or not 25 """ 26 super().__init__() 27 28 self.pointwise_conv1 = nn.Conv1d( 29 channels, 30 2 * channels, 31 kernel_size=1, 32 stride=1, 33 padding=0, 34 bias=bias, 35 ) 36 # self.lorder is used to distinguish if it's a causal convolution, 37 # if self.lorder > 0: it's a causal convolution, the input will be 38 # padded with self.lorder frames on the left in forward. 39 # else: it's a symmetrical convolution 40 if causal: 41 padding = 0 42 self.lorder = kernel_size - 1 43 else: 44 # kernel_size should be an odd number for none causal convolution 45 assert (kernel_size - 1) % 2 == 0 46 padding = (kernel_size - 1) // 2 47 self.lorder = 0 48 self.depthwise_conv = nn.Conv1d( 49 channels, 50 channels, 51 kernel_size, 52 stride=1, 53 padding=padding, 54 groups=channels, 55 bias=bias, 56 ) 57 58 assert norm in ["batch_norm", "layer_norm"] 59 if norm == "batch_norm": 60 self.use_layer_norm = False 61 self.norm = nn.BatchNorm1d(channels) 62 else: 63 self.use_layer_norm = True 64 self.norm = nn.LayerNorm(channels) 65 66 self.pointwise_conv2 = nn.Conv1d( 67 channels, 68 channels, 69 kernel_size=1, 70 stride=1, 71 padding=0, 72 bias=bias, 73 ) 74 self.activation = activation 75 76 def forward( 77 self, 78 x: torch.Tensor, 79 mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 80 cache: torch.Tensor = torch.zeros((0, 0, 0)), 81 ) -> Tuple[torch.Tensor, torch.Tensor]: 82 """Compute convolution module. 83 Args: 84 x (torch.Tensor): Input tensor (#batch, time, channels). 85 mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 86 (0, 0, 0) means fake mask. 87 cache (torch.Tensor): left context cache, it is only 88 used in causal convolution (#batch, channels, cache_t), 89 (0, 0, 0) meas fake cache. 90 Returns: 91 torch.Tensor: Output tensor (#batch, time, channels). 92 """ 93 # exchange the temporal dimension and the feature dimension 94 x = x.transpose(1, 2) # (#batch, channels, time) 95 96 # mask batch padding 97 if mask_pad.size(2) > 0: # time > 0 98 x.masked_fill_(~mask_pad, 0.0) 99 100 if self.lorder > 0: 101 if cache.size(2) == 0: # cache_t == 0 102 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) 103 else: 104 assert cache.size(0) == x.size(0) # equal batch 105 assert cache.size(1) == x.size(1) # equal channel 106 x = torch.cat((cache, x), dim=2) 107 assert x.size(2) > self.lorder 108 new_cache = x[:, :, -self.lorder :] 109 else: 110 # It's better we just return None if no cache is required, 111 # However, for JIT export, here we just fake one tensor instead of 112 # None. 113 new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 114 115 # GLU mechanism 116 x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 117 x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 118 119 # 1D Depthwise Conv 120 x = self.depthwise_conv(x) 121 if self.use_layer_norm: 122 x = x.transpose(1, 2) 123 x = self.activation(self.norm(x)) 124 if self.use_layer_norm: 125 x = x.transpose(1, 2) 126 x = self.pointwise_conv2(x) 127 # mask batch padding 128 if mask_pad.size(2) > 0: # time > 0 129 x.masked_fill_(~mask_pad, 0.0) 130 131 return x.transpose(1, 2), new_cache
ConvolutionModule in Conformer model.
11 def __init__( 12 self, 13 channels: int, 14 kernel_size: int = 15, 15 activation: nn.Module = nn.ReLU(), 16 norm: str = "batch_norm", 17 causal: bool = False, 18 bias: bool = True, 19 ): 20 """Construct an ConvolutionModule object. 21 Args: 22 channels (int): The number of channels of conv layers. 23 kernel_size (int): Kernel size of conv layers. 24 causal (int): Whether use causal convolution or not 25 """ 26 super().__init__() 27 28 self.pointwise_conv1 = nn.Conv1d( 29 channels, 30 2 * channels, 31 kernel_size=1, 32 stride=1, 33 padding=0, 34 bias=bias, 35 ) 36 # self.lorder is used to distinguish if it's a causal convolution, 37 # if self.lorder > 0: it's a causal convolution, the input will be 38 # padded with self.lorder frames on the left in forward. 39 # else: it's a symmetrical convolution 40 if causal: 41 padding = 0 42 self.lorder = kernel_size - 1 43 else: 44 # kernel_size should be an odd number for none causal convolution 45 assert (kernel_size - 1) % 2 == 0 46 padding = (kernel_size - 1) // 2 47 self.lorder = 0 48 self.depthwise_conv = nn.Conv1d( 49 channels, 50 channels, 51 kernel_size, 52 stride=1, 53 padding=padding, 54 groups=channels, 55 bias=bias, 56 ) 57 58 assert norm in ["batch_norm", "layer_norm"] 59 if norm == "batch_norm": 60 self.use_layer_norm = False 61 self.norm = nn.BatchNorm1d(channels) 62 else: 63 self.use_layer_norm = True 64 self.norm = nn.LayerNorm(channels) 65 66 self.pointwise_conv2 = nn.Conv1d( 67 channels, 68 channels, 69 kernel_size=1, 70 stride=1, 71 padding=0, 72 bias=bias, 73 ) 74 self.activation = activation
Construct an ConvolutionModule object. Args: channels (int): The number of channels of conv layers. kernel_size (int): Kernel size of conv layers. causal (int): Whether use causal convolution or not
76 def forward( 77 self, 78 x: torch.Tensor, 79 mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 80 cache: torch.Tensor = torch.zeros((0, 0, 0)), 81 ) -> Tuple[torch.Tensor, torch.Tensor]: 82 """Compute convolution module. 83 Args: 84 x (torch.Tensor): Input tensor (#batch, time, channels). 85 mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 86 (0, 0, 0) means fake mask. 87 cache (torch.Tensor): left context cache, it is only 88 used in causal convolution (#batch, channels, cache_t), 89 (0, 0, 0) meas fake cache. 90 Returns: 91 torch.Tensor: Output tensor (#batch, time, channels). 92 """ 93 # exchange the temporal dimension and the feature dimension 94 x = x.transpose(1, 2) # (#batch, channels, time) 95 96 # mask batch padding 97 if mask_pad.size(2) > 0: # time > 0 98 x.masked_fill_(~mask_pad, 0.0) 99 100 if self.lorder > 0: 101 if cache.size(2) == 0: # cache_t == 0 102 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) 103 else: 104 assert cache.size(0) == x.size(0) # equal batch 105 assert cache.size(1) == x.size(1) # equal channel 106 x = torch.cat((cache, x), dim=2) 107 assert x.size(2) > self.lorder 108 new_cache = x[:, :, -self.lorder :] 109 else: 110 # It's better we just return None if no cache is required, 111 # However, for JIT export, here we just fake one tensor instead of 112 # None. 113 new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 114 115 # GLU mechanism 116 x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 117 x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 118 119 # 1D Depthwise Conv 120 x = self.depthwise_conv(x) 121 if self.use_layer_norm: 122 x = x.transpose(1, 2) 123 x = self.activation(self.norm(x)) 124 if self.use_layer_norm: 125 x = x.transpose(1, 2) 126 x = self.pointwise_conv2(x) 127 # mask batch padding 128 if mask_pad.size(2) > 0: # time > 0 129 x.masked_fill_(~mask_pad, 0.0) 130 131 return x.transpose(1, 2), new_cache
Compute convolution module. Args: x (torch.Tensor): Input tensor (#batch, time, channels). mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), (0, 0, 0) means fake mask. cache (torch.Tensor): left context cache, it is only used in causal convolution (#batch, channels, cache_t), (0, 0, 0) meas fake cache. Returns: torch.Tensor: Output tensor (#batch, time, channels).
134class PositionwiseFeedForward(torch.nn.Module): 135 """Positionwise feed forward layer. 136 137 FeedForward are appied on each position of the sequence. 138 The output dim is same with the input dim. 139 140 Args: 141 idim (int): Input dimenstion. 142 hidden_units (int): The number of hidden units. 143 dropout_rate (float): Dropout rate. 144 activation (torch.nn.Module): Activation function 145 """ 146 147 def __init__( 148 self, 149 idim: int, 150 hidden_units: int, 151 dropout_rate: float, 152 activation: torch.nn.Module = torch.nn.ReLU(), 153 ): 154 """Construct a PositionwiseFeedForward object.""" 155 super(PositionwiseFeedForward, self).__init__() 156 self.w_1 = torch.nn.Linear(idim, hidden_units) 157 self.activation = activation 158 self.dropout = torch.nn.Dropout(dropout_rate) 159 self.w_2 = torch.nn.Linear(hidden_units, idim) 160 161 def forward(self, xs: torch.Tensor) -> torch.Tensor: 162 """Forward function. 163 164 Args: 165 xs: input tensor (B, L, D) 166 Returns: 167 output tensor, (B, L, D) 168 """ 169 return self.w_2(self.dropout(self.activation(self.w_1(xs))))
Positionwise feed forward layer.
FeedForward are appied on each position of the sequence. The output dim is same with the input dim.
Args: idim (int): Input dimenstion. hidden_units (int): The number of hidden units. dropout_rate (float): Dropout rate. activation (torch.nn.Module): Activation function
147 def __init__( 148 self, 149 idim: int, 150 hidden_units: int, 151 dropout_rate: float, 152 activation: torch.nn.Module = torch.nn.ReLU(), 153 ): 154 """Construct a PositionwiseFeedForward object.""" 155 super(PositionwiseFeedForward, self).__init__() 156 self.w_1 = torch.nn.Linear(idim, hidden_units) 157 self.activation = activation 158 self.dropout = torch.nn.Dropout(dropout_rate) 159 self.w_2 = torch.nn.Linear(hidden_units, idim)
Construct a PositionwiseFeedForward object.
161 def forward(self, xs: torch.Tensor) -> torch.Tensor: 162 """Forward function. 163 164 Args: 165 xs: input tensor (B, L, D) 166 Returns: 167 output tensor, (B, L, D) 168 """ 169 return self.w_2(self.dropout(self.activation(self.w_1(xs))))
Forward function.
Args: xs: input tensor (B, L, D) Returns: output tensor, (B, L, D)
172class Swish(torch.nn.Module): 173 """Construct an Swish object.""" 174 175 def forward(self, x: torch.Tensor) -> torch.Tensor: 176 """Return Swish activation function.""" 177 return x * torch.sigmoid(x)
Construct an Swish object.
180class MultiHeadedAttention(nn.Module): 181 """Multi-Head Attention layer. 182 183 Args: 184 n_head (int): The number of heads. 185 n_feat (int): The number of features. 186 dropout_rate (float): Dropout rate. 187 188 """ 189 190 def __init__( 191 self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True 192 ): 193 """Construct an MultiHeadedAttention object.""" 194 super().__init__() 195 assert n_feat % n_head == 0 196 # We assume d_v always equals d_k 197 self.d_k = n_feat // n_head 198 self.h = n_head 199 self.linear_q = nn.Linear(n_feat, n_feat) 200 self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) 201 self.linear_v = nn.Linear(n_feat, n_feat) 202 self.linear_out = nn.Linear(n_feat, n_feat) 203 self.dropout = nn.Dropout(p=dropout_rate) 204 205 def forward_qkv( 206 self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 207 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 208 """Transform query, key and value. 209 210 Args: 211 query (torch.Tensor): Query tensor (#batch, time1, size). 212 key (torch.Tensor): Key tensor (#batch, time2, size). 213 value (torch.Tensor): Value tensor (#batch, time2, size). 214 215 Returns: 216 torch.Tensor: Transformed query tensor, size 217 (#batch, n_head, time1, d_k). 218 torch.Tensor: Transformed key tensor, size 219 (#batch, n_head, time2, d_k). 220 torch.Tensor: Transformed value tensor, size 221 (#batch, n_head, time2, d_k). 222 223 """ 224 n_batch = query.size(0) 225 q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 226 k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 227 v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 228 q = q.transpose(1, 2) # (batch, head, time1, d_k) 229 k = k.transpose(1, 2) # (batch, head, time2, d_k) 230 v = v.transpose(1, 2) # (batch, head, time2, d_k) 231 return q, k, v 232 233 def forward_attention( 234 self, 235 value: torch.Tensor, 236 scores: torch.Tensor, 237 mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 238 ) -> torch.Tensor: 239 """Compute attention context vector. 240 241 Args: 242 value (torch.Tensor): Transformed value, size 243 (#batch, n_head, time2, d_k). 244 scores (torch.Tensor): Attention score, size 245 (#batch, n_head, time1, time2). 246 mask (torch.Tensor): Mask, size (#batch, 1, time2) or 247 (#batch, time1, time2), (0, 0, 0) means fake mask. 248 249 Returns: 250 torch.Tensor: Transformed value (#batch, time1, d_model) 251 weighted by the attention score (#batch, time1, time2). 252 253 """ 254 n_batch = value.size(0) 255 256 if mask.size(2) > 0: # time2 > 0 257 mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 258 # For last chunk, time2 might be larger than scores.size(-1) 259 mask = mask[:, :, :, : scores.size(-1)] # (batch, 1, *, time2) 260 scores = scores.masked_fill(mask, -float("inf")) 261 attn = torch.softmax(scores, dim=-1).masked_fill( 262 mask, 0.0 263 ) # (batch, head, time1, time2) 264 265 else: 266 attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 267 268 p_attn = self.dropout(attn) 269 x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 270 x = ( 271 x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) 272 ) # (batch, time1, d_model) 273 274 return self.linear_out(x) # (batch, time1, d_model) 275 276 def forward( 277 self, 278 query: torch.Tensor, 279 key: torch.Tensor, 280 value: torch.Tensor, 281 mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 282 pos_emb: torch.Tensor = torch.empty(0), 283 cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 284 ) -> Tuple[torch.Tensor, torch.Tensor]: 285 """Compute scaled dot product attention. 286 287 Args: 288 query (torch.Tensor): Query tensor (#batch, time1, size). 289 key (torch.Tensor): Key tensor (#batch, time2, size). 290 value (torch.Tensor): Value tensor (#batch, time2, size). 291 mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 292 (#batch, time1, time2). 293 1.When applying cross attention between decoder and encoder, 294 the batch padding mask for input is in (#batch, 1, T) shape. 295 2.When applying self attention of encoder, 296 the mask is in (#batch, T, T) shape. 297 3.When applying self attention of decoder, 298 the mask is in (#batch, L, L) shape. 299 4.If the different position in decoder see different block 300 of the encoder, such as Mocha, the passed in mask could be 301 in (#batch, L, T) shape. But there is no such case in current 302 CosyVoice. 303 cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), 304 where `cache_t == chunk_size * num_decoding_left_chunks` 305 and `head * d_k == size` 306 307 308 Returns: 309 torch.Tensor: Output tensor (#batch, time1, d_model). 310 torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) 311 where `cache_t == chunk_size * num_decoding_left_chunks` 312 and `head * d_k == size` 313 314 """ 315 q, k, v = self.forward_qkv(query, key, value) 316 if cache.size(0) > 0: 317 key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) 318 k = torch.cat([key_cache, k], dim=2) 319 v = torch.cat([value_cache, v], dim=2) 320 new_cache = torch.cat((k, v), dim=-1) 321 322 scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 323 return self.forward_attention(v, scores, mask), new_cache
Multi-Head Attention layer.
Args: n_head (int): The number of heads. n_feat (int): The number of features. dropout_rate (float): Dropout rate.
190 def __init__( 191 self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True 192 ): 193 """Construct an MultiHeadedAttention object.""" 194 super().__init__() 195 assert n_feat % n_head == 0 196 # We assume d_v always equals d_k 197 self.d_k = n_feat // n_head 198 self.h = n_head 199 self.linear_q = nn.Linear(n_feat, n_feat) 200 self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) 201 self.linear_v = nn.Linear(n_feat, n_feat) 202 self.linear_out = nn.Linear(n_feat, n_feat) 203 self.dropout = nn.Dropout(p=dropout_rate)
Construct an MultiHeadedAttention object.
205 def forward_qkv( 206 self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 207 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 208 """Transform query, key and value. 209 210 Args: 211 query (torch.Tensor): Query tensor (#batch, time1, size). 212 key (torch.Tensor): Key tensor (#batch, time2, size). 213 value (torch.Tensor): Value tensor (#batch, time2, size). 214 215 Returns: 216 torch.Tensor: Transformed query tensor, size 217 (#batch, n_head, time1, d_k). 218 torch.Tensor: Transformed key tensor, size 219 (#batch, n_head, time2, d_k). 220 torch.Tensor: Transformed value tensor, size 221 (#batch, n_head, time2, d_k). 222 223 """ 224 n_batch = query.size(0) 225 q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 226 k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 227 v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 228 q = q.transpose(1, 2) # (batch, head, time1, d_k) 229 k = k.transpose(1, 2) # (batch, head, time2, d_k) 230 v = v.transpose(1, 2) # (batch, head, time2, d_k) 231 return q, k, v
Transform query, key and value.
Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size).
Returns: torch.Tensor: Transformed query tensor, size (#batch, n_head, time1, d_k). torch.Tensor: Transformed key tensor, size (#batch, n_head, time2, d_k). torch.Tensor: Transformed value tensor, size (#batch, n_head, time2, d_k).
233 def forward_attention( 234 self, 235 value: torch.Tensor, 236 scores: torch.Tensor, 237 mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 238 ) -> torch.Tensor: 239 """Compute attention context vector. 240 241 Args: 242 value (torch.Tensor): Transformed value, size 243 (#batch, n_head, time2, d_k). 244 scores (torch.Tensor): Attention score, size 245 (#batch, n_head, time1, time2). 246 mask (torch.Tensor): Mask, size (#batch, 1, time2) or 247 (#batch, time1, time2), (0, 0, 0) means fake mask. 248 249 Returns: 250 torch.Tensor: Transformed value (#batch, time1, d_model) 251 weighted by the attention score (#batch, time1, time2). 252 253 """ 254 n_batch = value.size(0) 255 256 if mask.size(2) > 0: # time2 > 0 257 mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 258 # For last chunk, time2 might be larger than scores.size(-1) 259 mask = mask[:, :, :, : scores.size(-1)] # (batch, 1, *, time2) 260 scores = scores.masked_fill(mask, -float("inf")) 261 attn = torch.softmax(scores, dim=-1).masked_fill( 262 mask, 0.0 263 ) # (batch, head, time1, time2) 264 265 else: 266 attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 267 268 p_attn = self.dropout(attn) 269 x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 270 x = ( 271 x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) 272 ) # (batch, time1, d_model) 273 274 return self.linear_out(x) # (batch, time1, d_model)
Compute attention context vector.
Args: value (torch.Tensor): Transformed value, size (#batch, n_head, time2, d_k). scores (torch.Tensor): Attention score, size (#batch, n_head, time1, time2). mask (torch.Tensor): Mask, size (#batch, 1, time2) or (#batch, time1, time2), (0, 0, 0) means fake mask.
Returns: torch.Tensor: Transformed value (#batch, time1, d_model) weighted by the attention score (#batch, time1, time2).
276 def forward( 277 self, 278 query: torch.Tensor, 279 key: torch.Tensor, 280 value: torch.Tensor, 281 mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 282 pos_emb: torch.Tensor = torch.empty(0), 283 cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 284 ) -> Tuple[torch.Tensor, torch.Tensor]: 285 """Compute scaled dot product attention. 286 287 Args: 288 query (torch.Tensor): Query tensor (#batch, time1, size). 289 key (torch.Tensor): Key tensor (#batch, time2, size). 290 value (torch.Tensor): Value tensor (#batch, time2, size). 291 mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 292 (#batch, time1, time2). 293 1.When applying cross attention between decoder and encoder, 294 the batch padding mask for input is in (#batch, 1, T) shape. 295 2.When applying self attention of encoder, 296 the mask is in (#batch, T, T) shape. 297 3.When applying self attention of decoder, 298 the mask is in (#batch, L, L) shape. 299 4.If the different position in decoder see different block 300 of the encoder, such as Mocha, the passed in mask could be 301 in (#batch, L, T) shape. But there is no such case in current 302 CosyVoice. 303 cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), 304 where `cache_t == chunk_size * num_decoding_left_chunks` 305 and `head * d_k == size` 306 307 308 Returns: 309 torch.Tensor: Output tensor (#batch, time1, d_model). 310 torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) 311 where `cache_t == chunk_size * num_decoding_left_chunks` 312 and `head * d_k == size` 313 314 """ 315 q, k, v = self.forward_qkv(query, key, value) 316 if cache.size(0) > 0: 317 key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) 318 k = torch.cat([key_cache, k], dim=2) 319 v = torch.cat([value_cache, v], dim=2) 320 new_cache = torch.cat((k, v), dim=-1) 321 322 scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 323 return self.forward_attention(v, scores, mask), new_cache
Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
CosyVoice.
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where cache_t == chunk_size * num_decoding_left_chunks
and head * d_k == size
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where cache_t == chunk_size * num_decoding_left_chunks
and head * d_k == size
326class RelPositionMultiHeadedAttention(MultiHeadedAttention): 327 """Multi-Head Attention layer with relative position encoding. 328 Paper: https://arxiv.org/abs/1901.02860 329 Args: 330 n_head (int): The number of heads. 331 n_feat (int): The number of features. 332 dropout_rate (float): Dropout rate. 333 """ 334 335 def __init__( 336 self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True 337 ): 338 """Construct an RelPositionMultiHeadedAttention object.""" 339 super().__init__(n_head, n_feat, dropout_rate, key_bias) 340 # linear transformation for positional encoding 341 self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) 342 # these two learnable bias are used in matrix c and matrix d 343 # as described in https://arxiv.org/abs/1901.02860 Section 3.3 344 self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 345 self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 346 torch.nn.init.xavier_uniform_(self.pos_bias_u) 347 torch.nn.init.xavier_uniform_(self.pos_bias_v) 348 349 def rel_shift(self, x: torch.Tensor) -> torch.Tensor: 350 """Compute relative positional encoding. 351 352 Args: 353 x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). 354 time1 means the length of query vector. 355 356 Returns: 357 torch.Tensor: Output tensor. 358 359 """ 360 zero_pad = torch.zeros( 361 (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype 362 ) 363 x_padded = torch.cat([zero_pad, x], dim=-1) 364 365 x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2)) 366 x = x_padded[:, :, 1:].view_as(x)[ 367 :, :, :, : x.size(-1) // 2 + 1 368 ] # only keep the positions from 0 to time2 369 return x 370 371 def forward( 372 self, 373 query: torch.Tensor, 374 key: torch.Tensor, 375 value: torch.Tensor, 376 mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 377 pos_emb: torch.Tensor = torch.empty(0), 378 cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 379 ) -> Tuple[torch.Tensor, torch.Tensor]: 380 """Compute 'Scaled Dot Product Attention' with rel. positional encoding. 381 Args: 382 query (torch.Tensor): Query tensor (#batch, time1, size). 383 key (torch.Tensor): Key tensor (#batch, time2, size). 384 value (torch.Tensor): Value tensor (#batch, time2, size). 385 mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 386 (#batch, time1, time2), (0, 0, 0) means fake mask. 387 pos_emb (torch.Tensor): Positional embedding tensor 388 (#batch, time2, size). 389 cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), 390 where `cache_t == chunk_size * num_decoding_left_chunks` 391 and `head * d_k == size` 392 Returns: 393 torch.Tensor: Output tensor (#batch, time1, d_model). 394 torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) 395 where `cache_t == chunk_size * num_decoding_left_chunks` 396 and `head * d_k == size` 397 """ 398 q, k, v = self.forward_qkv(query, key, value) 399 q = q.transpose(1, 2) # (batch, time1, head, d_k) 400 401 if cache.size(0) > 0: 402 key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) 403 k = torch.cat([key_cache, k], dim=2) 404 v = torch.cat([value_cache, v], dim=2) 405 # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's 406 # non-trivial to calculate `next_cache_start` here. 407 new_cache = torch.cat((k, v), dim=-1) 408 409 n_batch_pos = pos_emb.size(0) 410 p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) 411 p = p.transpose(1, 2) # (batch, head, time1, d_k) 412 413 # (batch, head, time1, d_k) 414 q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 415 # (batch, head, time1, d_k) 416 q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 417 418 # compute attention score 419 # first compute matrix a and matrix c 420 # as described in https://arxiv.org/abs/1901.02860 Section 3.3 421 # (batch, head, time1, time2) 422 matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 423 424 # compute matrix b and matrix d 425 # (batch, head, time1, time2) 426 matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 427 # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used 428 if matrix_ac.shape != matrix_bd.shape: 429 matrix_bd = self.rel_shift(matrix_bd) 430 431 scores = (matrix_ac + matrix_bd) / math.sqrt( 432 self.d_k 433 ) # (batch, head, time1, time2) 434 435 return self.forward_attention(v, scores, mask), new_cache
Multi-Head Attention layer with relative position encoding. Paper: https://arxiv.org/abs/1901.02860 Args: n_head (int): The number of heads. n_feat (int): The number of features. dropout_rate (float): Dropout rate.
335 def __init__( 336 self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True 337 ): 338 """Construct an RelPositionMultiHeadedAttention object.""" 339 super().__init__(n_head, n_feat, dropout_rate, key_bias) 340 # linear transformation for positional encoding 341 self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) 342 # these two learnable bias are used in matrix c and matrix d 343 # as described in https://arxiv.org/abs/1901.02860 Section 3.3 344 self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 345 self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 346 torch.nn.init.xavier_uniform_(self.pos_bias_u) 347 torch.nn.init.xavier_uniform_(self.pos_bias_v)
Construct an RelPositionMultiHeadedAttention object.
349 def rel_shift(self, x: torch.Tensor) -> torch.Tensor: 350 """Compute relative positional encoding. 351 352 Args: 353 x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). 354 time1 means the length of query vector. 355 356 Returns: 357 torch.Tensor: Output tensor. 358 359 """ 360 zero_pad = torch.zeros( 361 (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype 362 ) 363 x_padded = torch.cat([zero_pad, x], dim=-1) 364 365 x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2)) 366 x = x_padded[:, :, 1:].view_as(x)[ 367 :, :, :, : x.size(-1) // 2 + 1 368 ] # only keep the positions from 0 to time2 369 return x
Compute relative positional encoding.
Args: x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector.
Returns: torch.Tensor: Output tensor.
371 def forward( 372 self, 373 query: torch.Tensor, 374 key: torch.Tensor, 375 value: torch.Tensor, 376 mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 377 pos_emb: torch.Tensor = torch.empty(0), 378 cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 379 ) -> Tuple[torch.Tensor, torch.Tensor]: 380 """Compute 'Scaled Dot Product Attention' with rel. positional encoding. 381 Args: 382 query (torch.Tensor): Query tensor (#batch, time1, size). 383 key (torch.Tensor): Key tensor (#batch, time2, size). 384 value (torch.Tensor): Value tensor (#batch, time2, size). 385 mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 386 (#batch, time1, time2), (0, 0, 0) means fake mask. 387 pos_emb (torch.Tensor): Positional embedding tensor 388 (#batch, time2, size). 389 cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), 390 where `cache_t == chunk_size * num_decoding_left_chunks` 391 and `head * d_k == size` 392 Returns: 393 torch.Tensor: Output tensor (#batch, time1, d_model). 394 torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) 395 where `cache_t == chunk_size * num_decoding_left_chunks` 396 and `head * d_k == size` 397 """ 398 q, k, v = self.forward_qkv(query, key, value) 399 q = q.transpose(1, 2) # (batch, time1, head, d_k) 400 401 if cache.size(0) > 0: 402 key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) 403 k = torch.cat([key_cache, k], dim=2) 404 v = torch.cat([value_cache, v], dim=2) 405 # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's 406 # non-trivial to calculate `next_cache_start` here. 407 new_cache = torch.cat((k, v), dim=-1) 408 409 n_batch_pos = pos_emb.size(0) 410 p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) 411 p = p.transpose(1, 2) # (batch, head, time1, d_k) 412 413 # (batch, head, time1, d_k) 414 q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 415 # (batch, head, time1, d_k) 416 q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 417 418 # compute attention score 419 # first compute matrix a and matrix c 420 # as described in https://arxiv.org/abs/1901.02860 Section 3.3 421 # (batch, head, time1, time2) 422 matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 423 424 # compute matrix b and matrix d 425 # (batch, head, time1, time2) 426 matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 427 # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used 428 if matrix_ac.shape != matrix_bd.shape: 429 matrix_bd = self.rel_shift(matrix_bd) 430 431 scores = (matrix_ac + matrix_bd) / math.sqrt( 432 self.d_k 433 ) # (batch, head, time1, time2) 434 435 return self.forward_attention(v, scores, mask), new_cache
Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where cache_t == chunk_size * num_decoding_left_chunks
and head * d_k == size
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where cache_t == chunk_size * num_decoding_left_chunks
and head * d_k == size
Inherited Members
438def subsequent_mask( 439 size: int, 440 device: torch.device = torch.device("cpu"), 441) -> torch.Tensor: 442 """Create mask for subsequent steps (size, size). 443 444 This mask is used only in decoder which works in an auto-regressive mode. 445 This means the current step could only do attention with its left steps. 446 447 In encoder, fully attention is used when streaming is not necessary and 448 the sequence is not long. In this case, no attention mask is needed. 449 450 When streaming is need, chunk-based attention is used in encoder. See 451 subsequent_chunk_mask for the chunk-based attention mask. 452 453 Args: 454 size (int): size of mask 455 str device (str): "cpu" or "cuda" or torch.Tensor.device 456 dtype (torch.device): result dtype 457 458 Returns: 459 torch.Tensor: mask 460 461 Examples: 462 >>> subsequent_mask(3) 463 [[1, 0, 0], 464 [1, 1, 0], 465 [1, 1, 1]] 466 """ 467 arange = torch.arange(size, device=device) 468 mask = arange.expand(size, size) 469 arange = arange.unsqueeze(-1) 470 mask = mask <= arange 471 return mask
Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode. This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See subsequent_chunk_mask for the chunk-based attention mask.
Args: size (int): size of mask str device (str): "cpu" or "cuda" or torch.Tensor.device dtype (torch.device): result dtype
Returns: torch.Tensor: mask
Examples:
subsequent_mask(3) [[1, 0, 0], [1, 1, 0], [1, 1, 1]]
474def subsequent_chunk_mask( 475 size: int, 476 chunk_size: int, 477 num_left_chunks: int = -1, 478 device: torch.device = torch.device("cpu"), 479) -> torch.Tensor: 480 """Create mask for subsequent steps (size, size) with chunk size, 481 this is for streaming encoder 482 483 Args: 484 size (int): size of mask 485 chunk_size (int): size of chunk 486 num_left_chunks (int): number of left chunks 487 <0: use full chunk 488 >=0: use num_left_chunks 489 device (torch.device): "cpu" or "cuda" or torch.Tensor.device 490 491 Returns: 492 torch.Tensor: mask 493 494 Examples: 495 >>> subsequent_chunk_mask(4, 2) 496 [[1, 1, 0, 0], 497 [1, 1, 0, 0], 498 [1, 1, 1, 1], 499 [1, 1, 1, 1]] 500 """ 501 ret = torch.zeros(size, size, device=device, dtype=torch.bool) 502 for i in range(size): 503 if num_left_chunks < 0: 504 start = 0 505 else: 506 start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) 507 ending = min((i // chunk_size + 1) * chunk_size, size) 508 ret[i, start:ending] = True 509 return ret
Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder
Args: size (int): size of mask chunk_size (int): size of chunk num_left_chunks (int): number of left chunks <0: use full chunk
=0: use num_left_chunks device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns: torch.Tensor: mask
Examples:
subsequent_chunk_mask(4, 2) [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]
512def add_optional_chunk_mask( 513 xs: torch.Tensor, 514 masks: torch.Tensor, 515 use_dynamic_chunk: bool, 516 use_dynamic_left_chunk: bool, 517 decoding_chunk_size: int, 518 static_chunk_size: int, 519 num_decoding_left_chunks: int, 520 enable_full_context: bool = True, 521): 522 """Apply optional mask for encoder. 523 524 Args: 525 xs (torch.Tensor): padded input, (B, L, D), L for max length 526 mask (torch.Tensor): mask for xs, (B, 1, L) 527 use_dynamic_chunk (bool): whether to use dynamic chunk or not 528 use_dynamic_left_chunk (bool): whether to use dynamic left chunk for 529 training. 530 decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 531 0: default for training, use random dynamic chunk. 532 <0: for decoding, use full chunk. 533 >0: for decoding, use fixed chunk size as set. 534 static_chunk_size (int): chunk size for static chunk training/decoding 535 if it's greater than 0, if use_dynamic_chunk is true, 536 this parameter will be ignored 537 num_decoding_left_chunks: number of left chunks, this is for decoding, 538 the chunk size is decoding_chunk_size. 539 >=0: use num_decoding_left_chunks 540 <0: use all left chunks 541 enable_full_context (bool): 542 True: chunk size is either [1, 25] or full context(max_len) 543 False: chunk size ~ U[1, 25] 544 545 Returns: 546 torch.Tensor: chunk mask of the input xs. 547 """ 548 # Whether to use chunk mask or not 549 if use_dynamic_chunk: 550 max_len = xs.size(1) 551 if decoding_chunk_size < 0: 552 chunk_size = max_len 553 num_left_chunks = -1 554 elif decoding_chunk_size > 0: 555 chunk_size = decoding_chunk_size 556 num_left_chunks = num_decoding_left_chunks 557 else: 558 # chunk size is either [1, 25] or full context(max_len). 559 # Since we use 4 times subsampling and allow up to 1s(100 frames) 560 # delay, the maximum frame is 100 / 4 = 25. 561 chunk_size = torch.randint(1, max_len, (1,)).item() 562 num_left_chunks = -1 563 if chunk_size > max_len // 2 and enable_full_context: 564 chunk_size = max_len 565 else: 566 chunk_size = chunk_size % 25 + 1 567 if use_dynamic_left_chunk: 568 max_left_chunks = (max_len - 1) // chunk_size 569 num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item() 570 chunk_masks = subsequent_chunk_mask( 571 xs.size(1), chunk_size, num_left_chunks, xs.device 572 ) # (L, L) 573 chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 574 chunk_masks = masks & chunk_masks # (B, L, L) 575 elif static_chunk_size > 0: 576 num_left_chunks = num_decoding_left_chunks 577 chunk_masks = subsequent_chunk_mask( 578 xs.size(1), static_chunk_size, num_left_chunks, xs.device 579 ) # (L, L) 580 chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 581 chunk_masks = masks & chunk_masks # (B, L, L) 582 else: 583 chunk_masks = masks 584 return chunk_masks
Apply optional mask for encoder.
Args: xs (torch.Tensor): padded input, (B, L, D), L for max length mask (torch.Tensor): mask for xs, (B, 1, L) use_dynamic_chunk (bool): whether to use dynamic chunk or not use_dynamic_left_chunk (bool): whether to use dynamic left chunk for training. decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 0: default for training, use random dynamic chunk. <0: for decoding, use full chunk.
0: for decoding, use fixed chunk size as set. static_chunk_size (int): chunk size for static chunk training/decoding if it's greater than 0, if use_dynamic_chunk is true, this parameter will be ignored num_decoding_left_chunks: number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. =0: use num_decoding_left_chunks <0: use all left chunks enable_full_context (bool): True: chunk size is either [1, 25] or full context(max_len) False: chunk size ~ U[1, 25]
Returns: torch.Tensor: chunk mask of the input xs.
587class ConformerEncoderLayer(nn.Module): 588 """Encoder layer module. 589 Args: 590 size (int): Input dimension. 591 self_attn (torch.nn.Module): Self-attention module instance. 592 `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 593 instance can be used as the argument. 594 feed_forward (torch.nn.Module): Feed-forward module instance. 595 `PositionwiseFeedForward` instance can be used as the argument. 596 feed_forward_macaron (torch.nn.Module): Additional feed-forward module 597 instance. 598 `PositionwiseFeedForward` instance can be used as the argument. 599 conv_module (torch.nn.Module): Convolution module instance. 600 `ConvlutionModule` instance can be used as the argument. 601 dropout_rate (float): Dropout rate. 602 normalize_before (bool): 603 True: use layer_norm before each sub-block. 604 False: use layer_norm after each sub-block. 605 """ 606 607 def __init__( 608 self, 609 size: int, 610 self_attn: torch.nn.Module, 611 feed_forward: Optional[nn.Module] = None, 612 feed_forward_macaron: Optional[nn.Module] = None, 613 conv_module: Optional[nn.Module] = None, 614 dropout_rate: float = 0.1, 615 normalize_before: bool = True, 616 ): 617 """Construct an EncoderLayer object.""" 618 super().__init__() 619 self.self_attn = self_attn 620 self.feed_forward = feed_forward 621 self.feed_forward_macaron = feed_forward_macaron 622 self.conv_module = conv_module 623 self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module 624 self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module 625 if feed_forward_macaron is not None: 626 self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) 627 self.ff_scale = 0.5 628 else: 629 self.ff_scale = 1.0 630 if self.conv_module is not None: 631 self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module 632 self.norm_final = nn.LayerNorm( 633 size, eps=1e-5 634 ) # for the final output of the block 635 self.dropout = nn.Dropout(dropout_rate) 636 self.size = size 637 self.normalize_before = normalize_before 638 639 def forward( 640 self, 641 x: torch.Tensor, 642 mask: torch.Tensor, 643 pos_emb: torch.Tensor, 644 mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 645 att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 646 cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 647 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 648 """Compute encoded features. 649 650 Args: 651 x (torch.Tensor): (#batch, time, size) 652 mask (torch.Tensor): Mask tensor for the input (#batch, time,time), 653 (0, 0, 0) means fake mask. 654 pos_emb (torch.Tensor): positional encoding, must not be None 655 for ConformerEncoderLayer. 656 mask_pad (torch.Tensor): batch padding mask used for conv module. 657 (#batch, 1,time), (0, 0, 0) means fake mask. 658 att_cache (torch.Tensor): Cache tensor of the KEY & VALUE 659 (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. 660 cnn_cache (torch.Tensor): Convolution cache in conformer layer 661 (#batch=1, size, cache_t2) 662 Returns: 663 torch.Tensor: Output tensor (#batch, time, size). 664 torch.Tensor: Mask tensor (#batch, time, time). 665 torch.Tensor: att_cache tensor, 666 (#batch=1, head, cache_t1 + time, d_k * 2). 667 torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). 668 """ 669 670 # whether to use macaron style 671 if self.feed_forward_macaron is not None: 672 residual = x 673 if self.normalize_before: 674 x = self.norm_ff_macaron(x) 675 x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) 676 if not self.normalize_before: 677 x = self.norm_ff_macaron(x) 678 679 # multi-headed self-attention module 680 residual = x 681 if self.normalize_before: 682 x = self.norm_mha(x) 683 x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) 684 x = residual + self.dropout(x_att) 685 if not self.normalize_before: 686 x = self.norm_mha(x) 687 688 # convolution module 689 # Fake new cnn cache here, and then change it in conv_module 690 new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 691 if self.conv_module is not None: 692 residual = x 693 if self.normalize_before: 694 x = self.norm_conv(x) 695 x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 696 x = residual + self.dropout(x) 697 698 if not self.normalize_before: 699 x = self.norm_conv(x) 700 701 # feed forward module 702 residual = x 703 if self.normalize_before: 704 x = self.norm_ff(x) 705 706 x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 707 if not self.normalize_before: 708 x = self.norm_ff(x) 709 710 if self.conv_module is not None: 711 x = self.norm_final(x) 712 713 return x, mask, new_att_cache, new_cnn_cache
Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
MultiHeadedAttention or RelPositionMultiHeadedAttention
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
PositionwiseFeedForward instance can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
instance.
PositionwiseFeedForward instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
ConvlutionModule instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
607 def __init__( 608 self, 609 size: int, 610 self_attn: torch.nn.Module, 611 feed_forward: Optional[nn.Module] = None, 612 feed_forward_macaron: Optional[nn.Module] = None, 613 conv_module: Optional[nn.Module] = None, 614 dropout_rate: float = 0.1, 615 normalize_before: bool = True, 616 ): 617 """Construct an EncoderLayer object.""" 618 super().__init__() 619 self.self_attn = self_attn 620 self.feed_forward = feed_forward 621 self.feed_forward_macaron = feed_forward_macaron 622 self.conv_module = conv_module 623 self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module 624 self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module 625 if feed_forward_macaron is not None: 626 self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) 627 self.ff_scale = 0.5 628 else: 629 self.ff_scale = 1.0 630 if self.conv_module is not None: 631 self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module 632 self.norm_final = nn.LayerNorm( 633 size, eps=1e-5 634 ) # for the final output of the block 635 self.dropout = nn.Dropout(dropout_rate) 636 self.size = size 637 self.normalize_before = normalize_before
Construct an EncoderLayer object.
639 def forward( 640 self, 641 x: torch.Tensor, 642 mask: torch.Tensor, 643 pos_emb: torch.Tensor, 644 mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 645 att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 646 cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 647 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 648 """Compute encoded features. 649 650 Args: 651 x (torch.Tensor): (#batch, time, size) 652 mask (torch.Tensor): Mask tensor for the input (#batch, time,time), 653 (0, 0, 0) means fake mask. 654 pos_emb (torch.Tensor): positional encoding, must not be None 655 for ConformerEncoderLayer. 656 mask_pad (torch.Tensor): batch padding mask used for conv module. 657 (#batch, 1,time), (0, 0, 0) means fake mask. 658 att_cache (torch.Tensor): Cache tensor of the KEY & VALUE 659 (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. 660 cnn_cache (torch.Tensor): Convolution cache in conformer layer 661 (#batch=1, size, cache_t2) 662 Returns: 663 torch.Tensor: Output tensor (#batch, time, size). 664 torch.Tensor: Mask tensor (#batch, time, time). 665 torch.Tensor: att_cache tensor, 666 (#batch=1, head, cache_t1 + time, d_k * 2). 667 torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). 668 """ 669 670 # whether to use macaron style 671 if self.feed_forward_macaron is not None: 672 residual = x 673 if self.normalize_before: 674 x = self.norm_ff_macaron(x) 675 x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) 676 if not self.normalize_before: 677 x = self.norm_ff_macaron(x) 678 679 # multi-headed self-attention module 680 residual = x 681 if self.normalize_before: 682 x = self.norm_mha(x) 683 x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) 684 x = residual + self.dropout(x_att) 685 if not self.normalize_before: 686 x = self.norm_mha(x) 687 688 # convolution module 689 # Fake new cnn cache here, and then change it in conv_module 690 new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 691 if self.conv_module is not None: 692 residual = x 693 if self.normalize_before: 694 x = self.norm_conv(x) 695 x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 696 x = residual + self.dropout(x) 697 698 if not self.normalize_before: 699 x = self.norm_conv(x) 700 701 # feed forward module 702 residual = x 703 if self.normalize_before: 704 x = self.norm_ff(x) 705 706 x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 707 if not self.normalize_before: 708 x = self.norm_ff(x) 709 710 if self.conv_module is not None: 711 x = self.norm_final(x) 712 713 return x, mask, new_att_cache, new_cnn_cache
Compute encoded features.
Args: x (torch.Tensor): (#batch, time, size) mask (torch.Tensor): Mask tensor for the input (#batch, time,time), (0, 0, 0) means fake mask. pos_emb (torch.Tensor): positional encoding, must not be None for ConformerEncoderLayer. mask_pad (torch.Tensor): batch padding mask used for conv module. (#batch, 1,time), (0, 0, 0) means fake mask. att_cache (torch.Tensor): Cache tensor of the KEY & VALUE (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. cnn_cache (torch.Tensor): Convolution cache in conformer layer (#batch=1, size, cache_t2) Returns: torch.Tensor: Output tensor (#batch, time, size). torch.Tensor: Mask tensor (#batch, time, time). torch.Tensor: att_cache tensor, (#batch=1, head, cache_t1 + time, d_k * 2). torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
716class EspnetRelPositionalEncoding(torch.nn.Module): 717 """Relative positional encoding module (new implementation). 718 719 Details can be found in https://github.com/espnet/espnet/pull/2816. 720 721 See : Appendix B in https://arxiv.org/abs/1901.02860 722 723 Args: 724 d_model (int): Embedding dimension. 725 dropout_rate (float): Dropout rate. 726 max_len (int): Maximum input length. 727 728 """ 729 730 def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): 731 """Construct an PositionalEncoding object.""" 732 super(EspnetRelPositionalEncoding, self).__init__() 733 self.d_model = d_model 734 self.xscale = math.sqrt(self.d_model) 735 self.dropout = torch.nn.Dropout(p=dropout_rate) 736 self.pe = None 737 self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 738 739 def extend_pe(self, x: torch.Tensor): 740 """Reset the positional encodings.""" 741 if self.pe is not None: 742 # self.pe contains both positive and negative parts 743 # the length of self.pe is 2 * input_len - 1 744 if self.pe.size(1) >= x.size(1) * 2 - 1: 745 if self.pe.dtype != x.dtype or self.pe.device != x.device: 746 self.pe = self.pe.to(dtype=x.dtype, device=x.device) 747 return 748 # Suppose `i` means to the position of query vecotr and `j` means the 749 # position of key vector. We use position relative positions when keys 750 # are to the left (i>j) and negative relative positions otherwise (i<j). 751 pe_positive = torch.zeros(x.size(1), self.d_model) 752 pe_negative = torch.zeros(x.size(1), self.d_model) 753 position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 754 div_term = torch.exp( 755 torch.arange(0, self.d_model, 2, dtype=torch.float32) 756 * -(math.log(10000.0) / self.d_model) 757 ) 758 pe_positive[:, 0::2] = torch.sin(position * div_term) 759 pe_positive[:, 1::2] = torch.cos(position * div_term) 760 pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) 761 pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) 762 763 # Reserve the order of positive indices and concat both positive and 764 # negative indices. This is used to support the shifting trick 765 # as in https://arxiv.org/abs/1901.02860 766 pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) 767 pe_negative = pe_negative[1:].unsqueeze(0) 768 pe = torch.cat([pe_positive, pe_negative], dim=1) 769 self.pe = pe.to(device=x.device, dtype=x.dtype) 770 771 def forward( 772 self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0 773 ) -> Tuple[torch.Tensor, torch.Tensor]: 774 """Add positional encoding. 775 776 Args: 777 x (torch.Tensor): Input tensor (batch, time, `*`). 778 779 Returns: 780 torch.Tensor: Encoded tensor (batch, time, `*`). 781 782 """ 783 self.extend_pe(x) 784 x = x * self.xscale 785 pos_emb = self.position_encoding(size=x.size(1), offset=offset) 786 return self.dropout(x), self.dropout(pos_emb) 787 788 def position_encoding( 789 self, offset: Union[int, torch.Tensor], size: int 790 ) -> torch.Tensor: 791 """For getting encoding in a streaming fashion 792 793 Attention!!!!! 794 we apply dropout only once at the whole utterance level in a none 795 streaming way, but will call this function several times with 796 increasing input size in a streaming scenario, so the dropout will 797 be applied several times. 798 799 Args: 800 offset (int or torch.tensor): start offset 801 size (int): required size of position encoding 802 803 Returns: 804 torch.Tensor: Corresponding encoding 805 """ 806 pos_emb = self.pe[ 807 :, 808 self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size, 809 ] 810 return pos_emb
Relative positional encoding module (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args: d_model (int): Embedding dimension. dropout_rate (float): Dropout rate. max_len (int): Maximum input length.
730 def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): 731 """Construct an PositionalEncoding object.""" 732 super(EspnetRelPositionalEncoding, self).__init__() 733 self.d_model = d_model 734 self.xscale = math.sqrt(self.d_model) 735 self.dropout = torch.nn.Dropout(p=dropout_rate) 736 self.pe = None 737 self.extend_pe(torch.tensor(0.0).expand(1, max_len))
Construct an PositionalEncoding object.
739 def extend_pe(self, x: torch.Tensor): 740 """Reset the positional encodings.""" 741 if self.pe is not None: 742 # self.pe contains both positive and negative parts 743 # the length of self.pe is 2 * input_len - 1 744 if self.pe.size(1) >= x.size(1) * 2 - 1: 745 if self.pe.dtype != x.dtype or self.pe.device != x.device: 746 self.pe = self.pe.to(dtype=x.dtype, device=x.device) 747 return 748 # Suppose `i` means to the position of query vecotr and `j` means the 749 # position of key vector. We use position relative positions when keys 750 # are to the left (i>j) and negative relative positions otherwise (i<j). 751 pe_positive = torch.zeros(x.size(1), self.d_model) 752 pe_negative = torch.zeros(x.size(1), self.d_model) 753 position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 754 div_term = torch.exp( 755 torch.arange(0, self.d_model, 2, dtype=torch.float32) 756 * -(math.log(10000.0) / self.d_model) 757 ) 758 pe_positive[:, 0::2] = torch.sin(position * div_term) 759 pe_positive[:, 1::2] = torch.cos(position * div_term) 760 pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) 761 pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) 762 763 # Reserve the order of positive indices and concat both positive and 764 # negative indices. This is used to support the shifting trick 765 # as in https://arxiv.org/abs/1901.02860 766 pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) 767 pe_negative = pe_negative[1:].unsqueeze(0) 768 pe = torch.cat([pe_positive, pe_negative], dim=1) 769 self.pe = pe.to(device=x.device, dtype=x.dtype)
Reset the positional encodings.
771 def forward( 772 self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0 773 ) -> Tuple[torch.Tensor, torch.Tensor]: 774 """Add positional encoding. 775 776 Args: 777 x (torch.Tensor): Input tensor (batch, time, `*`). 778 779 Returns: 780 torch.Tensor: Encoded tensor (batch, time, `*`). 781 782 """ 783 self.extend_pe(x) 784 x = x * self.xscale 785 pos_emb = self.position_encoding(size=x.size(1), offset=offset) 786 return self.dropout(x), self.dropout(pos_emb)
Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, *).
Returns:
torch.Tensor: Encoded tensor (batch, time, *).
788 def position_encoding( 789 self, offset: Union[int, torch.Tensor], size: int 790 ) -> torch.Tensor: 791 """For getting encoding in a streaming fashion 792 793 Attention!!!!! 794 we apply dropout only once at the whole utterance level in a none 795 streaming way, but will call this function several times with 796 increasing input size in a streaming scenario, so the dropout will 797 be applied several times. 798 799 Args: 800 offset (int or torch.tensor): start offset 801 size (int): required size of position encoding 802 803 Returns: 804 torch.Tensor: Corresponding encoding 805 """ 806 pos_emb = self.pe[ 807 :, 808 self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size, 809 ] 810 return pos_emb
For getting encoding in a streaming fashion
Attention!!!!! we apply dropout only once at the whole utterance level in a none streaming way, but will call this function several times with increasing input size in a streaming scenario, so the dropout will be applied several times.
Args: offset (int or torch.tensor): start offset size (int): required size of position encoding
Returns: torch.Tensor: Corresponding encoding
813class LinearEmbed(torch.nn.Module): 814 """Linear transform the input without subsampling 815 816 Args: 817 idim (int): Input dimension. 818 odim (int): Output dimension. 819 dropout_rate (float): Dropout rate. 820 821 """ 822 823 def __init__( 824 self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module 825 ): 826 """Construct an linear object.""" 827 super().__init__() 828 self.out = torch.nn.Sequential( 829 torch.nn.Linear(idim, odim), 830 torch.nn.LayerNorm(odim, eps=1e-5), 831 torch.nn.Dropout(dropout_rate), 832 ) 833 self.pos_enc = pos_enc_class # rel_pos_espnet 834 835 def position_encoding( 836 self, offset: Union[int, torch.Tensor], size: int 837 ) -> torch.Tensor: 838 return self.pos_enc.position_encoding(offset, size) 839 840 def forward( 841 self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0 842 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 843 """Input x. 844 845 Args: 846 x (torch.Tensor): Input tensor (#batch, time, idim). 847 x_mask (torch.Tensor): Input mask (#batch, 1, time). 848 849 Returns: 850 torch.Tensor: linear input tensor (#batch, time', odim), 851 where time' = time . 852 torch.Tensor: linear input mask (#batch, 1, time'), 853 where time' = time . 854 855 """ 856 x = self.out(x) 857 x, pos_emb = self.pos_enc(x, offset) 858 return x, pos_emb
Linear transform the input without subsampling
Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate.
823 def __init__( 824 self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module 825 ): 826 """Construct an linear object.""" 827 super().__init__() 828 self.out = torch.nn.Sequential( 829 torch.nn.Linear(idim, odim), 830 torch.nn.LayerNorm(odim, eps=1e-5), 831 torch.nn.Dropout(dropout_rate), 832 ) 833 self.pos_enc = pos_enc_class # rel_pos_espnet
Construct an linear object.
840 def forward( 841 self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0 842 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 843 """Input x. 844 845 Args: 846 x (torch.Tensor): Input tensor (#batch, time, idim). 847 x_mask (torch.Tensor): Input mask (#batch, 1, time). 848 849 Returns: 850 torch.Tensor: linear input tensor (#batch, time', odim), 851 where time' = time . 852 torch.Tensor: linear input mask (#batch, 1, time'), 853 where time' = time . 854 855 """ 856 x = self.out(x) 857 x, pos_emb = self.pos_enc(x, offset) 858 return x, pos_emb
Input x.
Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns: torch.Tensor: linear input tensor (#batch, time', odim), where time' = time . torch.Tensor: linear input mask (#batch, 1, time'), where time' = time .
876def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: 877 """Make mask tensor containing indices of padded part. 878 879 See description of make_non_pad_mask. 880 881 Args: 882 lengths (torch.Tensor): Batch of lengths (B,). 883 Returns: 884 torch.Tensor: Mask tensor containing indices of padded part. 885 886 Examples: 887 >>> lengths = [5, 3, 2] 888 >>> make_pad_mask(lengths) 889 masks = [[0, 0, 0, 0 ,0], 890 [0, 0, 0, 1, 1], 891 [0, 0, 1, 1, 1]] 892 """ 893 batch_size = lengths.size(0) 894 max_len = max_len if max_len > 0 else lengths.max().item() 895 seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) 896 seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 897 seq_length_expand = lengths.unsqueeze(-1) 898 mask = seq_range_expand >= seq_length_expand 899 return mask
Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args: lengths (torch.Tensor): Batch of lengths (B,). Returns: torch.Tensor: Mask tensor containing indices of padded part.
Examples:
lengths = [5, 3, 2] make_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]]
903class ConformerEncoder(torch.nn.Module): 904 """Conformer encoder module.""" 905 906 def __init__( 907 self, 908 input_size: int, 909 output_size: int = 1024, 910 attention_heads: int = 16, 911 linear_units: int = 4096, 912 num_blocks: int = 6, 913 dropout_rate: float = 0.1, 914 positional_dropout_rate: float = 0.1, 915 attention_dropout_rate: float = 0.0, 916 input_layer: str = "linear", 917 pos_enc_layer_type: str = "rel_pos_espnet", 918 normalize_before: bool = True, 919 static_chunk_size: int = 1, # 1: causal_mask; 0: full_mask 920 use_dynamic_chunk: bool = False, 921 use_dynamic_left_chunk: bool = False, 922 positionwise_conv_kernel_size: int = 1, 923 macaron_style: bool = False, 924 selfattention_layer_type: str = "rel_selfattn", 925 activation_type: str = "swish", 926 use_cnn_module: bool = False, 927 cnn_module_kernel: int = 15, 928 causal: bool = False, 929 cnn_module_norm: str = "batch_norm", 930 key_bias: bool = True, 931 gradient_checkpointing: bool = False, 932 ): 933 """Construct ConformerEncoder 934 935 Args: 936 input_size to use_dynamic_chunk, see in BaseEncoder 937 positionwise_conv_kernel_size (int): Kernel size of positionwise 938 conv1d layer. 939 macaron_style (bool): Whether to use macaron style for 940 positionwise layer. 941 selfattention_layer_type (str): Encoder attention layer type, 942 the parameter has no effect now, it's just for configure 943 compatibility. #'rel_selfattn' 944 activation_type (str): Encoder activation function type. 945 use_cnn_module (bool): Whether to use convolution module. 946 cnn_module_kernel (int): Kernel size of convolution module. 947 causal (bool): whether to use causal convolution or not. 948 key_bias: whether use bias in attention.linear_k, False for whisper models. 949 """ 950 super().__init__() 951 self.output_size = output_size 952 self.embed = LinearEmbed( 953 input_size, 954 output_size, 955 dropout_rate, 956 EspnetRelPositionalEncoding(output_size, positional_dropout_rate), 957 ) 958 self.normalize_before = normalize_before 959 self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) 960 self.gradient_checkpointing = gradient_checkpointing 961 self.use_dynamic_chunk = use_dynamic_chunk 962 963 self.static_chunk_size = static_chunk_size 964 self.use_dynamic_chunk = use_dynamic_chunk 965 self.use_dynamic_left_chunk = use_dynamic_left_chunk 966 activation = ACTIVATION_CLASSES[activation_type]() 967 968 # self-attention module definition 969 encoder_selfattn_layer_args = ( 970 attention_heads, 971 output_size, 972 attention_dropout_rate, 973 key_bias, 974 ) 975 # feed-forward module definition 976 positionwise_layer_args = ( 977 output_size, 978 linear_units, 979 dropout_rate, 980 activation, 981 ) 982 # convolution module definition 983 convolution_layer_args = ( 984 output_size, 985 cnn_module_kernel, 986 activation, 987 cnn_module_norm, 988 causal, 989 ) 990 991 self.encoders = torch.nn.ModuleList( 992 [ 993 ConformerEncoderLayer( 994 output_size, 995 RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args), 996 PositionwiseFeedForward(*positionwise_layer_args), 997 ( 998 PositionwiseFeedForward(*positionwise_layer_args) 999 if macaron_style 1000 else None 1001 ), 1002 ( 1003 ConvolutionModule(*convolution_layer_args) 1004 if use_cnn_module 1005 else None 1006 ), 1007 dropout_rate, 1008 normalize_before, 1009 ) 1010 for _ in range(num_blocks) 1011 ] 1012 ) 1013 1014 def forward_layers( 1015 self, 1016 xs: torch.Tensor, 1017 chunk_masks: torch.Tensor, 1018 pos_emb: torch.Tensor, 1019 mask_pad: torch.Tensor, 1020 ) -> torch.Tensor: 1021 for layer in self.encoders: 1022 xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) 1023 return xs 1024 1025 @torch.jit.unused 1026 def forward_layers_checkpointed( 1027 self, 1028 xs: torch.Tensor, 1029 chunk_masks: torch.Tensor, 1030 pos_emb: torch.Tensor, 1031 mask_pad: torch.Tensor, 1032 ) -> torch.Tensor: 1033 for layer in self.encoders: 1034 xs, chunk_masks, _, _ = torch.utils.checkpoint.checkpoint( 1035 layer.__call__, xs, chunk_masks, pos_emb, mask_pad, use_reentrant=False 1036 ) 1037 return xs 1038 1039 def forward( 1040 self, 1041 xs: torch.Tensor, 1042 pad_mask: torch.Tensor, 1043 decoding_chunk_size: int = 0, 1044 num_decoding_left_chunks: int = -1, 1045 ) -> Tuple[torch.Tensor, torch.Tensor]: 1046 """Embed positions in tensor. 1047 1048 Args: 1049 xs: padded input tensor (B, T, D) 1050 xs_lens: input length (B) 1051 decoding_chunk_size: decoding chunk size for dynamic chunk 1052 0: default for training, use random dynamic chunk. 1053 <0: for decoding, use full chunk. 1054 >0: for decoding, use fixed chunk size as set. 1055 num_decoding_left_chunks: number of left chunks, this is for decoding, 1056 the chunk size is decoding_chunk_size. 1057 >=0: use num_decoding_left_chunks 1058 <0: use all left chunks 1059 Returns: 1060 encoder output tensor xs, and subsampled masks 1061 xs: padded output tensor (B, T' ~= T/subsample_rate, D) 1062 masks: torch.Tensor batch padding mask after subsample 1063 (B, 1, T' ~= T/subsample_rate) 1064 NOTE(xcsong): 1065 We pass the `__call__` method of the modules instead of `forward` to the 1066 checkpointing API because `__call__` attaches all the hooks of the module. 1067 https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 1068 """ 1069 T = xs.size(1) 1070 masks = pad_mask.to(torch.bool).unsqueeze(1) # (B, 1, T) 1071 xs, pos_emb = self.embed(xs) 1072 mask_pad = masks # (B, 1, T/subsample_rate) 1073 chunk_masks = add_optional_chunk_mask( 1074 xs, 1075 masks, 1076 self.use_dynamic_chunk, 1077 self.use_dynamic_left_chunk, 1078 decoding_chunk_size, 1079 self.static_chunk_size, 1080 num_decoding_left_chunks, 1081 ) 1082 if self.gradient_checkpointing and self.training: 1083 xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad) 1084 else: 1085 xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) 1086 if self.normalize_before: 1087 xs = self.after_norm(xs) 1088 # Here we assume the mask is not changed in encoder layers, so just 1089 # return the masks before encoder layers, and the masks will be used 1090 # for cross attention with decoder later 1091 return xs, masks
Conformer encoder module.
906 def __init__( 907 self, 908 input_size: int, 909 output_size: int = 1024, 910 attention_heads: int = 16, 911 linear_units: int = 4096, 912 num_blocks: int = 6, 913 dropout_rate: float = 0.1, 914 positional_dropout_rate: float = 0.1, 915 attention_dropout_rate: float = 0.0, 916 input_layer: str = "linear", 917 pos_enc_layer_type: str = "rel_pos_espnet", 918 normalize_before: bool = True, 919 static_chunk_size: int = 1, # 1: causal_mask; 0: full_mask 920 use_dynamic_chunk: bool = False, 921 use_dynamic_left_chunk: bool = False, 922 positionwise_conv_kernel_size: int = 1, 923 macaron_style: bool = False, 924 selfattention_layer_type: str = "rel_selfattn", 925 activation_type: str = "swish", 926 use_cnn_module: bool = False, 927 cnn_module_kernel: int = 15, 928 causal: bool = False, 929 cnn_module_norm: str = "batch_norm", 930 key_bias: bool = True, 931 gradient_checkpointing: bool = False, 932 ): 933 """Construct ConformerEncoder 934 935 Args: 936 input_size to use_dynamic_chunk, see in BaseEncoder 937 positionwise_conv_kernel_size (int): Kernel size of positionwise 938 conv1d layer. 939 macaron_style (bool): Whether to use macaron style for 940 positionwise layer. 941 selfattention_layer_type (str): Encoder attention layer type, 942 the parameter has no effect now, it's just for configure 943 compatibility. #'rel_selfattn' 944 activation_type (str): Encoder activation function type. 945 use_cnn_module (bool): Whether to use convolution module. 946 cnn_module_kernel (int): Kernel size of convolution module. 947 causal (bool): whether to use causal convolution or not. 948 key_bias: whether use bias in attention.linear_k, False for whisper models. 949 """ 950 super().__init__() 951 self.output_size = output_size 952 self.embed = LinearEmbed( 953 input_size, 954 output_size, 955 dropout_rate, 956 EspnetRelPositionalEncoding(output_size, positional_dropout_rate), 957 ) 958 self.normalize_before = normalize_before 959 self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) 960 self.gradient_checkpointing = gradient_checkpointing 961 self.use_dynamic_chunk = use_dynamic_chunk 962 963 self.static_chunk_size = static_chunk_size 964 self.use_dynamic_chunk = use_dynamic_chunk 965 self.use_dynamic_left_chunk = use_dynamic_left_chunk 966 activation = ACTIVATION_CLASSES[activation_type]() 967 968 # self-attention module definition 969 encoder_selfattn_layer_args = ( 970 attention_heads, 971 output_size, 972 attention_dropout_rate, 973 key_bias, 974 ) 975 # feed-forward module definition 976 positionwise_layer_args = ( 977 output_size, 978 linear_units, 979 dropout_rate, 980 activation, 981 ) 982 # convolution module definition 983 convolution_layer_args = ( 984 output_size, 985 cnn_module_kernel, 986 activation, 987 cnn_module_norm, 988 causal, 989 ) 990 991 self.encoders = torch.nn.ModuleList( 992 [ 993 ConformerEncoderLayer( 994 output_size, 995 RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args), 996 PositionwiseFeedForward(*positionwise_layer_args), 997 ( 998 PositionwiseFeedForward(*positionwise_layer_args) 999 if macaron_style 1000 else None 1001 ), 1002 ( 1003 ConvolutionModule(*convolution_layer_args) 1004 if use_cnn_module 1005 else None 1006 ), 1007 dropout_rate, 1008 normalize_before, 1009 ) 1010 for _ in range(num_blocks) 1011 ] 1012 )
Construct ConformerEncoder
Args: input_size to use_dynamic_chunk, see in BaseEncoder positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. macaron_style (bool): Whether to use macaron style for positionwise layer. selfattention_layer_type (str): Encoder attention layer type, the parameter has no effect now, it's just for configure compatibility. #'rel_selfattn' activation_type (str): Encoder activation function type. use_cnn_module (bool): Whether to use convolution module. cnn_module_kernel (int): Kernel size of convolution module. causal (bool): whether to use causal convolution or not. key_bias: whether use bias in attention.linear_k, False for whisper models.
1025 @torch.jit.unused 1026 def forward_layers_checkpointed( 1027 self, 1028 xs: torch.Tensor, 1029 chunk_masks: torch.Tensor, 1030 pos_emb: torch.Tensor, 1031 mask_pad: torch.Tensor, 1032 ) -> torch.Tensor: 1033 for layer in self.encoders: 1034 xs, chunk_masks, _, _ = torch.utils.checkpoint.checkpoint( 1035 layer.__call__, xs, chunk_masks, pos_emb, mask_pad, use_reentrant=False 1036 ) 1037 return xs
1039 def forward( 1040 self, 1041 xs: torch.Tensor, 1042 pad_mask: torch.Tensor, 1043 decoding_chunk_size: int = 0, 1044 num_decoding_left_chunks: int = -1, 1045 ) -> Tuple[torch.Tensor, torch.Tensor]: 1046 """Embed positions in tensor. 1047 1048 Args: 1049 xs: padded input tensor (B, T, D) 1050 xs_lens: input length (B) 1051 decoding_chunk_size: decoding chunk size for dynamic chunk 1052 0: default for training, use random dynamic chunk. 1053 <0: for decoding, use full chunk. 1054 >0: for decoding, use fixed chunk size as set. 1055 num_decoding_left_chunks: number of left chunks, this is for decoding, 1056 the chunk size is decoding_chunk_size. 1057 >=0: use num_decoding_left_chunks 1058 <0: use all left chunks 1059 Returns: 1060 encoder output tensor xs, and subsampled masks 1061 xs: padded output tensor (B, T' ~= T/subsample_rate, D) 1062 masks: torch.Tensor batch padding mask after subsample 1063 (B, 1, T' ~= T/subsample_rate) 1064 NOTE(xcsong): 1065 We pass the `__call__` method of the modules instead of `forward` to the 1066 checkpointing API because `__call__` attaches all the hooks of the module. 1067 https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 1068 """ 1069 T = xs.size(1) 1070 masks = pad_mask.to(torch.bool).unsqueeze(1) # (B, 1, T) 1071 xs, pos_emb = self.embed(xs) 1072 mask_pad = masks # (B, 1, T/subsample_rate) 1073 chunk_masks = add_optional_chunk_mask( 1074 xs, 1075 masks, 1076 self.use_dynamic_chunk, 1077 self.use_dynamic_left_chunk, 1078 decoding_chunk_size, 1079 self.static_chunk_size, 1080 num_decoding_left_chunks, 1081 ) 1082 if self.gradient_checkpointing and self.training: 1083 xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad) 1084 else: 1085 xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) 1086 if self.normalize_before: 1087 xs = self.after_norm(xs) 1088 # Here we assume the mask is not changed in encoder layers, so just 1089 # return the masks before encoder layers, and the masks will be used 1090 # for cross attention with decoder later 1091 return xs, masks
Embed positions in tensor.
Args: xs: padded input tensor (B, T, D) xs_lens: input length (B) decoding_chunk_size: decoding chunk size for dynamic chunk 0: default for training, use random dynamic chunk. <0: for decoding, use full chunk.
0: for decoding, use fixed chunk size as set. num_decoding_left_chunks: number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. =0: use num_decoding_left_chunks <0: use all left chunks Returns: encoder output tensor xs, and subsampled masks xs: padded output tensor (B, T' ~= T/subsample_rate, D) masks: torch.Tensor batch padding mask after subsample (B, 1, T' ~= T/subsample_rate) NOTE(xcsong): We pass the
__call__method of the modules instead offorwardto the checkpointing API because__call__attaches all the hooks of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2