跳转至

API Reference

Zipformer Model

zipformer.modules.zipformer

Zipformer

Bases: Module

Args:

Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length as downsampling_factor if they are single ints or one-element tuples. The length of downsampling_factor defines the number of stacks.

output_downsampling_factor (int): how much to downsample at the output.  Note:
    we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
    You should probably leave this at 2.
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
   Note: this is in addition to the downsampling factor of 2 that is applied in
   the frontend (self.encoder_embed).
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
   encoder stack.
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
    the encoder stacks for purposes of per-frame dropout (recommend 256 for
    now).
query_head_dim (int or Tuple[int]): dimension of query and key per attention
   head: per stack, if a tuple..
pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
   attention head
value_head_dim (int or Tuple[int]): dimension of value in each attention head
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
      Must be at least 4.
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module

pos_dim (int): the dimension of each positional-encoding vector prior to projection,
    e.g. 128.

dropout (float): dropout rate
warmup_batches (float): number of batches to warm up over; this controls
  dropout of encoder layers.
causal (bool): if True, support chunkwise causal convolution.  This should
  not hurt WER as no modeling power is lost, but the convolution modules will be
  slightly slower and use more memory.  Enables use of the chunk_size and
  left_context_chunks options in forward(), which simulates streaming
  decoding.
chunk_size: (list of int): only set this to other than [-1] if causal;
   the chunk size will be randomly chosen from this list.  -1 means no chunking.
left_context_frames: (list of int): determines the number of left-
   context chunks for causal training; will be rounded to a number of
   chunks.  Must not be less than cnn_module_kernel (after factoring in
   rounding and downsampling); an error will be thrown if this is violated.
源代码位于: zipformer/modules/zipformer.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
class Zipformer(torch.nn.Module):
    """
    Args:

    Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
    as downsampling_factor if they are single ints or one-element tuples.  The length of
    downsampling_factor defines the number of stacks.

        output_downsampling_factor (int): how much to downsample at the output.  Note:
            we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
            You should probably leave this at 2.
        downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
           Note: this is in addition to the downsampling factor of 2 that is applied in
           the frontend (self.encoder_embed).
        encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
           encoder stack.
        num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
        encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
            the encoder stacks for purposes of per-frame dropout (recommend 256 for
            now).
        query_head_dim (int or Tuple[int]): dimension of query and key per attention
           head: per stack, if a tuple..
        pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
           attention head
        value_head_dim (int or Tuple[int]): dimension of value in each attention head
        num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
              Must be at least 4.
        feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
        cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module

        pos_dim (int): the dimension of each positional-encoding vector prior to projection,
            e.g. 128.

        dropout (float): dropout rate
        warmup_batches (float): number of batches to warm up over; this controls
          dropout of encoder layers.
        causal (bool): if True, support chunkwise causal convolution.  This should
          not hurt WER as no modeling power is lost, but the convolution modules will be
          slightly slower and use more memory.  Enables use of the chunk_size and
          left_context_chunks options in forward(), which simulates streaming
          decoding.
        chunk_size: (list of int): only set this to other than [-1] if causal;
           the chunk size will be randomly chosen from this list.  -1 means no chunking.
        left_context_frames: (list of int): determines the number of left-
           context chunks for causal training; will be rounded to a number of
           chunks.  Must not be less than cnn_module_kernel (after factoring in
           rounding and downsampling); an error will be thrown if this is violated.
    """

    def __init__(
        self,
        output_downsampling_factor: int = 2,
        downsampling_factor: Tuple[int] = (2, 4),
        encoder_dim: Union[int, Tuple[int]] = 384,
        num_encoder_layers: Union[int, Tuple[int]] = 4,
        encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
        query_head_dim: Union[int, Tuple[int]] = 24,
        pos_head_dim: Union[int, Tuple[int]] = 4,
        value_head_dim: Union[int, Tuple[int]] = 12,
        num_heads: Union[int, Tuple[int]] = 8,
        feedforward_dim: Union[int, Tuple[int]] = 1536,
        cnn_module_kernel: Union[int, Tuple[int]] = 31,
        pos_dim: int = 192,
        dropout: FloatLike = None,  # see code below for default
        warmup_batches: float = 4000.0,
        causal: bool = False,
        chunk_size: Tuple[int] = [-1],
        left_context_frames: Tuple[int] = [-1],
    ) -> None:
        super(Zipformer, self).__init__()

        if dropout is None:
            dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))

        def _to_tuple(x):
            """Converts a single int or a 1-tuple of an int to a tuple with the same length
            as downsampling_factor"""
            if isinstance(x, int):
                x = (x,)
            if len(x) == 1:
                x = x * len(downsampling_factor)
            else:
                assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
            return x

        self.output_downsampling_factor = output_downsampling_factor  # int
        self.downsampling_factor = downsampling_factor  # tuple
        self.encoder_dim = encoder_dim = _to_tuple(encoder_dim)  # tuple
        self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
            encoder_unmasked_dim
        )  # tuple
        num_encoder_layers = _to_tuple(num_encoder_layers)
        self.num_encoder_layers = num_encoder_layers
        self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
        self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
        pos_head_dim = _to_tuple(pos_head_dim)
        self.num_heads = num_heads = _to_tuple(num_heads)
        feedforward_dim = _to_tuple(feedforward_dim)
        self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)

        self.causal = causal
        self.chunk_size = chunk_size
        self.left_context_frames = left_context_frames

        for u, d in zip(encoder_unmasked_dim, encoder_dim):
            assert u <= d

        # each one will be ZipformerEncoder or DownsampledZipformerEncoder
        encoders = []

        num_encoders = len(downsampling_factor)
        for i in range(num_encoders):
            encoder_layer = ZipformerEncoderLayer(
                embed_dim=encoder_dim[i],
                pos_dim=pos_dim,
                num_heads=num_heads[i],
                query_head_dim=query_head_dim[i],
                pos_head_dim=pos_head_dim[i],
                value_head_dim=value_head_dim[i],
                feedforward_dim=feedforward_dim[i],
                dropout=dropout,
                cnn_module_kernel=cnn_module_kernel[i],
                causal=causal,
            )

            # For the segment of the warmup period, we let the Conv2dSubsampling
            # layer learn something.  Then we start to warm up the other encoders.
            encoder = ZipformerEncoder(
                encoder_layer,
                num_encoder_layers[i],
                pos_dim=pos_dim,
                dropout=dropout,
                warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
                warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
                final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
            )

            if downsampling_factor[i] != 1:
                encoder = DownsampledZipformerEncoder(
                    encoder,
                    dim=encoder_dim[i],
                    downsample=downsampling_factor[i],
                    dropout=dropout,
                    causal=causal,
                )

            encoders.append(encoder)

        self.encoders = torch.nn.ModuleList(encoders)

        self.downsample_output = SimpleDownsample(
            max(encoder_dim),
            downsample=output_downsampling_factor,
            dropout=dropout,
            causal=causal,
        )

    def get_feature_masks(
        self, x: torch.Tensor
    ) -> Union[List[float], List[torch.Tensor]]:
        """
        In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
        randomized feature masks, one per encoder.
        On e.g. 15% of frames, these masks will zero out all encoder dims larger than
        some supplied number, e.g. >256, so in effect on those frames we are using
        a smaller encoder dim.

        We generate the random masks at this level because we want the 2 masks to 'agree'
        all the way up the encoder stack. This will mean that the 1st mask will have
        mask values repeated self.zipformer_subsampling_factor times.

        Args:
           x: the embeddings (needed for the shape and dtype and device), of shape
             (1, batch_size, encoder_dims0)
        """
        num_encoders = len(self.encoder_dim)
        if not self.training:
            return [1.0] * num_encoders

        (num_frames0, batch_size, _encoder_dims0) = x.shape

        assert self.encoder_dim[0] == _encoder_dims0, (
            self.encoder_dim[0],
            _encoder_dims0,
        )

        feature_mask_dropout_prob = 0.125

        # mask1 shape: (1, batch_size, 1)
        mask1 = (
            torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
        ).to(x.dtype)

        # mask2 has additional sequences masked, about twice the number.
        mask2 = torch.logical_and(
            mask1,
            (
                torch.rand(1, batch_size, 1, device=x.device)
                > feature_mask_dropout_prob
            ).to(x.dtype),
        )

        # dim: (1, batch_size, 2)
        mask = torch.cat((mask1, mask2), dim=-1)

        feature_masks = []
        for i in range(num_encoders):
            channels = self.encoder_dim[i]
            feature_mask = torch.ones(
                1, batch_size, channels, dtype=x.dtype, device=x.device
            )
            u1 = self.encoder_unmasked_dim[i]
            u2 = u1 + (channels - u1) // 2

            feature_mask[:, :, u1:u2] *= mask[..., 0:1]
            feature_mask[:, :, u2:] *= mask[..., 1:2]

            feature_masks.append(feature_mask)

        return feature_masks

    def get_chunk_info(self) -> Tuple[int, int]:
        """
        Returns chunk_size and left_context_chunks.
        """
        if not self.causal:
            return -1, -1

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            assert len(self.chunk_size) == 1, self.chunk_size
            chunk_size = self.chunk_size[0]
        else:
            chunk_size = random.choice(self.chunk_size)

        if chunk_size == -1:
            left_context_chunks = -1
        else:
            if torch.jit.is_scripting() or torch.jit.is_tracing():
                assert len(self.left_context_frames) == 1, self.left_context_frames
                left_context_frames = self.left_context_frames[0]
            else:
                left_context_frames = random.choice(self.left_context_frames)
            # Note: in Python, -1 // n == -1 for n > 0
            left_context_chunks = left_context_frames // chunk_size
            if left_context_chunks == 0:
                left_context_chunks = 1

        return chunk_size, left_context_chunks

    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        src_key_padding_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
          x:
            The input tensor. Its shape is (seq_len, batch_size, feature_dim).
          x_lens:
            A tensor of shape (batch_size,) containing the number of frames in
            `x` before padding.
          src_key_padding_mask:
            The mask for padding, of shape (batch_size, seq_len); True means
            masked position. May be None.
        Returns:
          Return a tuple containing 2 tensors:
            - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
            - lengths, a tensor of shape (batch_size,) containing the number
              of frames in `embeddings` before padding.
        """
        outputs = []
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            feature_masks = [1.0] * len(self.encoder_dim)
        else:
            feature_masks = self.get_feature_masks(x)

        chunk_size, left_context_chunks = self.get_chunk_info()

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            # Not support exporting a model for simulating streaming decoding
            attn_mask = None
        else:
            attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)

        for i, module in enumerate(self.encoders):
            ds = self.downsampling_factor[i]
            x = convert_num_channels(x, self.encoder_dim[i])

            x = module(
                x,
                chunk_size=chunk_size,
                feature_mask=feature_masks[i],
                src_key_padding_mask=(
                    None
                    if src_key_padding_mask is None
                    else src_key_padding_mask[..., ::ds]
                ),
                attn_mask=attn_mask,
            )
            outputs.append(x)

        # if the last output has the largest dimension, x will be unchanged,
        # it will be the same as outputs[-1].  Otherwise it will be concatenated
        # from different pieces of 'outputs', taking each dimension from the
        # most recent output that has it present.
        x = self._get_full_dim_output(outputs)
        x = self.downsample_output(x)
        # class Downsample has this rounding behavior..
        assert self.output_downsampling_factor == 2, self.output_downsampling_factor
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            lengths = (x_lens + 1) // 2
        else:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                lengths = (x_lens + 1) // 2

        return x, lengths

    def _get_attn_mask(
        self, x: torch.Tensor, chunk_size: int, left_context_chunks: int
    ) -> Optional[torch.Tensor]:
        """
        Return None if chunk_size == -1, else return attention mask of shape
          (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len).  True
           means a masked position.
        Args:
           x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
          chunk_size: chunk size, must divide
        """
        if chunk_size <= 0:
            return None
        assert all(chunk_size % d == 0 for d in self.downsampling_factor)
        if left_context_chunks >= 0:
            num_encoders = len(self.encoder_dim)
            assert all(
                chunk_size * left_context_chunks
                >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
                for i in range(num_encoders)
            )
        else:
            left_context_chunks = 1000000

        seq_len = x.shape[0]

        # t is frame index, shape (seq_len,)
        t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
        # c is chunk index for each frame, shape (seq_len,)
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            c = t // chunk_size
        else:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                c = t // chunk_size
        src_c = c
        tgt_c = c.unsqueeze(-1)

        attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
        if __name__ == "__main__":
            logging.info(f"attn_mask = {attn_mask}")
        return attn_mask

    def _get_full_dim_output(self, outputs: List[torch.Tensor]):
        num_encoders = len(self.encoder_dim)
        assert len(outputs) == num_encoders
        output_dim = max(self.encoder_dim)
        output_pieces = [outputs[-1]]
        cur_dim = self.encoder_dim[-1]
        for i in range(num_encoders - 2, -1, -1):
            d = self.encoder_dim[i]
            if d > cur_dim:
                this_output = outputs[i]
                output_pieces.append(this_output[..., cur_dim:d])
                cur_dim = d
        assert cur_dim == output_dim
        return torch.cat(output_pieces, dim=-1)

    def streaming_forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        states: List[torch.Tensor],
        src_key_padding_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        """
        Args:
          x:
            The input tensor. Its shape is (seq_len, batch_size, feature_dim).
          x_lens:
            A tensor of shape (batch_size,) containing the number of frames in
            `x` before padding.
          states: list of cached tensors of all encoder layers. For layer-i,
            states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
            cached_conv1, cached_conv2).
          src_key_padding_mask:
            The mask for padding, of shape (batch_size, seq_len); True means
            masked position. May be None.
        Returns:
          Return a tuple containing 2 tensors:
            - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
            - lengths, a tensor of shape (batch_size,) containing the number
              of frames in `embeddings` before padding.
            - updated states
        """
        outputs = []
        new_states = []
        layer_offset = 0

        for i, module in enumerate(self.encoders):
            num_layers = module.num_layers
            ds = self.downsampling_factor[i]
            x = convert_num_channels(x, self.encoder_dim[i])

            x, new_layer_states = module.streaming_forward(
                x,
                states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
                left_context_len=self.left_context_frames[0] // ds,
                src_key_padding_mask=src_key_padding_mask[..., ::ds],
            )
            layer_offset += num_layers
            outputs.append(x)
            new_states += new_layer_states

        # if the last output has the largest dimension, x will be unchanged,
        # it will be the same as outputs[-1].  Otherwise it will be concatenated
        # from different pieces of 'outputs', taking each dimension from the
        # most recent output that has it present.
        x = self._get_full_dim_output(outputs)
        x = self.downsample_output(x)
        # class Downsample has this rounding behavior..
        assert self.output_downsampling_factor == 2
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            lengths = (x_lens + 1) // 2
        else:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                lengths = (x_lens + 1) // 2

        return x, lengths, new_states

    @torch.jit.export
    def get_init_states(
        self,
        batch_size: int = 1,
        device: torch.device = torch.device("cpu"),
    ) -> List[torch.Tensor]:
        """Get initial states.

        A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
        is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
        """
        states = []
        for i, module in enumerate(self.encoders):
            num_layers = module.num_layers
            embed_dim = self.encoder_dim[i]
            ds = self.downsampling_factor[i]
            num_heads = self.num_heads[i]
            key_dim = self.query_head_dim[i] * num_heads
            value_dim = self.value_head_dim[i] * num_heads
            downsample_left = self.left_context_frames[0] // ds
            nonlin_attn_head_dim = 3 * embed_dim // 4
            conv_left_pad = self.cnn_module_kernel[i] // 2
            for layer in range(num_layers):
                cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
                    device
                )
                cached_nonlin_attn = torch.zeros(
                    1, batch_size, downsample_left, nonlin_attn_head_dim
                ).to(device)
                cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
                    device
                )
                cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
                    device
                )
                cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
                    device
                )
                cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
                    device
                )
                states += [
                    cached_key,
                    cached_nonlin_attn,
                    cached_val1,
                    cached_val2,
                    cached_conv1,
                    cached_conv2,
                ]

        return states

__init__(output_downsampling_factor=2, downsampling_factor=(2, 4), encoder_dim=384, num_encoder_layers=4, encoder_unmasked_dim=256, query_head_dim=24, pos_head_dim=4, value_head_dim=12, num_heads=8, feedforward_dim=1536, cnn_module_kernel=31, pos_dim=192, dropout=None, warmup_batches=4000.0, causal=False, chunk_size=[-1], left_context_frames=[-1])

源代码位于: zipformer/modules/zipformer.py
def __init__(
    self,
    output_downsampling_factor: int = 2,
    downsampling_factor: Tuple[int] = (2, 4),
    encoder_dim: Union[int, Tuple[int]] = 384,
    num_encoder_layers: Union[int, Tuple[int]] = 4,
    encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
    query_head_dim: Union[int, Tuple[int]] = 24,
    pos_head_dim: Union[int, Tuple[int]] = 4,
    value_head_dim: Union[int, Tuple[int]] = 12,
    num_heads: Union[int, Tuple[int]] = 8,
    feedforward_dim: Union[int, Tuple[int]] = 1536,
    cnn_module_kernel: Union[int, Tuple[int]] = 31,
    pos_dim: int = 192,
    dropout: FloatLike = None,  # see code below for default
    warmup_batches: float = 4000.0,
    causal: bool = False,
    chunk_size: Tuple[int] = [-1],
    left_context_frames: Tuple[int] = [-1],
) -> None:
    super(Zipformer, self).__init__()

    if dropout is None:
        dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))

    def _to_tuple(x):
        """Converts a single int or a 1-tuple of an int to a tuple with the same length
        as downsampling_factor"""
        if isinstance(x, int):
            x = (x,)
        if len(x) == 1:
            x = x * len(downsampling_factor)
        else:
            assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
        return x

    self.output_downsampling_factor = output_downsampling_factor  # int
    self.downsampling_factor = downsampling_factor  # tuple
    self.encoder_dim = encoder_dim = _to_tuple(encoder_dim)  # tuple
    self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
        encoder_unmasked_dim
    )  # tuple
    num_encoder_layers = _to_tuple(num_encoder_layers)
    self.num_encoder_layers = num_encoder_layers
    self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
    self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
    pos_head_dim = _to_tuple(pos_head_dim)
    self.num_heads = num_heads = _to_tuple(num_heads)
    feedforward_dim = _to_tuple(feedforward_dim)
    self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)

    self.causal = causal
    self.chunk_size = chunk_size
    self.left_context_frames = left_context_frames

    for u, d in zip(encoder_unmasked_dim, encoder_dim):
        assert u <= d

    # each one will be ZipformerEncoder or DownsampledZipformerEncoder
    encoders = []

    num_encoders = len(downsampling_factor)
    for i in range(num_encoders):
        encoder_layer = ZipformerEncoderLayer(
            embed_dim=encoder_dim[i],
            pos_dim=pos_dim,
            num_heads=num_heads[i],
            query_head_dim=query_head_dim[i],
            pos_head_dim=pos_head_dim[i],
            value_head_dim=value_head_dim[i],
            feedforward_dim=feedforward_dim[i],
            dropout=dropout,
            cnn_module_kernel=cnn_module_kernel[i],
            causal=causal,
        )

        # For the segment of the warmup period, we let the Conv2dSubsampling
        # layer learn something.  Then we start to warm up the other encoders.
        encoder = ZipformerEncoder(
            encoder_layer,
            num_encoder_layers[i],
            pos_dim=pos_dim,
            dropout=dropout,
            warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
            warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
            final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
        )

        if downsampling_factor[i] != 1:
            encoder = DownsampledZipformerEncoder(
                encoder,
                dim=encoder_dim[i],
                downsample=downsampling_factor[i],
                dropout=dropout,
                causal=causal,
            )

        encoders.append(encoder)

    self.encoders = torch.nn.ModuleList(encoders)

    self.downsample_output = SimpleDownsample(
        max(encoder_dim),
        downsample=output_downsampling_factor,
        dropout=dropout,
        causal=causal,
    )

get_feature_masks(x)

In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. On e.g. 15% of frames, these masks will zero out all encoder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using a smaller encoder dim.

We generate the random masks at this level because we want the 2 masks to 'agree' all the way up the encoder stack. This will mean that the 1st mask will have mask values repeated self.zipformer_subsampling_factor times.

参数:

名称 类型 描述 默认
x Tensor

the embeddings (needed for the shape and dtype and device), of shape (1, batch_size, encoder_dims0)

必需
源代码位于: zipformer/modules/zipformer.py
def get_feature_masks(
    self, x: torch.Tensor
) -> Union[List[float], List[torch.Tensor]]:
    """
    In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
    randomized feature masks, one per encoder.
    On e.g. 15% of frames, these masks will zero out all encoder dims larger than
    some supplied number, e.g. >256, so in effect on those frames we are using
    a smaller encoder dim.

    We generate the random masks at this level because we want the 2 masks to 'agree'
    all the way up the encoder stack. This will mean that the 1st mask will have
    mask values repeated self.zipformer_subsampling_factor times.

    Args:
       x: the embeddings (needed for the shape and dtype and device), of shape
         (1, batch_size, encoder_dims0)
    """
    num_encoders = len(self.encoder_dim)
    if not self.training:
        return [1.0] * num_encoders

    (num_frames0, batch_size, _encoder_dims0) = x.shape

    assert self.encoder_dim[0] == _encoder_dims0, (
        self.encoder_dim[0],
        _encoder_dims0,
    )

    feature_mask_dropout_prob = 0.125

    # mask1 shape: (1, batch_size, 1)
    mask1 = (
        torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
    ).to(x.dtype)

    # mask2 has additional sequences masked, about twice the number.
    mask2 = torch.logical_and(
        mask1,
        (
            torch.rand(1, batch_size, 1, device=x.device)
            > feature_mask_dropout_prob
        ).to(x.dtype),
    )

    # dim: (1, batch_size, 2)
    mask = torch.cat((mask1, mask2), dim=-1)

    feature_masks = []
    for i in range(num_encoders):
        channels = self.encoder_dim[i]
        feature_mask = torch.ones(
            1, batch_size, channels, dtype=x.dtype, device=x.device
        )
        u1 = self.encoder_unmasked_dim[i]
        u2 = u1 + (channels - u1) // 2

        feature_mask[:, :, u1:u2] *= mask[..., 0:1]
        feature_mask[:, :, u2:] *= mask[..., 1:2]

        feature_masks.append(feature_mask)

    return feature_masks

get_chunk_info()

Returns chunk_size and left_context_chunks.

源代码位于: zipformer/modules/zipformer.py
def get_chunk_info(self) -> Tuple[int, int]:
    """
    Returns chunk_size and left_context_chunks.
    """
    if not self.causal:
        return -1, -1

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        assert len(self.chunk_size) == 1, self.chunk_size
        chunk_size = self.chunk_size[0]
    else:
        chunk_size = random.choice(self.chunk_size)

    if chunk_size == -1:
        left_context_chunks = -1
    else:
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            assert len(self.left_context_frames) == 1, self.left_context_frames
            left_context_frames = self.left_context_frames[0]
        else:
            left_context_frames = random.choice(self.left_context_frames)
        # Note: in Python, -1 // n == -1 for n > 0
        left_context_chunks = left_context_frames // chunk_size
        if left_context_chunks == 0:
            left_context_chunks = 1

    return chunk_size, left_context_chunks

forward(x, x_lens, src_key_padding_mask=None)

参数:

名称 类型 描述 默认
x Tensor

The input tensor. Its shape is (seq_len, batch_size, feature_dim).

必需
x_lens Tensor

A tensor of shape (batch_size,) containing the number of frames in x before padding.

必需
src_key_padding_mask Optional[Tensor]

The mask for padding, of shape (batch_size, seq_len); True means masked position. May be None.

None

Returns: Return a tuple containing 2 tensors: - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - lengths, a tensor of shape (batch_size,) containing the number of frames in embeddings before padding.

源代码位于: zipformer/modules/zipformer.py
def forward(
    self,
    x: torch.Tensor,
    x_lens: torch.Tensor,
    src_key_padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
      x:
        The input tensor. Its shape is (seq_len, batch_size, feature_dim).
      x_lens:
        A tensor of shape (batch_size,) containing the number of frames in
        `x` before padding.
      src_key_padding_mask:
        The mask for padding, of shape (batch_size, seq_len); True means
        masked position. May be None.
    Returns:
      Return a tuple containing 2 tensors:
        - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
        - lengths, a tensor of shape (batch_size,) containing the number
          of frames in `embeddings` before padding.
    """
    outputs = []
    if torch.jit.is_scripting() or torch.jit.is_tracing():
        feature_masks = [1.0] * len(self.encoder_dim)
    else:
        feature_masks = self.get_feature_masks(x)

    chunk_size, left_context_chunks = self.get_chunk_info()

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        # Not support exporting a model for simulating streaming decoding
        attn_mask = None
    else:
        attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)

    for i, module in enumerate(self.encoders):
        ds = self.downsampling_factor[i]
        x = convert_num_channels(x, self.encoder_dim[i])

        x = module(
            x,
            chunk_size=chunk_size,
            feature_mask=feature_masks[i],
            src_key_padding_mask=(
                None
                if src_key_padding_mask is None
                else src_key_padding_mask[..., ::ds]
            ),
            attn_mask=attn_mask,
        )
        outputs.append(x)

    # if the last output has the largest dimension, x will be unchanged,
    # it will be the same as outputs[-1].  Otherwise it will be concatenated
    # from different pieces of 'outputs', taking each dimension from the
    # most recent output that has it present.
    x = self._get_full_dim_output(outputs)
    x = self.downsample_output(x)
    # class Downsample has this rounding behavior..
    assert self.output_downsampling_factor == 2, self.output_downsampling_factor
    if torch.jit.is_scripting() or torch.jit.is_tracing():
        lengths = (x_lens + 1) // 2
    else:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            lengths = (x_lens + 1) // 2

    return x, lengths

streaming_forward(x, x_lens, states, src_key_padding_mask)

参数:

名称 类型 描述 默认
x Tensor

The input tensor. Its shape is (seq_len, batch_size, feature_dim).

必需
x_lens Tensor

A tensor of shape (batch_size,) containing the number of frames in x before padding.

必需
states List[Tensor]

list of cached tensors of all encoder layers. For layer-i, states[i6:(i+1)6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).

必需
src_key_padding_mask Tensor

The mask for padding, of shape (batch_size, seq_len); True means masked position. May be None.

必需

Returns: Return a tuple containing 2 tensors: - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - lengths, a tensor of shape (batch_size,) containing the number of frames in embeddings before padding. - updated states

源代码位于: zipformer/modules/zipformer.py
def streaming_forward(
    self,
    x: torch.Tensor,
    x_lens: torch.Tensor,
    states: List[torch.Tensor],
    src_key_padding_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
    """
    Args:
      x:
        The input tensor. Its shape is (seq_len, batch_size, feature_dim).
      x_lens:
        A tensor of shape (batch_size,) containing the number of frames in
        `x` before padding.
      states: list of cached tensors of all encoder layers. For layer-i,
        states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
        cached_conv1, cached_conv2).
      src_key_padding_mask:
        The mask for padding, of shape (batch_size, seq_len); True means
        masked position. May be None.
    Returns:
      Return a tuple containing 2 tensors:
        - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
        - lengths, a tensor of shape (batch_size,) containing the number
          of frames in `embeddings` before padding.
        - updated states
    """
    outputs = []
    new_states = []
    layer_offset = 0

    for i, module in enumerate(self.encoders):
        num_layers = module.num_layers
        ds = self.downsampling_factor[i]
        x = convert_num_channels(x, self.encoder_dim[i])

        x, new_layer_states = module.streaming_forward(
            x,
            states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
            left_context_len=self.left_context_frames[0] // ds,
            src_key_padding_mask=src_key_padding_mask[..., ::ds],
        )
        layer_offset += num_layers
        outputs.append(x)
        new_states += new_layer_states

    # if the last output has the largest dimension, x will be unchanged,
    # it will be the same as outputs[-1].  Otherwise it will be concatenated
    # from different pieces of 'outputs', taking each dimension from the
    # most recent output that has it present.
    x = self._get_full_dim_output(outputs)
    x = self.downsample_output(x)
    # class Downsample has this rounding behavior..
    assert self.output_downsampling_factor == 2
    if torch.jit.is_scripting() or torch.jit.is_tracing():
        lengths = (x_lens + 1) // 2
    else:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            lengths = (x_lens + 1) // 2

    return x, lengths, new_states

get_init_states(batch_size=1, device=torch.device('cpu'))

Get initial states.

A list of cached tensors of all encoder layers. For layer-i, states[i6:(i+1)6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).

源代码位于: zipformer/modules/zipformer.py
@torch.jit.export
def get_init_states(
    self,
    batch_size: int = 1,
    device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
    """Get initial states.

    A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
    is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
    """
    states = []
    for i, module in enumerate(self.encoders):
        num_layers = module.num_layers
        embed_dim = self.encoder_dim[i]
        ds = self.downsampling_factor[i]
        num_heads = self.num_heads[i]
        key_dim = self.query_head_dim[i] * num_heads
        value_dim = self.value_head_dim[i] * num_heads
        downsample_left = self.left_context_frames[0] // ds
        nonlin_attn_head_dim = 3 * embed_dim // 4
        conv_left_pad = self.cnn_module_kernel[i] // 2
        for layer in range(num_layers):
            cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
                device
            )
            cached_nonlin_attn = torch.zeros(
                1, batch_size, downsample_left, nonlin_attn_head_dim
            ).to(device)
            cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
                device
            )
            cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
                device
            )
            cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
                device
            )
            cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
                device
            )
            states += [
                cached_key,
                cached_nonlin_attn,
                cached_val1,
                cached_val2,
                cached_conv1,
                cached_conv2,
            ]

    return states

ZipformerEncoderLayer

Bases: Module

参数:

名称 类型 描述 默认
embed_dim int

the number of expected features in the input (required).

必需
nhead

the number of heads in the multiheadattention models (required).

必需
feedforward_dim int

the dimension of the feedforward network model (required).

必需
dropout FloatLike

the dropout value (default=0.1).

0.1
cnn_module_kernel int

Kernel size of convolution module (default=31).

31

Examples:: >>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb)

源代码位于: zipformer/modules/zipformer.py
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
class ZipformerEncoderLayer(torch.nn.Module):
    """
    Args:
        embed_dim: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        feedforward_dim: the dimension of the feedforward network model (required).
        dropout: the dropout value (default=0.1).
        cnn_module_kernel (int): Kernel size of convolution module (default=31).

    Examples::
        >>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> pos_emb = torch.rand(32, 19, 512)
        >>> out = encoder_layer(src, pos_emb)
    """

    def __init__(
        self,
        embed_dim: int,
        pos_dim: int,
        num_heads: int,
        query_head_dim: int,
        pos_head_dim: int,
        value_head_dim: int,
        feedforward_dim: int,
        dropout: FloatLike = 0.1,
        cnn_module_kernel: int = 31,
        causal: bool = False,
        attention_skip_rate: FloatLike = ScheduledFloat(
            (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
        ),
        conv_skip_rate: FloatLike = ScheduledFloat(
            (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
        ),
        const_attention_rate: FloatLike = ScheduledFloat(
            (0.0, 0.25), (4000.0, 0.025), default=0
        ),
        ff2_skip_rate: FloatLike = ScheduledFloat(
            (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
        ),
        ff3_skip_rate: FloatLike = ScheduledFloat(
            (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
        ),
        bypass_skip_rate: FloatLike = ScheduledFloat(
            (0.0, 0.5), (4000.0, 0.02), default=0
        ),
    ) -> None:
        super(ZipformerEncoderLayer, self).__init__()
        self.embed_dim = embed_dim

        # self.bypass implements layer skipping as well as bypass; see its default values.
        self.bypass = BypassModule(
            embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
        )
        # bypass_mid is bypass used in the middle of the layer.
        self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)

        # skip probability for dynamic modules (meaning: anything but feedforward).
        self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
        # an additional skip probability that applies to ConvModule to stop it from
        # contributing too much early on.
        self.conv_skip_rate = copy.deepcopy(conv_skip_rate)

        # ff2_skip_rate is to prevent the ff2 module from having output that's too big
        # compared to its residual.
        self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
        self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)

        self.const_attention_rate = copy.deepcopy(const_attention_rate)

        self.self_attn_weights = RelPositionMultiheadAttentionWeights(
            embed_dim,
            pos_dim=pos_dim,
            num_heads=num_heads,
            query_head_dim=query_head_dim,
            pos_head_dim=pos_head_dim,
            dropout=0.0,
        )

        self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)

        self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)

        self.feed_forward1 = FeedforwardModule(
            embed_dim, (feedforward_dim * 3) // 4, dropout
        )

        self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)

        self.feed_forward3 = FeedforwardModule(
            embed_dim, (feedforward_dim * 5) // 4, dropout
        )

        self.nonlin_attention = NonlinAttention(
            embed_dim, hidden_channels=3 * embed_dim // 4
        )

        self.conv_module1 = ConvolutionModule(
            embed_dim, cnn_module_kernel, causal=causal
        )

        self.conv_module2 = ConvolutionModule(
            embed_dim, cnn_module_kernel, causal=causal
        )

        # TODO: remove it
        self.bypass_scale = torch.nn.Parameter(torch.full((embed_dim,), 0.5))

        self.norm = BiasNorm(embed_dim)

        self.balancer1 = Balancer(
            embed_dim,
            channel_dim=-1,
            min_positive=0.45,
            max_positive=0.55,
            min_abs=0.2,
            max_abs=4.0,
        )

        # balancer for output of NonlinAttentionModule
        self.balancer_na = Balancer(
            embed_dim,
            channel_dim=-1,
            min_positive=0.3,
            max_positive=0.7,
            min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
            prob=0.05,  # out of concern for memory usage
        )

        # balancer for output of feedforward2, prevent it from staying too
        # small.  give this a very small probability, even at the start of
        # training, it's to fix a rare problem and it's OK to fix it slowly.
        self.balancer_ff2 = Balancer(
            embed_dim,
            channel_dim=-1,
            min_positive=0.3,
            max_positive=0.7,
            min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
            max_abs=2.0,
            prob=0.05,
        )

        self.balancer_ff3 = Balancer(
            embed_dim,
            channel_dim=-1,
            min_positive=0.3,
            max_positive=0.7,
            min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
            max_abs=4.0,
            prob=0.05,
        )

        self.whiten = Whiten(
            num_groups=1,
            whitening_limit=_whitening_schedule(4.0, ratio=3.0),
            prob=(0.025, 0.25),
            grad_scale=0.01,
        )

        self.balancer2 = Balancer(
            embed_dim,
            channel_dim=-1,
            min_positive=0.45,
            max_positive=0.55,
            min_abs=0.1,
            max_abs=4.0,
        )

    def get_sequence_dropout_mask(
        self, x: torch.Tensor, dropout_rate: float
    ) -> Optional[torch.Tensor]:
        if (
            dropout_rate == 0.0
            or not self.training
            or torch.jit.is_scripting()
            or torch.jit.is_tracing()
        ):
            return None
        batch_size = x.shape[1]
        mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
        return mask

    def sequence_dropout(self, x: torch.Tensor, dropout_rate: float) -> torch.Tensor:
        """
        Apply sequence-level dropout to x.
        x shape: (seq_len, batch_size, embed_dim)
        """
        dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
        if dropout_mask is None:
            return x
        else:
            return x * dropout_mask

    def forward(
        self,
        src: torch.Tensor,
        pos_emb: torch.Tensor,
        chunk_size: int = -1,
        attn_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
            Pass the input through the encoder layer.
            Args:
                src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
             pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
             chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
           feature_mask: something that broadcasts with src, that we'll multiply `src`
                  by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
             attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
                    interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
                   True means masked position. May be None.
        src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
                 masked position.  May be None.

            Returns:
               A tensor which has the same shape as src
        """
        src_orig = src

        # dropout rate for non-feedforward submodules
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            attention_skip_rate = 0.0
        else:
            attention_skip_rate = (
                float(self.attention_skip_rate) if self.training else 0.0
            )

        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
        attn_weights = self.self_attn_weights(
            src,
            pos_emb=pos_emb,
            attn_mask=attn_mask,
            key_padding_mask=src_key_padding_mask,
        )

        src = src + self.feed_forward1(src)

        self_attn_dropout_mask = self.get_sequence_dropout_mask(
            src, attention_skip_rate
        )

        selected_attn_weights = attn_weights[0:1]
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            pass
        elif self.training and random.random() < float(self.const_attention_rate):
            # Make attention weights constant.  The intention is to
            # encourage these modules to do something similar to an
            # averaging-over-time operation.
            # only need the mask, can just use the 1st one and expand later
            selected_attn_weights = selected_attn_weights[0:1]
            selected_attn_weights = (selected_attn_weights > 0.0).to(
                selected_attn_weights.dtype
            )
            selected_attn_weights = selected_attn_weights * (
                1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
            )

        na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))

        src = src + (
            na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
        )

        self_attn = self.self_attn1(src, attn_weights)

        src = src + (
            self_attn
            if self_attn_dropout_mask is None
            else self_attn * self_attn_dropout_mask
        )

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            conv_skip_rate = 0.0
        else:
            conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
        src = src + self.sequence_dropout(
            self.conv_module1(
                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
            ),
            conv_skip_rate,
        )

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            ff2_skip_rate = 0.0
        else:
            ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
        src = src + self.sequence_dropout(
            self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
        )

        # bypass in the middle of the layer.
        src = self.bypass_mid(src_orig, src)

        self_attn = self.self_attn2(src, attn_weights)

        src = src + (
            self_attn
            if self_attn_dropout_mask is None
            else self_attn * self_attn_dropout_mask
        )

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            conv_skip_rate = 0.0
        else:
            conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
        src = src + self.sequence_dropout(
            self.conv_module2(
                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
            ),
            conv_skip_rate,
        )

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            ff3_skip_rate = 0.0
        else:
            ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
        src = src + self.sequence_dropout(
            self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
        )

        src = self.balancer1(src)
        src = self.norm(src)

        src = self.bypass(src_orig, src)

        src = self.balancer2(src)
        src = self.whiten(src)

        return src

    def streaming_forward(
        self,
        src: torch.Tensor,
        pos_emb: torch.Tensor,
        cached_key: torch.Tensor,
        cached_nonlin_attn: torch.Tensor,
        cached_val1: torch.Tensor,
        cached_val2: torch.Tensor,
        cached_conv1: torch.Tensor,
        cached_conv2: torch.Tensor,
        left_context_len: int,
        src_key_padding_mask: torch.Tensor,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        """Pass the input through the encoder layer in streaming forward mode.

        Args:
            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
            pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
              (batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
            cached_key: cached attention key tensor of left context,
              of shape (left_context_len, batch_size, key_dim)
            cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
              (num_heads, batch_size, left_context_len, head_dim)
            cached_val1: cached left context for the first attention module,
              of shape (left_context_len, batch_size, value_dim)
            cached_val2: cached left context for the second attention module,
              of shape (left_context_len, batch_size, value_dim)
            cached_conv1: cached left context for the first convolution module,
              of shape (batch_size, channels, left_pad)
            cached_conv2: cached left context for the second convolution module,
              of shape (batch_size, channels, left_pad)
            left_context_len: number of left context frames.
            src_key_padding_mask:  the mask for padding, of shape
              (batch_size, left_context_len + seq_len); True means masked position.
              May be None.

        Returns:
            - x, with the same shape as src
            - updated cached_key
            - updated cached_nonlin_attn
            - updated cached_val1
            - updated cached_val2
            - updated cached_conv1
            - updated cached_conv2
        """
        src_orig = src

        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
        attn_weights, cached_key = self.self_attn_weights.streaming_forward(
            src,
            pos_emb=pos_emb,
            cached_key=cached_key,
            left_context_len=left_context_len,
            key_padding_mask=src_key_padding_mask,
        )

        src = src + self.feed_forward1(src)

        na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
            src,
            attn_weights[0:1],
            cached_x=cached_nonlin_attn,
            left_context_len=left_context_len,
        )
        src = src + na

        self_attn, cached_val1 = self.self_attn1.streaming_forward(
            src,
            attn_weights=attn_weights,
            cached_val=cached_val1,
            left_context_len=left_context_len,
        )
        src = src + self_attn

        src_conv, cached_conv1 = self.conv_module1.streaming_forward(
            src,
            cache=cached_conv1,
            src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
        )
        src = src + src_conv

        src = src + self.feed_forward2(src)

        # bypass in the middle of the layer.
        src = self.bypass_mid(src_orig, src)

        self_attn, cached_val2 = self.self_attn2.streaming_forward(
            src,
            attn_weights=attn_weights,
            cached_val=cached_val2,
            left_context_len=left_context_len,
        )
        src = src + self_attn

        src_conv, cached_conv2 = self.conv_module2.streaming_forward(
            src,
            cache=cached_conv2,
            src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
        )
        src = src + src_conv

        src = src + self.feed_forward3(src)

        src = self.norm(src)

        src = self.bypass(src_orig, src)

        return (
            src,
            cached_key,
            cached_nonlin_attn,
            cached_val1,
            cached_val2,
            cached_conv1,
            cached_conv2,
        )

sequence_dropout(x, dropout_rate)

Apply sequence-level dropout to x. x shape: (seq_len, batch_size, embed_dim)

源代码位于: zipformer/modules/zipformer.py
def sequence_dropout(self, x: torch.Tensor, dropout_rate: float) -> torch.Tensor:
    """
    Apply sequence-level dropout to x.
    x shape: (seq_len, batch_size, embed_dim)
    """
    dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
    if dropout_mask is None:
        return x
    else:
        return x * dropout_mask

forward(src, pos_emb, chunk_size=-1, attn_mask=None, src_key_padding_mask=None)

Pass the input through the encoder layer.
Args:
    src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
 pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
 chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.

feature_mask: something that broadcasts with src, that we'll multiply src by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). True means masked position. May be None. src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None.

Returns:
   A tensor which has the same shape as src
源代码位于: zipformer/modules/zipformer.py
def forward(
    self,
    src: torch.Tensor,
    pos_emb: torch.Tensor,
    chunk_size: int = -1,
    attn_mask: Optional[torch.Tensor] = None,
    src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
        Pass the input through the encoder layer.
        Args:
            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
         pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
         chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
       feature_mask: something that broadcasts with src, that we'll multiply `src`
              by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
         attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
                interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
               True means masked position. May be None.
    src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
             masked position.  May be None.

        Returns:
           A tensor which has the same shape as src
    """
    src_orig = src

    # dropout rate for non-feedforward submodules
    if torch.jit.is_scripting() or torch.jit.is_tracing():
        attention_skip_rate = 0.0
    else:
        attention_skip_rate = (
            float(self.attention_skip_rate) if self.training else 0.0
        )

    # attn_weights: (num_heads, batch_size, seq_len, seq_len)
    attn_weights = self.self_attn_weights(
        src,
        pos_emb=pos_emb,
        attn_mask=attn_mask,
        key_padding_mask=src_key_padding_mask,
    )

    src = src + self.feed_forward1(src)

    self_attn_dropout_mask = self.get_sequence_dropout_mask(
        src, attention_skip_rate
    )

    selected_attn_weights = attn_weights[0:1]
    if torch.jit.is_scripting() or torch.jit.is_tracing():
        pass
    elif self.training and random.random() < float(self.const_attention_rate):
        # Make attention weights constant.  The intention is to
        # encourage these modules to do something similar to an
        # averaging-over-time operation.
        # only need the mask, can just use the 1st one and expand later
        selected_attn_weights = selected_attn_weights[0:1]
        selected_attn_weights = (selected_attn_weights > 0.0).to(
            selected_attn_weights.dtype
        )
        selected_attn_weights = selected_attn_weights * (
            1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
        )

    na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))

    src = src + (
        na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
    )

    self_attn = self.self_attn1(src, attn_weights)

    src = src + (
        self_attn
        if self_attn_dropout_mask is None
        else self_attn * self_attn_dropout_mask
    )

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        conv_skip_rate = 0.0
    else:
        conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
    src = src + self.sequence_dropout(
        self.conv_module1(
            src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
        ),
        conv_skip_rate,
    )

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        ff2_skip_rate = 0.0
    else:
        ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
    src = src + self.sequence_dropout(
        self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
    )

    # bypass in the middle of the layer.
    src = self.bypass_mid(src_orig, src)

    self_attn = self.self_attn2(src, attn_weights)

    src = src + (
        self_attn
        if self_attn_dropout_mask is None
        else self_attn * self_attn_dropout_mask
    )

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        conv_skip_rate = 0.0
    else:
        conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
    src = src + self.sequence_dropout(
        self.conv_module2(
            src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
        ),
        conv_skip_rate,
    )

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        ff3_skip_rate = 0.0
    else:
        ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
    src = src + self.sequence_dropout(
        self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
    )

    src = self.balancer1(src)
    src = self.norm(src)

    src = self.bypass(src_orig, src)

    src = self.balancer2(src)
    src = self.whiten(src)

    return src

streaming_forward(src, pos_emb, cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2, left_context_len, src_key_padding_mask)

Pass the input through the encoder layer in streaming forward mode.

参数:

名称 类型 描述 默认
src Tensor

the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).

必需
pos_emb Tensor

(1, left_context_len+2seq_len-1, pos_emb_dim) or (batch_size, left_context_len+2seq_len-1, pos_emb_dim)

必需
cached_key Tensor

cached attention key tensor of left context, of shape (left_context_len, batch_size, key_dim)

必需
cached_nonlin_attn Tensor

left context for nonlin_attention module, a Tensor of shape (num_heads, batch_size, left_context_len, head_dim)

必需
cached_val1 Tensor

cached left context for the first attention module, of shape (left_context_len, batch_size, value_dim)

必需
cached_val2 Tensor

cached left context for the second attention module, of shape (left_context_len, batch_size, value_dim)

必需
cached_conv1 Tensor

cached left context for the first convolution module, of shape (batch_size, channels, left_pad)

必需
cached_conv2 Tensor

cached left context for the second convolution module, of shape (batch_size, channels, left_pad)

必需
left_context_len int

number of left context frames.

必需
src_key_padding_mask Tensor

the mask for padding, of shape (batch_size, left_context_len + seq_len); True means masked position. May be None.

必需

返回:

类型 描述
Tensor
  • x, with the same shape as src
Tensor
  • updated cached_key
Tensor
  • updated cached_nonlin_attn
Tensor
  • updated cached_val1
Tensor
  • updated cached_val2
Tensor
  • updated cached_conv1
Tensor
  • updated cached_conv2
源代码位于: zipformer/modules/zipformer.py
def streaming_forward(
    self,
    src: torch.Tensor,
    pos_emb: torch.Tensor,
    cached_key: torch.Tensor,
    cached_nonlin_attn: torch.Tensor,
    cached_val1: torch.Tensor,
    cached_val2: torch.Tensor,
    cached_conv1: torch.Tensor,
    cached_conv2: torch.Tensor,
    left_context_len: int,
    src_key_padding_mask: torch.Tensor,
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
    """Pass the input through the encoder layer in streaming forward mode.

    Args:
        src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
        pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
          (batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
        cached_key: cached attention key tensor of left context,
          of shape (left_context_len, batch_size, key_dim)
        cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
          (num_heads, batch_size, left_context_len, head_dim)
        cached_val1: cached left context for the first attention module,
          of shape (left_context_len, batch_size, value_dim)
        cached_val2: cached left context for the second attention module,
          of shape (left_context_len, batch_size, value_dim)
        cached_conv1: cached left context for the first convolution module,
          of shape (batch_size, channels, left_pad)
        cached_conv2: cached left context for the second convolution module,
          of shape (batch_size, channels, left_pad)
        left_context_len: number of left context frames.
        src_key_padding_mask:  the mask for padding, of shape
          (batch_size, left_context_len + seq_len); True means masked position.
          May be None.

    Returns:
        - x, with the same shape as src
        - updated cached_key
        - updated cached_nonlin_attn
        - updated cached_val1
        - updated cached_val2
        - updated cached_conv1
        - updated cached_conv2
    """
    src_orig = src

    # attn_weights: (num_heads, batch_size, seq_len, seq_len)
    attn_weights, cached_key = self.self_attn_weights.streaming_forward(
        src,
        pos_emb=pos_emb,
        cached_key=cached_key,
        left_context_len=left_context_len,
        key_padding_mask=src_key_padding_mask,
    )

    src = src + self.feed_forward1(src)

    na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
        src,
        attn_weights[0:1],
        cached_x=cached_nonlin_attn,
        left_context_len=left_context_len,
    )
    src = src + na

    self_attn, cached_val1 = self.self_attn1.streaming_forward(
        src,
        attn_weights=attn_weights,
        cached_val=cached_val1,
        left_context_len=left_context_len,
    )
    src = src + self_attn

    src_conv, cached_conv1 = self.conv_module1.streaming_forward(
        src,
        cache=cached_conv1,
        src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
    )
    src = src + src_conv

    src = src + self.feed_forward2(src)

    # bypass in the middle of the layer.
    src = self.bypass_mid(src_orig, src)

    self_attn, cached_val2 = self.self_attn2.streaming_forward(
        src,
        attn_weights=attn_weights,
        cached_val=cached_val2,
        left_context_len=left_context_len,
    )
    src = src + self_attn

    src_conv, cached_conv2 = self.conv_module2.streaming_forward(
        src,
        cache=cached_conv2,
        src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
    )
    src = src + src_conv

    src = src + self.feed_forward3(src)

    src = self.norm(src)

    src = self.bypass(src_orig, src)

    return (
        src,
        cached_key,
        cached_nonlin_attn,
        cached_val1,
        cached_val2,
        cached_conv1,
        cached_conv2,
    )

ZipformerEncoder

Bases: Module

ZipformerEncoder is a stack of N encoder layers

参数:

名称 类型 描述 默认
encoder_layer Module

an instance of the ZipformerEncoderLayer() class (required).

必需
num_layers int

the number of sub-encoder-layers in the encoder (required).

必需

pos_dim: the dimension for the relative positional encoding

Examples:: >>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8) >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src)

源代码位于: zipformer/modules/zipformer.py
class ZipformerEncoder(torch.nn.Module):
    r"""ZipformerEncoder is a stack of N encoder layers

    Args:
        encoder_layer: an instance of the ZipformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
       pos_dim: the dimension for the relative positional encoding

    Examples::
        >>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8)
        >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = zipformer_encoder(src)
    """

    def __init__(
        self,
        encoder_layer: torch.nn.Module,
        num_layers: int,
        pos_dim: int,
        dropout: float,
        warmup_begin: float,
        warmup_end: float,
        initial_layerdrop_rate: float = 0.5,
        final_layerdrop_rate: float = 0.05,
    ) -> None:
        super().__init__()
        self.encoder_pos = CompactRelPositionalEncoding(
            pos_dim, dropout_rate=0.15, length_factor=1.0
        )

        self.layers = torch.nn.ModuleList(
            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
        )
        self.num_layers = num_layers

        assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end)

        delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
        cur_begin = warmup_begin  # interpreted as a training batch index
        for i in range(num_layers):
            cur_end = cur_begin + delta
            self.layers[i].bypass.skip_rate = ScheduledFloat(
                (cur_begin, initial_layerdrop_rate),
                (cur_end, final_layerdrop_rate),
                default=0.0,
            )
            cur_begin = cur_end

    def forward(
        self,
        src: torch.Tensor,
        chunk_size: int = -1,
        feature_mask: Union[torch.Tensor, float] = 1.0,
        attn_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
            chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
            feature_mask: something that broadcasts with src, that we'll multiply `src`
               by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
            attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
                 interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
                 True means masked position. May be None.
            src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
                 masked position.  May be None.

        Returns: a Tensor with the same shape as src.
        """
        pos_emb = self.encoder_pos(src)
        output = src

        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            output = output * feature_mask

        for i, mod in enumerate(self.layers):
            output = mod(
                output,
                pos_emb,
                chunk_size=chunk_size,
                attn_mask=attn_mask,
                src_key_padding_mask=src_key_padding_mask,
            )

            if not torch.jit.is_scripting() and not torch.jit.is_tracing():
                output = output * feature_mask

        return output

    def streaming_forward(
        self,
        src: torch.Tensor,
        states: List[torch.Tensor],
        left_context_len: int,
        src_key_padding_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
            states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
              (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
            left_context_len: Number of left context frames.
            src_key_padding_mask:  the mask for padding, of shape
              (batch_size, left_context_len + seq_len); True means masked position.
              May be None.

        Returns:
          - output, a Tensor with the same shape as src.
          - updated states
        """
        pos_emb = self.encoder_pos(src, left_context_len)
        output = src

        new_states = []
        for i, mod in enumerate(self.layers):
            (
                cached_key,
                cached_nonlin_attn,
                cached_val1,
                cached_val2,
                cached_conv1,
                cached_conv2,
            ) = states[i * 6 : (i + 1) * 6]
            (
                output,
                new_cached_key,
                new_cached_nonlin_attn,
                new_cached_val1,
                new_cached_val2,
                new_cached_conv1,
                new_cached_conv2,
            ) = mod.streaming_forward(
                output,
                pos_emb,
                cached_key=cached_key,
                cached_nonlin_attn=cached_nonlin_attn,
                cached_val1=cached_val1,
                cached_val2=cached_val2,
                cached_conv1=cached_conv1,
                cached_conv2=cached_conv2,
                left_context_len=left_context_len,
                src_key_padding_mask=src_key_padding_mask,
            )
            new_states += [
                new_cached_key,
                new_cached_nonlin_attn,
                new_cached_val1,
                new_cached_val2,
                new_cached_conv1,
                new_cached_conv2,
            ]

        return output, new_states

forward(src, chunk_size=-1, feature_mask=1.0, attn_mask=None, src_key_padding_mask=None)

Pass the input through the encoder layers in turn.

参数:

名称 类型 描述 默认
src Tensor

the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).

必需
chunk_size int

the number of frames per chunk, of >= 0; if -1, no chunking.

-1
feature_mask Union[Tensor, float]

something that broadcasts with src, that we'll multiply src by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)

1.0
attn_mask Optional[Tensor]

the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). True means masked position. May be None.

None
src_key_padding_mask Optional[Tensor]

the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None.

None

Returns: a Tensor with the same shape as src.

源代码位于: zipformer/modules/zipformer.py
def forward(
    self,
    src: torch.Tensor,
    chunk_size: int = -1,
    feature_mask: Union[torch.Tensor, float] = 1.0,
    attn_mask: Optional[torch.Tensor] = None,
    src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    r"""Pass the input through the encoder layers in turn.

    Args:
        src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
        chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
        feature_mask: something that broadcasts with src, that we'll multiply `src`
           by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
        attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
             interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
             True means masked position. May be None.
        src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
             masked position.  May be None.

    Returns: a Tensor with the same shape as src.
    """
    pos_emb = self.encoder_pos(src)
    output = src

    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        output = output * feature_mask

    for i, mod in enumerate(self.layers):
        output = mod(
            output,
            pos_emb,
            chunk_size=chunk_size,
            attn_mask=attn_mask,
            src_key_padding_mask=src_key_padding_mask,
        )

        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            output = output * feature_mask

    return output

streaming_forward(src, states, left_context_len, src_key_padding_mask)

Pass the input through the encoder layers in turn.

参数:

名称 类型 描述 默认
src Tensor

the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).

必需
states List[Tensor]

list of cached tensors of N encoder layers. For layer-i, states[i6:(i+1)6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).

必需
left_context_len int

Number of left context frames.

必需
src_key_padding_mask Tensor

the mask for padding, of shape (batch_size, left_context_len + seq_len); True means masked position. May be None.

必需

返回:

类型 描述
Tensor
  • output, a Tensor with the same shape as src.
List[Tensor]
  • updated states
源代码位于: zipformer/modules/zipformer.py
def streaming_forward(
    self,
    src: torch.Tensor,
    states: List[torch.Tensor],
    left_context_len: int,
    src_key_padding_mask: torch.Tensor,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
    r"""Pass the input through the encoder layers in turn.

    Args:
        src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
        states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
          (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
        left_context_len: Number of left context frames.
        src_key_padding_mask:  the mask for padding, of shape
          (batch_size, left_context_len + seq_len); True means masked position.
          May be None.

    Returns:
      - output, a Tensor with the same shape as src.
      - updated states
    """
    pos_emb = self.encoder_pos(src, left_context_len)
    output = src

    new_states = []
    for i, mod in enumerate(self.layers):
        (
            cached_key,
            cached_nonlin_attn,
            cached_val1,
            cached_val2,
            cached_conv1,
            cached_conv2,
        ) = states[i * 6 : (i + 1) * 6]
        (
            output,
            new_cached_key,
            new_cached_nonlin_attn,
            new_cached_val1,
            new_cached_val2,
            new_cached_conv1,
            new_cached_conv2,
        ) = mod.streaming_forward(
            output,
            pos_emb,
            cached_key=cached_key,
            cached_nonlin_attn=cached_nonlin_attn,
            cached_val1=cached_val1,
            cached_val2=cached_val2,
            cached_conv1=cached_conv1,
            cached_conv2=cached_conv2,
            left_context_len=left_context_len,
            src_key_padding_mask=src_key_padding_mask,
        )
        new_states += [
            new_cached_key,
            new_cached_nonlin_attn,
            new_cached_val1,
            new_cached_val2,
            new_cached_conv1,
            new_cached_conv2,
        ]

    return output, new_states

BypassModule

Bases: Module

An nn.Module that implements a learnable bypass scale, and also randomized per-sequence layer-skipping. The bypass is limited during early stages of training to be close to "straight-through", i.e. to not do the bypass operation much initially, in order to force all the modules to learn something.

源代码位于: zipformer/modules/zipformer.py
class BypassModule(torch.nn.Module):
    """
    An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
    layer-skipping.  The bypass is limited during early stages of training to be close to
    "straight-through", i.e. to not do the bypass operation much initially, in order to
    force all the modules to learn something.
    """

    def __init__(
        self,
        embed_dim: int,
        skip_rate: FloatLike = 0.0,
        straight_through_rate: FloatLike = 0.0,
        scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
        scale_max: FloatLike = 1.0,
    ):
        super().__init__()
        self.bypass_scale = torch.nn.Parameter(torch.full((embed_dim,), 0.5))
        self.skip_rate = copy.deepcopy(skip_rate)
        self.straight_through_rate = copy.deepcopy(straight_through_rate)
        self.scale_min = copy.deepcopy(scale_min)
        self.scale_max = copy.deepcopy(scale_max)

    def _get_bypass_scale(self, batch_size: int):
        # returns bypass-scale of shape (num_channels,),
        # or (batch_size, num_channels,).  This is actually the
        # scale on the non-residual term, so 0 corresponds to bypassing
        # this module.
        if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
            return self.bypass_scale
        else:
            ans = limit_param_value(
                self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
            )
            skip_rate = float(self.skip_rate)
            if skip_rate != 0.0:
                mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
                ans = ans * mask
                # now ans is of shape (batch_size, num_channels), and is zero for sequences
                # on which we have randomly chosen to do layer-skipping.
            straight_through_rate = float(self.straight_through_rate)
            if straight_through_rate != 0.0:
                mask = (
                    torch.rand((batch_size, 1), device=ans.device)
                    < straight_through_rate
                )
                ans = torch.maximum(ans, mask.to(ans.dtype))
            return ans

    def forward(self, src_orig: torch.Tensor, src: torch.Tensor):
        """
        Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
        Returns: something with the same shape as src and src_orig
        """
        bypass_scale = self._get_bypass_scale(src.shape[1])
        return src_orig + (src - src_orig) * bypass_scale

forward(src_orig, src)

Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) Returns: something with the same shape as src and src_orig

源代码位于: zipformer/modules/zipformer.py
def forward(self, src_orig: torch.Tensor, src: torch.Tensor):
    """
    Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
    Returns: something with the same shape as src and src_orig
    """
    bypass_scale = self._get_bypass_scale(src.shape[1])
    return src_orig + (src - src_orig) * bypass_scale

DownsampledZipformerEncoder

Bases: Module

DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input.

源代码位于: zipformer/modules/zipformer.py
class DownsampledZipformerEncoder(torch.nn.Module):
    r"""
    DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
    after convolutional downsampling, and then upsampled again at the output, and combined
    with the origin input, so that the output has the same shape as the input.
    """

    def __init__(
        self,
        encoder: torch.nn.Module,
        dim: int,
        downsample: int,
        dropout: FloatLike,
        causal: bool,
    ):
        super(DownsampledZipformerEncoder, self).__init__()
        self.downsample_factor = downsample
        self.downsample = SimpleDownsample(dim, downsample, dropout, causal)
        self.num_layers = encoder.num_layers
        self.encoder = encoder
        self.upsample = SimpleUpsample(dim, downsample)
        self.out_combiner = BypassModule(dim, straight_through_rate=0)

    def forward(
        self,
        src: torch.Tensor,
        chunk_size: int = -1,
        feature_mask: Union[torch.Tensor, float] = 1.0,
        attn_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        r"""Downsample, go through encoder, upsample.

        Args:
            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
            feature_mask: something that broadcasts with src, that we'll multiply `src`
               by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
            attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
                 interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
                 True means masked position. May be None.
            src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
                 masked position.  May be None.

        Returns: a Tensor with the same shape as src.
        """
        src_orig = src
        src = self.downsample(src)
        ds = self.downsample_factor
        if attn_mask is not None:
            attn_mask = attn_mask[::ds, ::ds]

        src = self.encoder(
            src,
            chunk_size=chunk_size // ds,
            feature_mask=feature_mask,
            attn_mask=attn_mask,
            src_key_padding_mask=src_key_padding_mask,
        )
        src = self.upsample(src)
        # remove any extra frames that are not a multiple of downsample_factor
        src = src[: src_orig.shape[0]]

        return self.out_combiner(src_orig, src)

    def streaming_forward(
        self,
        src: torch.Tensor,
        states: List[torch.Tensor],
        left_context_len: int,
        src_key_padding_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        r"""Downsample, go through encoder, upsample, in streaming forward mode.

        Args:
            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
            states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
              (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
            left_context_len: Number of left context frames.
            src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
              True means masked position. May be None.

        Returns:
            - output, a Tensor with the same shape as src.
            - updated states
        """
        src_orig = src
        src = self.downsample(src)

        src, new_states = self.encoder.streaming_forward(
            src,
            states=states,
            left_context_len=left_context_len,
            src_key_padding_mask=src_key_padding_mask,
        )
        src = self.upsample(src)
        # remove any extra frames that are not a multiple of downsample_factor
        src = src[: src_orig.shape[0]]

        return self.out_combiner(src_orig, src), new_states

forward(src, chunk_size=-1, feature_mask=1.0, attn_mask=None, src_key_padding_mask=None)

Downsample, go through encoder, upsample.

参数:

名称 类型 描述 默认
src Tensor

the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).

必需
feature_mask Union[Tensor, float]

something that broadcasts with src, that we'll multiply src by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)

1.0
attn_mask Optional[Tensor]

the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). True means masked position. May be None.

None
src_key_padding_mask Optional[Tensor]

the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None.

None

Returns: a Tensor with the same shape as src.

源代码位于: zipformer/modules/zipformer.py
def forward(
    self,
    src: torch.Tensor,
    chunk_size: int = -1,
    feature_mask: Union[torch.Tensor, float] = 1.0,
    attn_mask: Optional[torch.Tensor] = None,
    src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    r"""Downsample, go through encoder, upsample.

    Args:
        src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
        feature_mask: something that broadcasts with src, that we'll multiply `src`
           by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
        attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
             interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
             True means masked position. May be None.
        src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
             masked position.  May be None.

    Returns: a Tensor with the same shape as src.
    """
    src_orig = src
    src = self.downsample(src)
    ds = self.downsample_factor
    if attn_mask is not None:
        attn_mask = attn_mask[::ds, ::ds]

    src = self.encoder(
        src,
        chunk_size=chunk_size // ds,
        feature_mask=feature_mask,
        attn_mask=attn_mask,
        src_key_padding_mask=src_key_padding_mask,
    )
    src = self.upsample(src)
    # remove any extra frames that are not a multiple of downsample_factor
    src = src[: src_orig.shape[0]]

    return self.out_combiner(src_orig, src)

streaming_forward(src, states, left_context_len, src_key_padding_mask)

Downsample, go through encoder, upsample, in streaming forward mode.

参数:

名称 类型 描述 默认
src Tensor

the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).

必需
states List[Tensor]

list of cached tensors of N encoder layers. For layer-i, states[i6:(i+1)6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).

必需
left_context_len int

Number of left context frames.

必需
src_key_padding_mask Tensor

the mask for padding, of shape (batch_size, left_context_len+seq_len); True means masked position. May be None.

必需

返回:

类型 描述
Tensor
  • output, a Tensor with the same shape as src.
List[Tensor]
  • updated states
源代码位于: zipformer/modules/zipformer.py
def streaming_forward(
    self,
    src: torch.Tensor,
    states: List[torch.Tensor],
    left_context_len: int,
    src_key_padding_mask: torch.Tensor,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
    r"""Downsample, go through encoder, upsample, in streaming forward mode.

    Args:
        src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
        states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
          (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
        left_context_len: Number of left context frames.
        src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
          True means masked position. May be None.

    Returns:
        - output, a Tensor with the same shape as src.
        - updated states
    """
    src_orig = src
    src = self.downsample(src)

    src, new_states = self.encoder.streaming_forward(
        src,
        states=states,
        left_context_len=left_context_len,
        src_key_padding_mask=src_key_padding_mask,
    )
    src = self.upsample(src)
    # remove any extra frames that are not a multiple of downsample_factor
    src = src[: src_orig.shape[0]]

    return self.out_combiner(src_orig, src), new_states

SimpleDownsample

Bases: Module

Does downsampling with attention, by weighted sum, and a projection..

源代码位于: zipformer/modules/zipformer.py
class SimpleDownsample(torch.nn.Module):
    """
    Does downsampling with attention, by weighted sum, and a projection..
    """

    def __init__(
        self, channels: int, downsample: int, dropout: FloatLike, causal: bool
    ):
        super(SimpleDownsample, self).__init__()

        self.causal = causal
        self.bias = torch.nn.Parameter(torch.zeros(downsample))

        self.name = None  # will be set from training code
        self.dropout = copy.deepcopy(dropout)

        self.downsample = downsample

    def forward(self, src: torch.Tensor) -> torch.Tensor:
        """
        x: (seq_len, batch_size, in_channels)
        Returns a tensor of shape
           ( (seq_len+downsample-1)//downsample, batch_size, channels)
        """
        (seq_len, batch_size, in_channels) = src.shape
        ds = self.downsample
        d_seq_len = (seq_len + ds - 1) // ds

        # Pad to an exact multiple of self.downsample
        # right-pad src, repeating the last element.
        pad = d_seq_len * ds - seq_len

        if self.causal and torch.jit.is_tracing():
            assert pad == 0, (
                f"pad should be zero for exporting streaming models. Given {pad}"
            )

        # If we are exporting a streaming model, then we skip the if statement
        if not self.causal or not torch.jit.is_tracing():
            src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
            src = torch.cat((src, src_extra), dim=0)

        assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds)

        src = src.reshape(d_seq_len, ds, batch_size, in_channels)

        weights = self.bias.softmax(dim=0)
        # weights: (downsample, 1, 1)
        weights = weights.unsqueeze(-1).unsqueeze(-1)

        # ans1 is the first `in_channels` channels of the output
        ans = (src * weights).sum(dim=1)

        return ans

forward(src)

x: (seq_len, batch_size, in_channels) Returns a tensor of shape ( (seq_len+downsample-1)//downsample, batch_size, channels)

源代码位于: zipformer/modules/zipformer.py
def forward(self, src: torch.Tensor) -> torch.Tensor:
    """
    x: (seq_len, batch_size, in_channels)
    Returns a tensor of shape
       ( (seq_len+downsample-1)//downsample, batch_size, channels)
    """
    (seq_len, batch_size, in_channels) = src.shape
    ds = self.downsample
    d_seq_len = (seq_len + ds - 1) // ds

    # Pad to an exact multiple of self.downsample
    # right-pad src, repeating the last element.
    pad = d_seq_len * ds - seq_len

    if self.causal and torch.jit.is_tracing():
        assert pad == 0, (
            f"pad should be zero for exporting streaming models. Given {pad}"
        )

    # If we are exporting a streaming model, then we skip the if statement
    if not self.causal or not torch.jit.is_tracing():
        src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
        src = torch.cat((src, src_extra), dim=0)

    assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds)

    src = src.reshape(d_seq_len, ds, batch_size, in_channels)

    weights = self.bias.softmax(dim=0)
    # weights: (downsample, 1, 1)
    weights = weights.unsqueeze(-1).unsqueeze(-1)

    # ans1 is the first `in_channels` channels of the output
    ans = (src * weights).sum(dim=1)

    return ans

SimpleUpsample

Bases: Module

A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias.

源代码位于: zipformer/modules/zipformer.py
class SimpleUpsample(torch.nn.Module):
    """
    A very simple form of upsampling that mostly just repeats the input, but
    also adds a position-specific bias.
    """

    def __init__(self, num_channels: int, upsample: int):
        super(SimpleUpsample, self).__init__()
        self.upsample = upsample

    def forward(self, src: torch.Tensor) -> torch.Tensor:
        """
        x: (seq_len, batch_size, num_channels)
        Returns a tensor of shape
           ( (seq_len*upsample), batch_size, num_channels)
        """
        upsample = self.upsample
        (seq_len, batch_size, num_channels) = src.shape
        src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
        src = src.reshape(seq_len * upsample, batch_size, num_channels)
        return src

forward(src)

x: (seq_len, batch_size, num_channels) Returns a tensor of shape ( (seq_len*upsample), batch_size, num_channels)

源代码位于: zipformer/modules/zipformer.py
def forward(self, src: torch.Tensor) -> torch.Tensor:
    """
    x: (seq_len, batch_size, num_channels)
    Returns a tensor of shape
       ( (seq_len*upsample), batch_size, num_channels)
    """
    upsample = self.upsample
    (seq_len, batch_size, num_channels) = src.shape
    src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
    src = src.reshape(seq_len * upsample, batch_size, num_channels)
    return src

CompactRelPositionalEncoding

Bases: Module

Relative positional encoding module. This version is "compact" meaning it is able to encode the important information about the relative position in a relatively small number of dimensions. The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) make very little difference to the embedding. Such differences were potentially important when encoding absolute position, but not important when encoding relative position because there is now no need to compare two large offsets with each other.

Our embedding works by projecting the interval [-infinity,infinity] to a finite interval using the atan() function, before doing the Fourier transform of that fixed interval. The atan() function would compress the "long tails" too small, making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic function to compress large offsets to a smaller range before applying atan(). Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)

参数:

名称 类型 描述 默认
embed_dim int

Embedding dimension.

必需
dropout_rate FloatLike

Dropout rate.

必需
max_len int

Maximum input length: just a heuristic for initialization.

1000
length_factor float

a heuristic scale (should be >= 1.0) which, if larger, gives less weight to small differences of offset near the origin.

1.0
源代码位于: zipformer/modules/zipformer.py
class CompactRelPositionalEncoding(torch.nn.Module):
    """
    Relative positional encoding module.  This version is "compact" meaning it is able to encode
    the important information about the relative position in a relatively small number of dimensions.
    The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
    make very little difference to the embedding.   Such differences were potentially important
    when encoding absolute position, but not important when encoding relative position because there
    is now no need to compare two large offsets with each other.

    Our embedding works by projecting the interval [-infinity,infinity] to a finite interval
    using the atan() function, before doing the Fourier transform of that fixed interval.  The
    atan() function would compress the "long tails" too small,
    making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
    function to compress large offsets to a smaller range before applying atan().
    Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long
    as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)


    Args:
        embed_dim: Embedding dimension.
        dropout_rate: Dropout rate.
        max_len: Maximum input length: just a heuristic for initialization.
        length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
           less weight to small differences of offset near the origin.
    """

    def __init__(
        self,
        embed_dim: int,
        dropout_rate: FloatLike,
        max_len: int = 1000,
        length_factor: float = 1.0,
    ) -> None:
        """Construct a CompactRelPositionalEncoding object."""
        super(CompactRelPositionalEncoding, self).__init__()
        self.embed_dim = embed_dim
        assert embed_dim % 2 == 0, embed_dim
        self.dropout = Dropout2(dropout_rate)
        self.pe = None
        assert length_factor >= 1.0, length_factor
        self.length_factor = length_factor
        self.extend_pe(torch.tensor(0.0).expand(max_len))

    def extend_pe(self, x: torch.Tensor, left_context_len: int = 0) -> None:
        """Reset the positional encodings."""
        T = x.size(0) + left_context_len

        if self.pe is not None:
            # self.pe contains both positive and negative parts
            # the length of self.pe is 2 * input_len - 1
            if self.pe.size(0) >= T * 2 - 1:
                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                return

        # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
        x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)

        freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)

        # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
        # for small time offsets but less resolution for large time offsets.
        compression_length = self.embed_dim**0.5
        # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
        # but it does so more slowly than T for large absolute values of T.
        # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
        # is important.
        x_compressed = (
            compression_length
            * x.sign()
            * ((x.abs() + compression_length).log() - math.log(compression_length))
        )

        # if self.length_factor == 1.0, then length_scale is chosen so that the
        # FFT can exactly separate points close to the origin (T == 0).  So this
        # part of the formulation is not really heuristic.
        # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
        length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)

        # note for machine implementations: if atan is not available, we can use:
        #   x.sign() * ((1 / (x.abs() + 1)) - 1)  * (-math.pi/2)
        #  check on wolframalpha.com: plot(sign(x) *  (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
        x_atan = (x_compressed / length_scale).atan()  # results between -pi and pi

        cosines = (x_atan * freqs).cos()
        sines = (x_atan * freqs).sin()

        pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
        pe[:, 0::2] = cosines
        pe[:, 1::2] = sines
        pe[:, -1] = 1.0  # for bias.

        self.pe = pe.to(dtype=x.dtype)

    def forward(self, x: torch.Tensor, left_context_len: int = 0) -> torch.Tensor:
        """Create positional encoding.

        Args:
            x (Tensor): Input tensor (time, batch, `*`).
            left_context_len: (int): Length of cached left context.

        Returns:
            positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
        """
        self.extend_pe(x, left_context_len)
        x_size_left = x.size(0) + left_context_len
        # length of positive side: x.size(0) + left_context_len
        # length of negative side: x.size(0)
        pos_emb = self.pe[
            self.pe.size(0) // 2 - x_size_left + 1 : self.pe.size(0) // 2  # noqa E203
            + x.size(0),
            :,
        ]
        pos_emb = pos_emb.unsqueeze(0)
        return self.dropout(pos_emb)

__init__(embed_dim, dropout_rate, max_len=1000, length_factor=1.0)

Construct a CompactRelPositionalEncoding object.

源代码位于: zipformer/modules/zipformer.py
def __init__(
    self,
    embed_dim: int,
    dropout_rate: FloatLike,
    max_len: int = 1000,
    length_factor: float = 1.0,
) -> None:
    """Construct a CompactRelPositionalEncoding object."""
    super(CompactRelPositionalEncoding, self).__init__()
    self.embed_dim = embed_dim
    assert embed_dim % 2 == 0, embed_dim
    self.dropout = Dropout2(dropout_rate)
    self.pe = None
    assert length_factor >= 1.0, length_factor
    self.length_factor = length_factor
    self.extend_pe(torch.tensor(0.0).expand(max_len))

extend_pe(x, left_context_len=0)

Reset the positional encodings.

源代码位于: zipformer/modules/zipformer.py
def extend_pe(self, x: torch.Tensor, left_context_len: int = 0) -> None:
    """Reset the positional encodings."""
    T = x.size(0) + left_context_len

    if self.pe is not None:
        # self.pe contains both positive and negative parts
        # the length of self.pe is 2 * input_len - 1
        if self.pe.size(0) >= T * 2 - 1:
            self.pe = self.pe.to(dtype=x.dtype, device=x.device)
            return

    # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
    x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)

    freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)

    # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
    # for small time offsets but less resolution for large time offsets.
    compression_length = self.embed_dim**0.5
    # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
    # but it does so more slowly than T for large absolute values of T.
    # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
    # is important.
    x_compressed = (
        compression_length
        * x.sign()
        * ((x.abs() + compression_length).log() - math.log(compression_length))
    )

    # if self.length_factor == 1.0, then length_scale is chosen so that the
    # FFT can exactly separate points close to the origin (T == 0).  So this
    # part of the formulation is not really heuristic.
    # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
    length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)

    # note for machine implementations: if atan is not available, we can use:
    #   x.sign() * ((1 / (x.abs() + 1)) - 1)  * (-math.pi/2)
    #  check on wolframalpha.com: plot(sign(x) *  (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
    x_atan = (x_compressed / length_scale).atan()  # results between -pi and pi

    cosines = (x_atan * freqs).cos()
    sines = (x_atan * freqs).sin()

    pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
    pe[:, 0::2] = cosines
    pe[:, 1::2] = sines
    pe[:, -1] = 1.0  # for bias.

    self.pe = pe.to(dtype=x.dtype)

forward(x, left_context_len=0)

Create positional encoding.

参数:

名称 类型 描述 默认
x Tensor

Input tensor (time, batch, *).

必需
left_context_len int

(int): Length of cached left context.

0

返回:

类型 描述
Tensor

positional embedding, of shape (batch, left_context_len + 2*time-1, *).

源代码位于: zipformer/modules/zipformer.py
def forward(self, x: torch.Tensor, left_context_len: int = 0) -> torch.Tensor:
    """Create positional encoding.

    Args:
        x (Tensor): Input tensor (time, batch, `*`).
        left_context_len: (int): Length of cached left context.

    Returns:
        positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
    """
    self.extend_pe(x, left_context_len)
    x_size_left = x.size(0) + left_context_len
    # length of positive side: x.size(0) + left_context_len
    # length of negative side: x.size(0)
    pos_emb = self.pe[
        self.pe.size(0) // 2 - x_size_left + 1 : self.pe.size(0) // 2  # noqa E203
        + x.size(0),
        :,
    ]
    pos_emb = pos_emb.unsqueeze(0)
    return self.dropout(pos_emb)

RelPositionMultiheadAttentionWeights

Bases: Module

Module that computes multi-head attention weights with relative position encoding. Various other modules consume the resulting attention weights: see, for example, the SimpleAttention module which allows you to compute conventional attention.

This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", we have to write up the differences.

参数:

名称 类型 描述 默认
embed_dim int

number of channels at the input to this module, e.g. 256 pos_dim: dimension of the positional encoding vectors, e.g. 128.

必需
num_heads int

number of heads to compute weights for, e.g. 8

必需

query_head_dim: dimension of the query (and key), per head. e.g. 24. pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. dropout: dropout probability for attn_output_weights. Default: 0.0. pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on any given call to forward(), in training time.

源代码位于: zipformer/modules/zipformer.py
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
class RelPositionMultiheadAttentionWeights(torch.nn.Module):
    r"""Module that computes multi-head attention weights with relative position encoding.
    Various other modules consume the resulting attention weights: see, for example, the
    SimpleAttention module which allows you to compute conventional attention.

    This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
    we have to write up the differences.


    Args:
           embed_dim: number of channels at the input to this module, e.g. 256
             pos_dim: dimension of the positional encoding vectors, e.g. 128.
           num_heads:  number of heads to compute weights for, e.g. 8
     query_head_dim: dimension of the query (and key), per head.  e.g. 24.
       pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
            dropout: dropout probability for attn_output_weights. Default: 0.0.
       pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
                     any given call to forward(), in training time.
    """

    def __init__(
        self,
        embed_dim: int,
        pos_dim: int,
        num_heads: int,
        query_head_dim: int,
        pos_head_dim: int,
        dropout: float = 0.0,
        pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.query_head_dim = query_head_dim
        self.pos_head_dim = pos_head_dim
        self.dropout = dropout
        self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
        self.name = None  # will be overwritten in training code; for diagnostics.

        key_head_dim = query_head_dim
        in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads

        # the initial_scale is supposed to take over the "scaling" factor of
        # head_dim ** -0.5 that has been used in previous forms of attention,
        # dividing it between the query and key.   Note: this module is intended
        # to be used with the ScaledAdam optimizer; with most other optimizers,
        # it would be necessary to apply the scaling factor in the forward function.
        self.in_proj = ScaledLinear(
            embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
        )

        self.whiten_keys = Whiten(
            num_groups=num_heads,
            whitening_limit=_whitening_schedule(3.0),
            prob=(0.025, 0.25),
            grad_scale=0.025,
        )

        # add a balancer for the keys that runs with very small probability, and
        # tries to enforce that all dimensions have mean around zero.  The
        # weights produced by this module are invariant to adding a constant to
        # the keys, so the derivative of the bias is mathematically zero; but
        # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
        # bias because the small numerical roundoff tends to have a non-random
        # sign.  This module is intended to prevent that.  Use a very small
        # probability; that should be sufficient to fix the problem.
        self.balance_keys = Balancer(
            key_head_dim * num_heads,
            channel_dim=-1,
            min_positive=0.4,
            max_positive=0.6,
            min_abs=0.0,
            max_abs=100.0,
            prob=0.025,
        )

        # linear transformation for positional encoding.
        self.linear_pos = ScaledLinear(
            pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
        )

        # the following are for diagnostics only, see --print-diagnostics option
        self.copy_pos_query = Identity()
        self.copy_query = Identity()

    def forward(
        self,
        x: torch.Tensor,
        pos_emb: torch.Tensor,
        key_padding_mask: Optional[torch.Tensor] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        r"""
        Args:
            x: input of shape (seq_len, batch_size, embed_dim)
            pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
            key_padding_mask: a bool tensor of shape (batch_size, seq_len).  Positions that
               are True in this mask will be ignored as sources in the attention weighting.
            attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
               interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
               saying which positions are allowed to attend to which other positions.
        Returns:
           a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
           interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
        """
        x = self.in_proj(x)
        query_head_dim = self.query_head_dim
        pos_head_dim = self.pos_head_dim
        num_heads = self.num_heads

        seq_len, batch_size, _ = x.shape

        query_dim = query_head_dim * num_heads

        # self-attention
        q = x[..., 0:query_dim]
        k = x[..., query_dim : 2 * query_dim]
        # p is the position-encoding query
        p = x[..., 2 * query_dim :]
        assert p.shape[-1] == num_heads * pos_head_dim, (
            p.shape[-1],
            num_heads,
            pos_head_dim,
        )

        q = self.copy_query(q)  # for diagnostics only, does nothing.
        k = self.whiten_keys(self.balance_keys(k))  # does nothing in the forward pass.
        p = self.copy_pos_query(p)  # for diagnostics only, does nothing.

        q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
        p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
        k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)

        # time1 refers to target, time2 refers to source.
        q = q.permute(2, 1, 0, 3)  # (head, batch, time1, query_head_dim)
        p = p.permute(2, 1, 0, 3)  # (head, batch, time1, pos_head_dim)
        k = k.permute(2, 1, 3, 0)  # (head, batch, d_k, time2)

        attn_scores = torch.matmul(q, k)

        use_pos_scores = False
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            # We can't put random.random() in the same line
            use_pos_scores = True
        elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
            use_pos_scores = True

        if use_pos_scores:
            pos_emb = self.linear_pos(pos_emb)
            seq_len2 = 2 * seq_len - 1
            pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
                2, 0, 3, 1
            )
            # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)

            # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
            #  [where seq_len2 represents relative position.]
            pos_scores = torch.matmul(p, pos_emb)
            # the following .as_strided() expression converts the last axis of pos_scores from relative
            # to absolute position.  I don't know whether I might have got the time-offsets backwards or
            # not, but let this code define which way round it is supposed to be.
            if torch.jit.is_tracing():
                (num_heads, batch_size, time1, n) = pos_scores.shape
                rows = torch.arange(start=time1 - 1, end=-1, step=-1)
                cols = torch.arange(seq_len)
                rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
                indexes = rows + cols
                pos_scores = pos_scores.reshape(-1, n)
                pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
                pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
            else:
                pos_scores = pos_scores.as_strided(
                    (num_heads, batch_size, seq_len, seq_len),
                    (
                        pos_scores.stride(0),
                        pos_scores.stride(1),
                        pos_scores.stride(2) - pos_scores.stride(3),
                        pos_scores.stride(3),
                    ),
                    storage_offset=pos_scores.stride(3) * (seq_len - 1),
                )

            attn_scores = attn_scores + pos_scores

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            pass
        elif self.training and random.random() < 0.1:
            # This is a harder way of limiting the attention scores to not be
            # too large.  It incurs a penalty if any of them has an absolute
            # value greater than 50.0.  this should be outside the normal range
            # of the attention scores.  We use this mechanism instead of, say,
            # something added to the loss function involving the entropy,
            # because once the entropy gets very small gradients through the
            # softmax can become very small, and we'd get zero derivatives.  The
            # choices of 1.0e-04 as the scale on the penalty makes this
            # mechanism vulnerable to the absolute scale of the loss function,
            # but we view this as a failsafe to avoid "implausible" parameter
            # values rather than a regularization method that should be active
            # under normal circumstances.
            attn_scores = penalize_abs_values_gt(
                attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
            )

        assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)

        if attn_mask is not None:
            assert attn_mask.dtype == torch.bool
            # use -1000 to avoid nan's where attn_mask and key_padding_mask make
            # all scores zero.  It's important that this be large enough that exp(-1000)
            # is exactly zero, for reasons related to const_attention_rate, it
            # compares the final weights with zero.
            attn_scores = attn_scores.masked_fill(attn_mask, -1000)

        if key_padding_mask is not None:
            assert key_padding_mask.shape == (
                batch_size,
                seq_len,
            ), key_padding_mask.shape
            attn_scores = attn_scores.masked_fill(
                key_padding_mask.unsqueeze(1),
                -1000,
            )

        # We use our own version of softmax, defined in scaling.py, which should
        # save a little of the memory used in backprop by, if we are in
        # automatic mixed precision mode (amp / autocast), by only storing the
        # half-precision output for backprop purposes.
        attn_weights = softmax(attn_scores, dim=-1)

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            pass
        elif random.random() < 0.001 and not self.training:
            self._print_attn_entropy(attn_weights)

        attn_weights = torch.nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )

        return attn_weights

    def streaming_forward(
        self,
        x: torch.Tensor,
        pos_emb: torch.Tensor,
        cached_key: torch.Tensor,
        left_context_len: int,
        key_padding_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Args:
            x: input of shape (seq_len, batch_size, embed_dim)
            pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
            cached_key: cached attention key tensor of left context,
              of shape (left_context_len, batch_size, key_dim)
            left_context_len: number of left context frames.
            key_padding_mask: a bool tensor of shape (batch_size, seq_len).  Positions that
              are True in this mask will be ignored as sources in the attention weighting.

        Returns:
           - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
             interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
           - updated cached attention key tensor of left context.
        """
        x = self.in_proj(x)
        query_head_dim = self.query_head_dim
        pos_head_dim = self.pos_head_dim
        num_heads = self.num_heads

        seq_len, batch_size, _ = x.shape

        query_dim = query_head_dim * num_heads

        # self-attention
        q = x[..., 0:query_dim]
        k = x[..., query_dim : 2 * query_dim]
        # p is the position-encoding query
        p = x[..., 2 * query_dim :]
        assert p.shape[-1] == num_heads * pos_head_dim

        # Pad cached left contexts
        assert cached_key.shape[0] == left_context_len, (
            cached_key.shape[0],
            left_context_len,
        )
        k = torch.cat([cached_key, k], dim=0)
        # Update cached left contexts
        cached_key = k[-left_context_len:, ...]

        # The length of key
        k_len = k.shape[0]

        q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
        p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
        k = k.reshape(k_len, batch_size, num_heads, query_head_dim)

        # time1 refers to target, time2 refers to source.
        q = q.permute(2, 1, 0, 3)  # (head, batch, time1, query_head_dim)
        p = p.permute(2, 1, 0, 3)  # (head, batch, time1, pos_head_dim)
        k = k.permute(2, 1, 3, 0)  # (head, batch, d_k, time2)

        attn_scores = torch.matmul(q, k)

        pos_emb = self.linear_pos(pos_emb)
        seq_len2 = 2 * seq_len - 1 + left_context_len
        pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
            2, 0, 3, 1
        )
        # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)

        # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
        #  [where seq_len2 represents relative position.]
        pos_scores = torch.matmul(p, pos_emb)

        if torch.jit.is_tracing():
            (num_heads, batch_size, time1, n) = pos_scores.shape
            rows = torch.arange(start=time1 - 1, end=-1, step=-1)
            cols = torch.arange(k_len)
            rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
            indexes = rows + cols
            pos_scores = pos_scores.reshape(-1, n)
            pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
            pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
        # the following .as_strided() expression converts the last axis of pos_scores from relative
        # to absolute position.  I don't know whether I might have got the time-offsets backwards or
        # not, but let this code define which way round it is supposed to be.
        else:
            pos_scores = pos_scores.as_strided(
                (num_heads, batch_size, seq_len, k_len),
                (
                    pos_scores.stride(0),
                    pos_scores.stride(1),
                    pos_scores.stride(2) - pos_scores.stride(3),
                    pos_scores.stride(3),
                ),
                storage_offset=pos_scores.stride(3) * (seq_len - 1),
            )

        attn_scores = attn_scores + pos_scores

        assert attn_scores.shape == (
            num_heads,
            batch_size,
            seq_len,
            k_len,
        ), attn_scores.shape

        if key_padding_mask is not None:
            assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
            attn_scores = attn_scores.masked_fill(
                key_padding_mask.unsqueeze(1),
                -1000,
            )

        attn_weights = attn_scores.softmax(dim=-1)

        return attn_weights, cached_key

    def _print_attn_entropy(self, attn_weights: torch.Tensor):
        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
        (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape

        with torch.no_grad():
            with torch_autocast(enabled=False):
                attn_weights = attn_weights.to(torch.float32)
                attn_weights_entropy = (
                    -((attn_weights + 1.0e-20).log() * attn_weights)
                    .sum(dim=-1)
                    .mean(dim=(1, 2))
                )
                logging.info(
                    f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
                )

forward(x, pos_emb, key_padding_mask=None, attn_mask=None)

参数:

名称 类型 描述 默认
x Tensor

input of shape (seq_len, batch_size, embed_dim)

必需
pos_emb Tensor

Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)

必需
key_padding_mask Optional[Tensor]

a bool tensor of shape (batch_size, seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting.

None
attn_mask Optional[Tensor]

mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), interpreted as ([batch_size,] tgt_seq_len, src_seq_len) saying which positions are allowed to attend to which other positions.

None

Returns: a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).

源代码位于: zipformer/modules/zipformer.py
def forward(
    self,
    x: torch.Tensor,
    pos_emb: torch.Tensor,
    key_padding_mask: Optional[torch.Tensor] = None,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    r"""
    Args:
        x: input of shape (seq_len, batch_size, embed_dim)
        pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
        key_padding_mask: a bool tensor of shape (batch_size, seq_len).  Positions that
           are True in this mask will be ignored as sources in the attention weighting.
        attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
           interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
           saying which positions are allowed to attend to which other positions.
    Returns:
       a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
       interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
    """
    x = self.in_proj(x)
    query_head_dim = self.query_head_dim
    pos_head_dim = self.pos_head_dim
    num_heads = self.num_heads

    seq_len, batch_size, _ = x.shape

    query_dim = query_head_dim * num_heads

    # self-attention
    q = x[..., 0:query_dim]
    k = x[..., query_dim : 2 * query_dim]
    # p is the position-encoding query
    p = x[..., 2 * query_dim :]
    assert p.shape[-1] == num_heads * pos_head_dim, (
        p.shape[-1],
        num_heads,
        pos_head_dim,
    )

    q = self.copy_query(q)  # for diagnostics only, does nothing.
    k = self.whiten_keys(self.balance_keys(k))  # does nothing in the forward pass.
    p = self.copy_pos_query(p)  # for diagnostics only, does nothing.

    q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
    p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
    k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)

    # time1 refers to target, time2 refers to source.
    q = q.permute(2, 1, 0, 3)  # (head, batch, time1, query_head_dim)
    p = p.permute(2, 1, 0, 3)  # (head, batch, time1, pos_head_dim)
    k = k.permute(2, 1, 3, 0)  # (head, batch, d_k, time2)

    attn_scores = torch.matmul(q, k)

    use_pos_scores = False
    if torch.jit.is_scripting() or torch.jit.is_tracing():
        # We can't put random.random() in the same line
        use_pos_scores = True
    elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
        use_pos_scores = True

    if use_pos_scores:
        pos_emb = self.linear_pos(pos_emb)
        seq_len2 = 2 * seq_len - 1
        pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
            2, 0, 3, 1
        )
        # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)

        # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
        #  [where seq_len2 represents relative position.]
        pos_scores = torch.matmul(p, pos_emb)
        # the following .as_strided() expression converts the last axis of pos_scores from relative
        # to absolute position.  I don't know whether I might have got the time-offsets backwards or
        # not, but let this code define which way round it is supposed to be.
        if torch.jit.is_tracing():
            (num_heads, batch_size, time1, n) = pos_scores.shape
            rows = torch.arange(start=time1 - 1, end=-1, step=-1)
            cols = torch.arange(seq_len)
            rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
            indexes = rows + cols
            pos_scores = pos_scores.reshape(-1, n)
            pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
            pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
        else:
            pos_scores = pos_scores.as_strided(
                (num_heads, batch_size, seq_len, seq_len),
                (
                    pos_scores.stride(0),
                    pos_scores.stride(1),
                    pos_scores.stride(2) - pos_scores.stride(3),
                    pos_scores.stride(3),
                ),
                storage_offset=pos_scores.stride(3) * (seq_len - 1),
            )

        attn_scores = attn_scores + pos_scores

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        pass
    elif self.training and random.random() < 0.1:
        # This is a harder way of limiting the attention scores to not be
        # too large.  It incurs a penalty if any of them has an absolute
        # value greater than 50.0.  this should be outside the normal range
        # of the attention scores.  We use this mechanism instead of, say,
        # something added to the loss function involving the entropy,
        # because once the entropy gets very small gradients through the
        # softmax can become very small, and we'd get zero derivatives.  The
        # choices of 1.0e-04 as the scale on the penalty makes this
        # mechanism vulnerable to the absolute scale of the loss function,
        # but we view this as a failsafe to avoid "implausible" parameter
        # values rather than a regularization method that should be active
        # under normal circumstances.
        attn_scores = penalize_abs_values_gt(
            attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
        )

    assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)

    if attn_mask is not None:
        assert attn_mask.dtype == torch.bool
        # use -1000 to avoid nan's where attn_mask and key_padding_mask make
        # all scores zero.  It's important that this be large enough that exp(-1000)
        # is exactly zero, for reasons related to const_attention_rate, it
        # compares the final weights with zero.
        attn_scores = attn_scores.masked_fill(attn_mask, -1000)

    if key_padding_mask is not None:
        assert key_padding_mask.shape == (
            batch_size,
            seq_len,
        ), key_padding_mask.shape
        attn_scores = attn_scores.masked_fill(
            key_padding_mask.unsqueeze(1),
            -1000,
        )

    # We use our own version of softmax, defined in scaling.py, which should
    # save a little of the memory used in backprop by, if we are in
    # automatic mixed precision mode (amp / autocast), by only storing the
    # half-precision output for backprop purposes.
    attn_weights = softmax(attn_scores, dim=-1)

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        pass
    elif random.random() < 0.001 and not self.training:
        self._print_attn_entropy(attn_weights)

    attn_weights = torch.nn.functional.dropout(
        attn_weights, p=self.dropout, training=self.training
    )

    return attn_weights

streaming_forward(x, pos_emb, cached_key, left_context_len, key_padding_mask)

参数:

名称 类型 描述 默认
x Tensor

input of shape (seq_len, batch_size, embed_dim)

必需
pos_emb Tensor

Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)

必需
cached_key Tensor

cached attention key tensor of left context, of shape (left_context_len, batch_size, key_dim)

必需
left_context_len int

number of left context frames.

必需
key_padding_mask Tensor

a bool tensor of shape (batch_size, seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting.

必需

返回:

类型 描述
Tensor
  • attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
Tensor
  • updated cached attention key tensor of left context.
源代码位于: zipformer/modules/zipformer.py
def streaming_forward(
    self,
    x: torch.Tensor,
    pos_emb: torch.Tensor,
    cached_key: torch.Tensor,
    left_context_len: int,
    key_padding_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""
    Args:
        x: input of shape (seq_len, batch_size, embed_dim)
        pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
        cached_key: cached attention key tensor of left context,
          of shape (left_context_len, batch_size, key_dim)
        left_context_len: number of left context frames.
        key_padding_mask: a bool tensor of shape (batch_size, seq_len).  Positions that
          are True in this mask will be ignored as sources in the attention weighting.

    Returns:
       - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
         interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
       - updated cached attention key tensor of left context.
    """
    x = self.in_proj(x)
    query_head_dim = self.query_head_dim
    pos_head_dim = self.pos_head_dim
    num_heads = self.num_heads

    seq_len, batch_size, _ = x.shape

    query_dim = query_head_dim * num_heads

    # self-attention
    q = x[..., 0:query_dim]
    k = x[..., query_dim : 2 * query_dim]
    # p is the position-encoding query
    p = x[..., 2 * query_dim :]
    assert p.shape[-1] == num_heads * pos_head_dim

    # Pad cached left contexts
    assert cached_key.shape[0] == left_context_len, (
        cached_key.shape[0],
        left_context_len,
    )
    k = torch.cat([cached_key, k], dim=0)
    # Update cached left contexts
    cached_key = k[-left_context_len:, ...]

    # The length of key
    k_len = k.shape[0]

    q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
    p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
    k = k.reshape(k_len, batch_size, num_heads, query_head_dim)

    # time1 refers to target, time2 refers to source.
    q = q.permute(2, 1, 0, 3)  # (head, batch, time1, query_head_dim)
    p = p.permute(2, 1, 0, 3)  # (head, batch, time1, pos_head_dim)
    k = k.permute(2, 1, 3, 0)  # (head, batch, d_k, time2)

    attn_scores = torch.matmul(q, k)

    pos_emb = self.linear_pos(pos_emb)
    seq_len2 = 2 * seq_len - 1 + left_context_len
    pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
        2, 0, 3, 1
    )
    # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)

    # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
    #  [where seq_len2 represents relative position.]
    pos_scores = torch.matmul(p, pos_emb)

    if torch.jit.is_tracing():
        (num_heads, batch_size, time1, n) = pos_scores.shape
        rows = torch.arange(start=time1 - 1, end=-1, step=-1)
        cols = torch.arange(k_len)
        rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
        indexes = rows + cols
        pos_scores = pos_scores.reshape(-1, n)
        pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
        pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
    # the following .as_strided() expression converts the last axis of pos_scores from relative
    # to absolute position.  I don't know whether I might have got the time-offsets backwards or
    # not, but let this code define which way round it is supposed to be.
    else:
        pos_scores = pos_scores.as_strided(
            (num_heads, batch_size, seq_len, k_len),
            (
                pos_scores.stride(0),
                pos_scores.stride(1),
                pos_scores.stride(2) - pos_scores.stride(3),
                pos_scores.stride(3),
            ),
            storage_offset=pos_scores.stride(3) * (seq_len - 1),
        )

    attn_scores = attn_scores + pos_scores

    assert attn_scores.shape == (
        num_heads,
        batch_size,
        seq_len,
        k_len,
    ), attn_scores.shape

    if key_padding_mask is not None:
        assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
        attn_scores = attn_scores.masked_fill(
            key_padding_mask.unsqueeze(1),
            -1000,
        )

    attn_weights = attn_scores.softmax(dim=-1)

    return attn_weights, cached_key

SelfAttention

Bases: Module

The simplest possible attention module. This one works with already-computed attention weights, e.g. as computed by RelPositionMultiheadAttentionWeights.

参数:

名称 类型 描述 默认
embed_dim int

the input and output embedding dimension

必需
num_heads int

the number of attention heads

必需
value_head_dim int

the value dimension per head

必需
源代码位于: zipformer/modules/zipformer.py
class SelfAttention(torch.nn.Module):
    """
    The simplest possible attention module.  This one works with already-computed attention
    weights, e.g. as computed by RelPositionMultiheadAttentionWeights.

    Args:
          embed_dim: the input and output embedding dimension
          num_heads: the number of attention heads
          value_head_dim: the value dimension per head
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        value_head_dim: int,
    ) -> None:
        super().__init__()
        self.in_proj = torch.nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)

        self.out_proj = ScaledLinear(
            num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
        )

        self.whiten = Whiten(
            num_groups=1,
            whitening_limit=_whitening_schedule(7.5, ratio=3.0),
            prob=(0.025, 0.25),
            grad_scale=0.01,
        )

    def forward(
        self,
        x: torch.Tensor,
        attn_weights: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
          x: input tensor, of shape (seq_len, batch_size, embed_dim)
         attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
          with seq_len being interpreted as (tgt_seq_len, src_seq_len).  Expect
          attn_weights.sum(dim=-1) == 1.
        Returns:
           a tensor with the same shape as x.
        """
        (seq_len, batch_size, embed_dim) = x.shape
        num_heads = attn_weights.shape[0]
        assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)

        x = self.in_proj(x)  # (seq_len, batch_size, num_heads * value_head_dim)
        x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
        # now x: (num_heads, batch_size, seq_len, value_head_dim)
        value_head_dim = x.shape[-1]

        # todo: see whether there is benefit in overriding matmul
        x = torch.matmul(attn_weights, x)
        # v: (num_heads, batch_size, seq_len, value_head_dim)

        x = (
            x.permute(2, 1, 0, 3)
            .contiguous()
            .view(seq_len, batch_size, num_heads * value_head_dim)
        )

        # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
        x = self.out_proj(x)
        x = self.whiten(x)

        return x

    def streaming_forward(
        self,
        x: torch.Tensor,
        attn_weights: torch.Tensor,
        cached_val: torch.Tensor,
        left_context_len: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: input tensor, of shape (seq_len, batch_size, embed_dim)
            attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
              with seq_len being interpreted as (tgt_seq_len, src_seq_len).  Expect
              attn_weights.sum(dim=-1) == 1.
            cached_val: cached attention value tensor of left context,
              of shape (left_context_len, batch_size, value_dim)
            left_context_len: number of left context frames.

        Returns:
           - attention weighted output, a tensor with the same shape as x.
           - updated cached attention value tensor of left context.
        """
        (seq_len, batch_size, embed_dim) = x.shape
        num_heads = attn_weights.shape[0]
        seq_len2 = seq_len + left_context_len
        assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)

        x = self.in_proj(x)  # (seq_len, batch_size, num_heads * value_head_dim)

        # Pad cached left contexts
        assert cached_val.shape[0] == left_context_len, (
            cached_val.shape[0],
            left_context_len,
        )
        x = torch.cat([cached_val, x], dim=0)
        # Update cached left contexts
        cached_val = x[-left_context_len:, ...]

        x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
        # now x: (num_heads, batch_size, seq_len, value_head_dim)
        value_head_dim = x.shape[-1]

        # todo: see whether there is benefit in overriding matmul
        x = torch.matmul(attn_weights, x)
        # v: (num_heads, batch_size, seq_len, value_head_dim)

        x = (
            x.permute(2, 1, 0, 3)
            .contiguous()
            .view(seq_len, batch_size, num_heads * value_head_dim)
        )

        # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
        x = self.out_proj(x)

        return x, cached_val

forward(x, attn_weights)

参数:

名称 类型 描述 默认
x Tensor

input tensor, of shape (seq_len, batch_size, embed_dim)

必需

attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect attn_weights.sum(dim=-1) == 1. Returns: a tensor with the same shape as x.

源代码位于: zipformer/modules/zipformer.py
def forward(
    self,
    x: torch.Tensor,
    attn_weights: torch.Tensor,
) -> torch.Tensor:
    """
    Args:
      x: input tensor, of shape (seq_len, batch_size, embed_dim)
     attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
      with seq_len being interpreted as (tgt_seq_len, src_seq_len).  Expect
      attn_weights.sum(dim=-1) == 1.
    Returns:
       a tensor with the same shape as x.
    """
    (seq_len, batch_size, embed_dim) = x.shape
    num_heads = attn_weights.shape[0]
    assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)

    x = self.in_proj(x)  # (seq_len, batch_size, num_heads * value_head_dim)
    x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
    # now x: (num_heads, batch_size, seq_len, value_head_dim)
    value_head_dim = x.shape[-1]

    # todo: see whether there is benefit in overriding matmul
    x = torch.matmul(attn_weights, x)
    # v: (num_heads, batch_size, seq_len, value_head_dim)

    x = (
        x.permute(2, 1, 0, 3)
        .contiguous()
        .view(seq_len, batch_size, num_heads * value_head_dim)
    )

    # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
    x = self.out_proj(x)
    x = self.whiten(x)

    return x

streaming_forward(x, attn_weights, cached_val, left_context_len)

参数:

名称 类型 描述 默认
x Tensor

input tensor, of shape (seq_len, batch_size, embed_dim)

必需
attn_weights Tensor

a tensor of shape (num_heads, batch_size, seq_len, seq_len), with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect attn_weights.sum(dim=-1) == 1.

必需
cached_val Tensor

cached attention value tensor of left context, of shape (left_context_len, batch_size, value_dim)

必需
left_context_len int

number of left context frames.

必需

返回:

类型 描述
Tensor
  • attention weighted output, a tensor with the same shape as x.
Tensor
  • updated cached attention value tensor of left context.
源代码位于: zipformer/modules/zipformer.py
def streaming_forward(
    self,
    x: torch.Tensor,
    attn_weights: torch.Tensor,
    cached_val: torch.Tensor,
    left_context_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
        x: input tensor, of shape (seq_len, batch_size, embed_dim)
        attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
          with seq_len being interpreted as (tgt_seq_len, src_seq_len).  Expect
          attn_weights.sum(dim=-1) == 1.
        cached_val: cached attention value tensor of left context,
          of shape (left_context_len, batch_size, value_dim)
        left_context_len: number of left context frames.

    Returns:
       - attention weighted output, a tensor with the same shape as x.
       - updated cached attention value tensor of left context.
    """
    (seq_len, batch_size, embed_dim) = x.shape
    num_heads = attn_weights.shape[0]
    seq_len2 = seq_len + left_context_len
    assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)

    x = self.in_proj(x)  # (seq_len, batch_size, num_heads * value_head_dim)

    # Pad cached left contexts
    assert cached_val.shape[0] == left_context_len, (
        cached_val.shape[0],
        left_context_len,
    )
    x = torch.cat([cached_val, x], dim=0)
    # Update cached left contexts
    cached_val = x[-left_context_len:, ...]

    x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
    # now x: (num_heads, batch_size, seq_len, value_head_dim)
    value_head_dim = x.shape[-1]

    # todo: see whether there is benefit in overriding matmul
    x = torch.matmul(attn_weights, x)
    # v: (num_heads, batch_size, seq_len, value_head_dim)

    x = (
        x.permute(2, 1, 0, 3)
        .contiguous()
        .view(seq_len, batch_size, num_heads * value_head_dim)
    )

    # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
    x = self.out_proj(x)

    return x, cached_val

FeedforwardModule

Bases: Module

Feedforward module in Zipformer model.

源代码位于: zipformer/modules/zipformer.py
class FeedforwardModule(torch.nn.Module):
    """Feedforward module in Zipformer model."""

    def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
        super(FeedforwardModule, self).__init__()
        self.in_proj = torch.nn.Linear(embed_dim, feedforward_dim)

        self.hidden_balancer = Balancer(
            feedforward_dim,
            channel_dim=-1,
            min_positive=0.3,
            max_positive=1.0,
            min_abs=0.75,
            max_abs=5.0,
        )

        # shared_dim=0 means we share the dropout mask along the time axis
        self.out_proj = ActivationDropoutAndLinear(
            feedforward_dim,
            embed_dim,
            activation="SwooshL",
            dropout_p=dropout,
            dropout_shared_dim=0,
            bias=True,
            initial_scale=0.1,
        )

        self.out_whiten = Whiten(
            num_groups=1,
            whitening_limit=_whitening_schedule(7.5),
            prob=(0.025, 0.25),
            grad_scale=0.01,
        )

    def forward(self, x: torch.Tensor):
        x = self.in_proj(x)
        x = self.hidden_balancer(x)
        # out_proj contains SwooshL activation, then dropout, then linear.
        x = self.out_proj(x)
        x = self.out_whiten(x)
        return x

NonlinAttention

Bases: Module

This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed from the attention module) in place of actual convolution. We also took out the second nonlinearity, the one after the attention mechanism.

参数:

名称 类型 描述 默认
channels int

The number of channels of conv layers.

必需
源代码位于: zipformer/modules/zipformer.py
class NonlinAttention(torch.nn.Module):
    """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
       from the attention module) in place of actual convolution.  We also took out the second nonlinearity, the
       one after the attention mechanism.

    Args:
        channels (int): The number of channels of conv layers.
    """

    def __init__(
        self,
        channels: int,
        hidden_channels: int,
    ) -> None:
        super().__init__()

        self.hidden_channels = hidden_channels

        self.in_proj = torch.nn.Linear(channels, hidden_channels * 3, bias=True)

        # balancer that goes before the sigmoid.  Have quite a large min_abs value, at 2.0,
        # because we noticed that well-trained instances of this module have abs-value before the sigmoid
        # starting from about 3, and poorly-trained instances of the module have smaller abs values
        # before the sigmoid.
        self.balancer = Balancer(
            hidden_channels,
            channel_dim=-1,
            min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
            max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
            min_abs=0.5,
            max_abs=5.0,
        )
        self.tanh = torch.nn.Tanh()

        self.identity1 = Identity()  # for diagnostics.
        self.identity2 = Identity()  # for diagnostics.
        self.identity3 = Identity()  # for diagnostics.

        self.out_proj = ScaledLinear(
            hidden_channels, channels, bias=True, initial_scale=0.05
        )

        self.whiten1 = Whiten(
            num_groups=1,
            whitening_limit=_whitening_schedule(5.0),
            prob=(0.025, 0.25),
            grad_scale=0.01,
        )

        self.whiten2 = Whiten(
            num_groups=1,
            whitening_limit=_whitening_schedule(5.0, ratio=3.0),
            prob=(0.025, 0.25),
            grad_scale=0.01,
        )

    def forward(
        self,
        x: torch.Tensor,
        attn_weights: torch.Tensor,
    ) -> torch.Tensor:
        """.
                Args:
                   x: a Tensor of shape (seq_len, batch_size, num_channels)
        attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
                Returns:
                   a Tensor with the same shape as x
        """
        x = self.in_proj(x)

        (seq_len, batch_size, _) = x.shape
        hidden_channels = self.hidden_channels

        s, x, y = x.chunk(3, dim=2)

        # s will go through tanh.

        s = self.balancer(s)
        s = self.tanh(s)

        s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
        x = self.whiten1(x)
        x = x * s
        x = self.identity1(x)  # diagnostics only, it's the identity.

        (seq_len, batch_size, embed_dim) = x.shape
        num_heads = attn_weights.shape[0]
        assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)

        x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
        # now x: (num_heads, batch_size, seq_len, head_dim)
        x = torch.matmul(attn_weights, x)
        # now x: (num_heads, batch_size, seq_len, head_dim)
        x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)

        y = self.identity2(y)
        x = x * y
        x = self.identity3(x)

        x = self.out_proj(x)
        x = self.whiten2(x)
        return x

    def streaming_forward(
        self,
        x: torch.Tensor,
        attn_weights: torch.Tensor,
        cached_x: torch.Tensor,
        left_context_len: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """.
        Args:
            x: a Tensor of shape (seq_len, batch_size, num_channels)
            attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
            cached_x: left context, a Tensor of shape
              (num_heads, batch_size, left_context_len, head_dim)
            left_context_len: number of left context frames.
        Returns:
            - a Tensor with the same shape as x
            - updated left context with same shape as cached_x
        """
        x = self.in_proj(x)

        (seq_len, batch_size, _) = x.shape
        hidden_channels = self.hidden_channels

        s, x, y = x.chunk(3, dim=2)

        # s will go through tanh.
        s = self.tanh(s)

        s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
        x = x * s

        (seq_len, batch_size, embed_dim) = x.shape
        num_heads = attn_weights.shape[0]
        assert attn_weights.shape == (
            num_heads,
            batch_size,
            seq_len,
            left_context_len + seq_len,
        )

        x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
        # now x: (num_heads, batch_size, seq_len, head_dim)

        # Pad cached tensor
        assert cached_x.shape[2] == left_context_len, (
            cached_x.shape[2],
            left_context_len,
        )
        x_pad = torch.cat([cached_x, x], dim=2)
        # Update cached tensor
        cached_x = x_pad[:, :, -left_context_len:, :]

        x = torch.matmul(attn_weights, x_pad)
        # now x: (num_heads, batch_size, seq_len, head_dim)
        x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)

        x = x * y

        x = self.out_proj(x)
        return x, cached_x

forward(x, attn_weights)

. Args: x: a Tensor of shape (seq_len, batch_size, num_channels) attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) Returns: a Tensor with the same shape as x

源代码位于: zipformer/modules/zipformer.py
def forward(
    self,
    x: torch.Tensor,
    attn_weights: torch.Tensor,
) -> torch.Tensor:
    """.
            Args:
               x: a Tensor of shape (seq_len, batch_size, num_channels)
    attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
            Returns:
               a Tensor with the same shape as x
    """
    x = self.in_proj(x)

    (seq_len, batch_size, _) = x.shape
    hidden_channels = self.hidden_channels

    s, x, y = x.chunk(3, dim=2)

    # s will go through tanh.

    s = self.balancer(s)
    s = self.tanh(s)

    s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
    x = self.whiten1(x)
    x = x * s
    x = self.identity1(x)  # diagnostics only, it's the identity.

    (seq_len, batch_size, embed_dim) = x.shape
    num_heads = attn_weights.shape[0]
    assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)

    x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
    # now x: (num_heads, batch_size, seq_len, head_dim)
    x = torch.matmul(attn_weights, x)
    # now x: (num_heads, batch_size, seq_len, head_dim)
    x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)

    y = self.identity2(y)
    x = x * y
    x = self.identity3(x)

    x = self.out_proj(x)
    x = self.whiten2(x)
    return x

streaming_forward(x, attn_weights, cached_x, left_context_len)

. Args: x: a Tensor of shape (seq_len, batch_size, num_channels) attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) cached_x: left context, a Tensor of shape (num_heads, batch_size, left_context_len, head_dim) left_context_len: number of left context frames. Returns: - a Tensor with the same shape as x - updated left context with same shape as cached_x

源代码位于: zipformer/modules/zipformer.py
def streaming_forward(
    self,
    x: torch.Tensor,
    attn_weights: torch.Tensor,
    cached_x: torch.Tensor,
    left_context_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """.
    Args:
        x: a Tensor of shape (seq_len, batch_size, num_channels)
        attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
        cached_x: left context, a Tensor of shape
          (num_heads, batch_size, left_context_len, head_dim)
        left_context_len: number of left context frames.
    Returns:
        - a Tensor with the same shape as x
        - updated left context with same shape as cached_x
    """
    x = self.in_proj(x)

    (seq_len, batch_size, _) = x.shape
    hidden_channels = self.hidden_channels

    s, x, y = x.chunk(3, dim=2)

    # s will go through tanh.
    s = self.tanh(s)

    s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
    x = x * s

    (seq_len, batch_size, embed_dim) = x.shape
    num_heads = attn_weights.shape[0]
    assert attn_weights.shape == (
        num_heads,
        batch_size,
        seq_len,
        left_context_len + seq_len,
    )

    x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
    # now x: (num_heads, batch_size, seq_len, head_dim)

    # Pad cached tensor
    assert cached_x.shape[2] == left_context_len, (
        cached_x.shape[2],
        left_context_len,
    )
    x_pad = torch.cat([cached_x, x], dim=2)
    # Update cached tensor
    cached_x = x_pad[:, :, -left_context_len:, :]

    x = torch.matmul(attn_weights, x_pad)
    # now x: (num_heads, batch_size, seq_len, head_dim)
    x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)

    x = x * y

    x = self.out_proj(x)
    return x, cached_x

ConvolutionModule

Bases: Module

ConvolutionModule in Zipformer model. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py

参数:

名称 类型 描述 默认
channels int

The number of channels of conv layers.

必需
kernel_size int

Kernerl size of conv layers.

必需
bias bool

Whether to use bias in conv layers (default=True).

必需
源代码位于: zipformer/modules/zipformer.py
class ConvolutionModule(torch.nn.Module):
    """ConvolutionModule in Zipformer model.
    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py

    Args:
        channels (int): The number of channels of conv layers.
        kernel_size (int): Kernerl size of conv layers.
        bias (bool): Whether to use bias in conv layers (default=True).

    """

    def __init__(
        self,
        channels: int,
        kernel_size: int,
        causal: bool,
    ) -> None:
        """Construct a ConvolutionModule object."""
        super(ConvolutionModule, self).__init__()
        # kernerl_size should be a odd number for 'SAME' padding
        assert (kernel_size - 1) % 2 == 0

        bottleneck_dim = channels
        self.causal = causal

        self.in_proj = torch.nn.Linear(
            channels,
            2 * bottleneck_dim,
        )
        # the gradients on in_proj are a little noisy, likely to do with the
        # sigmoid in glu.

        # after in_proj we put x through a gated linear unit (nn.functional.glu).
        # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
        # but sometimes, for some reason, for layer 0 the rms ends up being very large,
        # between 50 and 100 for different channels.  This will cause very peaky and
        # sparse derivatives for the sigmoid gating function, which will tend to make
        # the loss function not learn effectively.  (for most layers the average absolute values
        # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
        # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
        # layers, which likely breaks down as 0.5 for the "linear" half and
        # 0.2 to 0.3 for the part that goes into the sigmoid.  The idea is that if we
        # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
        # it will be in a better position to start learning something, i.e. to latch onto
        # the correct range.
        self.balancer1 = Balancer(
            bottleneck_dim,
            channel_dim=-1,
            min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
            max_positive=1.0,
            min_abs=1.5,
            max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
        )

        self.activation1 = Identity()  # for diagnostics

        self.sigmoid = torch.nn.Sigmoid()

        self.activation2 = Identity()  # for diagnostics

        assert kernel_size % 2 == 1

        self.depthwise_conv = (
            ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
            if causal
            else torch.nn.Conv1d(
                in_channels=bottleneck_dim,
                out_channels=bottleneck_dim,
                groups=bottleneck_dim,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
            )
        )

        self.balancer2 = Balancer(
            bottleneck_dim,
            channel_dim=1,
            min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
            max_positive=1.0,
            min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
            max_abs=10.0,
        )

        self.whiten = Whiten(
            num_groups=1,
            whitening_limit=_whitening_schedule(7.5),
            prob=(0.025, 0.25),
            grad_scale=0.01,
        )

        self.out_proj = ActivationDropoutAndLinear(
            bottleneck_dim,
            channels,
            activation="SwooshR",
            dropout_p=0.0,
            initial_scale=0.05,
        )

    def forward(
        self,
        x: torch.Tensor,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        chunk_size: int = -1,
    ) -> torch.Tensor:
        """Compute convolution module.

        Args:
            x: Input tensor (#time, batch, channels).
           src_key_padding_mask: the mask for the src keys per batch (optional):
               (batch, #time), contains True in masked positions.

        Returns:
            Tensor: Output tensor (#time, batch, channels).

        """

        x = self.in_proj(x)  # (time, batch, 2*channels)

        x, s = x.chunk(2, dim=2)
        s = self.balancer1(s)
        s = self.sigmoid(s)
        x = self.activation1(x)  # identity.
        x = x * s
        x = self.activation2(x)  # identity

        # (time, batch, channels)

        # exchange the temporal dimension and the feature dimension
        x = x.permute(1, 2, 0)  # (#batch, channels, time).

        if src_key_padding_mask is not None:
            x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)

        if (
            not torch.jit.is_scripting()
            and not torch.jit.is_tracing()
            and chunk_size >= 0
        ):
            # Not support exporting a model for simulated streaming decoding
            assert self.causal, (
                "Must initialize model with causal=True if you use chunk_size"
            )
            x = self.depthwise_conv(x, chunk_size=chunk_size)
        else:
            x = self.depthwise_conv(x)

        x = self.balancer2(x)
        x = x.permute(2, 0, 1)  # (time, batch, channels)

        x = self.whiten(x)  # (time, batch, channels)
        x = self.out_proj(x)  # (time, batch, channels)

        return x

    def streaming_forward(
        self,
        x: torch.Tensor,
        cache: torch.Tensor,
        src_key_padding_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute convolution module in streaming forward mode.

        Args:
            x: Input tensor (#time, batch, channels).
            cache: cached left context for depthwise_conv of shape
              (#batch, channels, left_pad)
            src_key_padding_mask: the mask for the src keys per batch (optional):
              (batch, #time), contains True in masked positions.

        Returns:
            - Output tensor (#time, batch, channels).
            - Updated cache (#batch, channels, left_pad)
        """

        x = self.in_proj(x)  # (time, batch, 2*channels)

        x, s = x.chunk(2, dim=2)
        s = self.sigmoid(s)
        x = x * s
        # (time, batch, channels)

        # exchange the temporal dimension and the feature dimension
        x = x.permute(1, 2, 0)  # (#batch, channels, time).

        if src_key_padding_mask is not None:
            x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)

        x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)

        x = x.permute(2, 0, 1)  # (time, batch, channels)

        x = self.out_proj(x)  # (time, batch, channels)

        return x, cache

__init__(channels, kernel_size, causal)

Construct a ConvolutionModule object.

源代码位于: zipformer/modules/zipformer.py
def __init__(
    self,
    channels: int,
    kernel_size: int,
    causal: bool,
) -> None:
    """Construct a ConvolutionModule object."""
    super(ConvolutionModule, self).__init__()
    # kernerl_size should be a odd number for 'SAME' padding
    assert (kernel_size - 1) % 2 == 0

    bottleneck_dim = channels
    self.causal = causal

    self.in_proj = torch.nn.Linear(
        channels,
        2 * bottleneck_dim,
    )
    # the gradients on in_proj are a little noisy, likely to do with the
    # sigmoid in glu.

    # after in_proj we put x through a gated linear unit (nn.functional.glu).
    # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
    # but sometimes, for some reason, for layer 0 the rms ends up being very large,
    # between 50 and 100 for different channels.  This will cause very peaky and
    # sparse derivatives for the sigmoid gating function, which will tend to make
    # the loss function not learn effectively.  (for most layers the average absolute values
    # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
    # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
    # layers, which likely breaks down as 0.5 for the "linear" half and
    # 0.2 to 0.3 for the part that goes into the sigmoid.  The idea is that if we
    # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
    # it will be in a better position to start learning something, i.e. to latch onto
    # the correct range.
    self.balancer1 = Balancer(
        bottleneck_dim,
        channel_dim=-1,
        min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
        max_positive=1.0,
        min_abs=1.5,
        max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
    )

    self.activation1 = Identity()  # for diagnostics

    self.sigmoid = torch.nn.Sigmoid()

    self.activation2 = Identity()  # for diagnostics

    assert kernel_size % 2 == 1

    self.depthwise_conv = (
        ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
        if causal
        else torch.nn.Conv1d(
            in_channels=bottleneck_dim,
            out_channels=bottleneck_dim,
            groups=bottleneck_dim,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
        )
    )

    self.balancer2 = Balancer(
        bottleneck_dim,
        channel_dim=1,
        min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
        max_positive=1.0,
        min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
        max_abs=10.0,
    )

    self.whiten = Whiten(
        num_groups=1,
        whitening_limit=_whitening_schedule(7.5),
        prob=(0.025, 0.25),
        grad_scale=0.01,
    )

    self.out_proj = ActivationDropoutAndLinear(
        bottleneck_dim,
        channels,
        activation="SwooshR",
        dropout_p=0.0,
        initial_scale=0.05,
    )

forward(x, src_key_padding_mask=None, chunk_size=-1)

Compute convolution module.

参数:

名称 类型 描述 默认
x Tensor

Input tensor (#time, batch, channels).

必需

src_key_padding_mask: the mask for the src keys per batch (optional): (batch, #time), contains True in masked positions.

返回:

名称 类型 描述
Tensor Tensor

Output tensor (#time, batch, channels).

源代码位于: zipformer/modules/zipformer.py
def forward(
    self,
    x: torch.Tensor,
    src_key_padding_mask: Optional[torch.Tensor] = None,
    chunk_size: int = -1,
) -> torch.Tensor:
    """Compute convolution module.

    Args:
        x: Input tensor (#time, batch, channels).
       src_key_padding_mask: the mask for the src keys per batch (optional):
           (batch, #time), contains True in masked positions.

    Returns:
        Tensor: Output tensor (#time, batch, channels).

    """

    x = self.in_proj(x)  # (time, batch, 2*channels)

    x, s = x.chunk(2, dim=2)
    s = self.balancer1(s)
    s = self.sigmoid(s)
    x = self.activation1(x)  # identity.
    x = x * s
    x = self.activation2(x)  # identity

    # (time, batch, channels)

    # exchange the temporal dimension and the feature dimension
    x = x.permute(1, 2, 0)  # (#batch, channels, time).

    if src_key_padding_mask is not None:
        x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)

    if (
        not torch.jit.is_scripting()
        and not torch.jit.is_tracing()
        and chunk_size >= 0
    ):
        # Not support exporting a model for simulated streaming decoding
        assert self.causal, (
            "Must initialize model with causal=True if you use chunk_size"
        )
        x = self.depthwise_conv(x, chunk_size=chunk_size)
    else:
        x = self.depthwise_conv(x)

    x = self.balancer2(x)
    x = x.permute(2, 0, 1)  # (time, batch, channels)

    x = self.whiten(x)  # (time, batch, channels)
    x = self.out_proj(x)  # (time, batch, channels)

    return x

streaming_forward(x, cache, src_key_padding_mask)

Compute convolution module in streaming forward mode.

参数:

名称 类型 描述 默认
x Tensor

Input tensor (#time, batch, channels).

必需
cache Tensor

cached left context for depthwise_conv of shape (#batch, channels, left_pad)

必需
src_key_padding_mask Tensor

the mask for the src keys per batch (optional): (batch, #time), contains True in masked positions.

必需

返回:

类型 描述
Tensor
  • Output tensor (#time, batch, channels).
Tensor
  • Updated cache (#batch, channels, left_pad)
源代码位于: zipformer/modules/zipformer.py
def streaming_forward(
    self,
    x: torch.Tensor,
    cache: torch.Tensor,
    src_key_padding_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute convolution module in streaming forward mode.

    Args:
        x: Input tensor (#time, batch, channels).
        cache: cached left context for depthwise_conv of shape
          (#batch, channels, left_pad)
        src_key_padding_mask: the mask for the src keys per batch (optional):
          (batch, #time), contains True in masked positions.

    Returns:
        - Output tensor (#time, batch, channels).
        - Updated cache (#batch, channels, left_pad)
    """

    x = self.in_proj(x)  # (time, batch, 2*channels)

    x, s = x.chunk(2, dim=2)
    s = self.sigmoid(s)
    x = x * s
    # (time, batch, channels)

    # exchange the temporal dimension and the feature dimension
    x = x.permute(1, 2, 0)  # (#batch, channels, time).

    if src_key_padding_mask is not None:
        x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)

    x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)

    x = x.permute(2, 0, 1)  # (time, batch, channels)

    x = self.out_proj(x)  # (time, batch, channels)

    return x, cache

Encoder

zipformer.modules.model

Decoder

Bases: Module

This class modifies the stateless decoder from the following paper:

RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419

It removes the recurrent connection from the decoder, i.e., the prediction network. Different from the above paper, it adds an extra Conv1d right after the embedding layer.

源代码位于: zipformer/modules/model.py
class Decoder(torch.nn.Module):
    """This class modifies the stateless decoder from the following paper:

        RNN-transducer with stateless prediction network
        https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419

    It removes the recurrent connection from the decoder, i.e., the prediction
    network. Different from the above paper, it adds an extra Conv1d
    right after the embedding layer.
    """

    def __init__(
        self,
        vocab_size: int,
        decoder_dim: int,
        blank_id: int,
        context_size: int,
    ):
        """
        Args:
          vocab_size:
            Number of tokens of the modeling unit including blank.
          decoder_dim:
            Dimension of the input embedding, and of the decoder output.
          blank_id:
            The ID of the blank symbol.
          context_size:
            Number of previous words to use to predict the next word.
            1 means bigram; 2 means trigram. n means (n+1)-gram.
        """
        super().__init__()

        self.embedding = torch.nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=decoder_dim,
        )
        # the balancers are to avoid any drift in the magnitude of the
        # embeddings, which would interact badly with parameter averaging.
        self.balancer = Balancer(
            decoder_dim,
            channel_dim=-1,
            min_positive=0.0,
            max_positive=1.0,
            min_abs=0.5,
            max_abs=1.0,
            prob=0.05,
        )

        self.blank_id = blank_id

        assert context_size >= 1, context_size
        self.context_size = context_size
        self.vocab_size = vocab_size

        if context_size > 1:
            self.conv = torch.nn.Conv1d(
                in_channels=decoder_dim,
                out_channels=decoder_dim,
                kernel_size=context_size,
                padding=0,
                groups=decoder_dim // 4,  # group size == 4
                bias=False,
            )
            self.balancer2 = Balancer(
                decoder_dim,
                channel_dim=-1,
                min_positive=0.0,
                max_positive=1.0,
                min_abs=0.5,
                max_abs=1.0,
                prob=0.05,
            )
        else:
            # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
            # when inference with torch.jit.script and context_size == 1
            self.conv = torch.nn.Identity()
            self.balancer2 = torch.nn.Identity()

    def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
        """
        Args:
          y:
            A 2-D tensor of shape (N, U).
          need_pad:
            True to left pad the input. Should be True during training.
            False to not pad the input. Should be False during inference.
        Returns:
          Return a tensor of shape (N, U, decoder_dim).
        """
        y = y.to(torch.int64)
        # this stuff about clamp() is a temporary fix for a mismatch
        # at utterance start, we use negative ids in beam_search.py
        embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)

        embedding_out = self.balancer(embedding_out)

        if self.context_size > 1:
            embedding_out = embedding_out.permute(0, 2, 1)
            if need_pad is True:
                embedding_out = torch.nn.functional.pad(
                    embedding_out, pad=(self.context_size - 1, 0)
                )
            else:
                # During inference time, there is no need to do extra padding
                # as we only need one output
                assert embedding_out.size(-1) == self.context_size
            embedding_out = self.conv(embedding_out)
            embedding_out = embedding_out.permute(0, 2, 1)
            embedding_out = torch.nn.functional.relu(embedding_out)
            embedding_out = self.balancer2(embedding_out)

        return embedding_out

__init__(vocab_size, decoder_dim, blank_id, context_size)

参数:

名称 类型 描述 默认
vocab_size int

Number of tokens of the modeling unit including blank.

必需
decoder_dim int

Dimension of the input embedding, and of the decoder output.

必需
blank_id int

The ID of the blank symbol.

必需
context_size int

Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram.

必需
源代码位于: zipformer/modules/model.py
def __init__(
    self,
    vocab_size: int,
    decoder_dim: int,
    blank_id: int,
    context_size: int,
):
    """
    Args:
      vocab_size:
        Number of tokens of the modeling unit including blank.
      decoder_dim:
        Dimension of the input embedding, and of the decoder output.
      blank_id:
        The ID of the blank symbol.
      context_size:
        Number of previous words to use to predict the next word.
        1 means bigram; 2 means trigram. n means (n+1)-gram.
    """
    super().__init__()

    self.embedding = torch.nn.Embedding(
        num_embeddings=vocab_size,
        embedding_dim=decoder_dim,
    )
    # the balancers are to avoid any drift in the magnitude of the
    # embeddings, which would interact badly with parameter averaging.
    self.balancer = Balancer(
        decoder_dim,
        channel_dim=-1,
        min_positive=0.0,
        max_positive=1.0,
        min_abs=0.5,
        max_abs=1.0,
        prob=0.05,
    )

    self.blank_id = blank_id

    assert context_size >= 1, context_size
    self.context_size = context_size
    self.vocab_size = vocab_size

    if context_size > 1:
        self.conv = torch.nn.Conv1d(
            in_channels=decoder_dim,
            out_channels=decoder_dim,
            kernel_size=context_size,
            padding=0,
            groups=decoder_dim // 4,  # group size == 4
            bias=False,
        )
        self.balancer2 = Balancer(
            decoder_dim,
            channel_dim=-1,
            min_positive=0.0,
            max_positive=1.0,
            min_abs=0.5,
            max_abs=1.0,
            prob=0.05,
        )
    else:
        # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
        # when inference with torch.jit.script and context_size == 1
        self.conv = torch.nn.Identity()
        self.balancer2 = torch.nn.Identity()

forward(y, need_pad=True)

参数:

名称 类型 描述 默认
y Tensor

A 2-D tensor of shape (N, U).

必需
need_pad bool

True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference.

True

Returns: Return a tensor of shape (N, U, decoder_dim).

源代码位于: zipformer/modules/model.py
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
    """
    Args:
      y:
        A 2-D tensor of shape (N, U).
      need_pad:
        True to left pad the input. Should be True during training.
        False to not pad the input. Should be False during inference.
    Returns:
      Return a tensor of shape (N, U, decoder_dim).
    """
    y = y.to(torch.int64)
    # this stuff about clamp() is a temporary fix for a mismatch
    # at utterance start, we use negative ids in beam_search.py
    embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)

    embedding_out = self.balancer(embedding_out)

    if self.context_size > 1:
        embedding_out = embedding_out.permute(0, 2, 1)
        if need_pad is True:
            embedding_out = torch.nn.functional.pad(
                embedding_out, pad=(self.context_size - 1, 0)
            )
        else:
            # During inference time, there is no need to do extra padding
            # as we only need one output
            assert embedding_out.size(-1) == self.context_size
        embedding_out = self.conv(embedding_out)
        embedding_out = embedding_out.permute(0, 2, 1)
        embedding_out = torch.nn.functional.relu(embedding_out)
        embedding_out = self.balancer2(embedding_out)

    return embedding_out

Joiner

Bases: Module

源代码位于: zipformer/modules/model.py
class Joiner(torch.nn.Module):
    def __init__(
        self,
        encoder_dim: int,
        decoder_dim: int,
        joiner_dim: int,
        vocab_size: int,
    ):
        super().__init__()

        self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
        self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
        self.output_linear = torch.nn.Linear(joiner_dim, vocab_size)

    def forward(
        self,
        encoder_out: torch.Tensor,
        decoder_out: torch.Tensor,
        project_input: bool = True,
    ) -> torch.Tensor:
        """
        Args:
          encoder_out:
            Output from the encoder. Its shape is (N, T, s_range, C).
          decoder_out:
            Output from the decoder. Its shape is (N, T, s_range, C).
          project_input:
            If true, apply input projections encoder_proj and decoder_proj.
            If this is false, it is the user's responsibility to do this
            manually.
        Returns:
          Return a tensor of shape (N, T, s_range, C).
        """
        assert encoder_out.ndim == decoder_out.ndim, (
            encoder_out.shape,
            decoder_out.shape,
        )

        if project_input:
            logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
        else:
            logit = encoder_out + decoder_out

        logit = self.output_linear(torch.tanh(logit))

        return logit

forward(encoder_out, decoder_out, project_input=True)

参数:

名称 类型 描述 默认
encoder_out Tensor

Output from the encoder. Its shape is (N, T, s_range, C).

必需
decoder_out Tensor

Output from the decoder. Its shape is (N, T, s_range, C).

必需
project_input bool

If true, apply input projections encoder_proj and decoder_proj. If this is false, it is the user's responsibility to do this manually.

True

Returns: Return a tensor of shape (N, T, s_range, C).

源代码位于: zipformer/modules/model.py
def forward(
    self,
    encoder_out: torch.Tensor,
    decoder_out: torch.Tensor,
    project_input: bool = True,
) -> torch.Tensor:
    """
    Args:
      encoder_out:
        Output from the encoder. Its shape is (N, T, s_range, C).
      decoder_out:
        Output from the decoder. Its shape is (N, T, s_range, C).
      project_input:
        If true, apply input projections encoder_proj and decoder_proj.
        If this is false, it is the user's responsibility to do this
        manually.
    Returns:
      Return a tensor of shape (N, T, s_range, C).
    """
    assert encoder_out.ndim == decoder_out.ndim, (
        encoder_out.shape,
        decoder_out.shape,
    )

    if project_input:
        logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
    else:
        logit = encoder_out + decoder_out

    logit = self.output_linear(torch.tanh(logit))

    return logit

AsrModel

Bases: Module

源代码位于: zipformer/modules/model.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
class AsrModel(torch.nn.Module):
    def __init__(
        self,
        feature_dim: int = 80,
        downsampling_factor: Tuple[int] = (2, 4),
        encoder_dim: Union[int, Tuple[int]] = 384,
        num_encoder_layers: Union[int, Tuple[int]] = 4,
        encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
        query_head_dim: Union[int, Tuple[int]] = 24,
        pos_head_dim: Union[int, Tuple[int]] = 4,
        value_head_dim: Union[int, Tuple[int]] = 12,
        num_heads: Union[int, Tuple[int]] = 8,
        feedforward_dim: Union[int, Tuple[int]] = 1536,
        cnn_module_kernel: Union[int, Tuple[int]] = 31,
        pos_dim: int = 192,
        dropout: FloatLike = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
        warmup_batches: float = 4000.0,
        causal: bool = False,
        chunk_size: Tuple[int] = [-1],
        left_context_frames: Tuple[int] = [-1],
        use_ctc: bool = False,
        blank_id: int = 0,
        vocab_size: int = 500,
        use_transducer: bool = True,
        decoder_dim: int = 512,
        context_size: int = 2,
        joiner_dim: int = 512,
        use_attention_decoder: bool = False,
        attention_decoder_dim: int = 512,
        attention_decoder_num_layers: int = 2,
        attention_decoder_attention_dim: int = 512,
        attention_decoder_num_heads: int = 8,
        attention_decoder_feedforward_dim: int = 2048,
        sos_id: int = 1,
        eos_id: int = 2,
        ignore_id: int = -100,
        label_smoothing: float = 0.0,
    ):
        """A joint CTC & Transducer ASR model.

        - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
        - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
        - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)

        Args:
          encoder_embed:
            It is a Convolutional 2D subsampling module. It converts
            an input of shape (N, T, idim) to an output of of shape
            (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
          encoder:
            It is the transcription network in the paper. Its accepts
            two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
            It returns two tensors: `logits` of shape (N, T, encoder_dim) and
            `logit_lens` of shape (N,).
          decoder:
            It is the prediction network in the paper. Its input shape
            is (N, U) and its output shape is (N, U, decoder_dim).
            It should contain one attribute: `blank_id`.
            It is used when use_transducer is True.
          joiner:
            It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
            Its output shape is (N, T, U, vocab_size). Note that its output contains
            unnormalized probs, i.e., not processed by log-softmax.
            It is used when use_transducer is True.
          use_transducer:
            Whether use transducer head. Default: True.
          use_ctc:
            Whether use CTC head. Default: False.
          use_attention_decoder:
            Whether use attention-decoder head. Default: False.
        """
        super().__init__()

        assert use_transducer or use_ctc, (
            f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
        )

        self.blank_id = blank_id
        self.vocab_size = vocab_size

        # encoder_embed converts the input of shape (N, T, num_features)
        # to the shape (N, (T - 7) // 2, encoder_dims).
        # That is, it does two things simultaneously:
        #   (1) subsampling: T -> (T - 7) // 2
        #   (2) embedding: num_features -> encoder_dims
        # In the normal configuration, we will downsample once more at the end
        # by a factor of 2, and most of the encoder stacks will run at a lower
        # sampling rate.
        self.encoder_embed = Conv2dSubsampling(
            in_channels=feature_dim,
            out_channels=_to_int_tuple(encoder_dim)[0],
            dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
        )

        self.encoder = Zipformer(
            output_downsampling_factor=2,
            downsampling_factor=_to_int_tuple(downsampling_factor),
            num_encoder_layers=_to_int_tuple(num_encoder_layers),
            encoder_dim=_to_int_tuple(encoder_dim),
            encoder_unmasked_dim=_to_int_tuple(encoder_unmasked_dim),
            query_head_dim=_to_int_tuple(query_head_dim),
            pos_head_dim=_to_int_tuple(pos_head_dim),
            value_head_dim=_to_int_tuple(value_head_dim),
            pos_dim=pos_dim,
            num_heads=_to_int_tuple(num_heads),
            feedforward_dim=_to_int_tuple(feedforward_dim),
            cnn_module_kernel=_to_int_tuple(cnn_module_kernel),
            dropout=dropout,
            warmup_batches=warmup_batches,
            causal=causal,
            chunk_size=_to_int_tuple(chunk_size),
            left_context_frames=_to_int_tuple(left_context_frames),
        )

        self.use_transducer = use_transducer
        self.encoder_out_dim = max(_to_int_tuple(encoder_dim))
        if use_transducer:
            self.decoder = Decoder(
                vocab_size=vocab_size,
                decoder_dim=decoder_dim,
                blank_id=blank_id,
                context_size=context_size,
            )
            self.joiner = Joiner(
                encoder_dim=self.encoder_out_dim,
                decoder_dim=decoder_dim,
                joiner_dim=joiner_dim,
                vocab_size=vocab_size,
            )
            self.simple_am_proj = ScaledLinear(
                self.encoder_out_dim, vocab_size, initial_scale=0.25
            )
            self.simple_lm_proj = ScaledLinear(
                decoder_dim, vocab_size, initial_scale=0.25
            )
        else:
            self.decoder = None
            self.joiner = None

        self.use_ctc = use_ctc
        if use_ctc:
            self.ctc_output = torch.nn.Sequential(
                torch.nn.Dropout(p=0.1),
                torch.nn.Linear(self.encoder_out_dim, vocab_size),
                torch.nn.LogSoftmax(dim=-1),
            )
        else:
            self.ctc_output = None

        self.use_attention_decoder = use_attention_decoder
        if use_attention_decoder:
            self.attention_decoder = AttentionDecoderModel(
                vocab_size=vocab_size,
                decoder_dim=attention_decoder_dim,
                num_decoder_layers=attention_decoder_num_layers,
                attention_dim=attention_decoder_attention_dim,
                num_heads=attention_decoder_num_heads,
                feedforward_dim=attention_decoder_feedforward_dim,
                memory_dim=self.encoder_out_dim,
                sos_id=sos_id,
                eos_id=eos_id,
                ignore_id=ignore_id,
                label_smoothing=label_smoothing,
            )
        else:
            self.attention_decoder = None

    def forward_encoder(
        self, x: torch.Tensor, x_lens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute encoder outputs.
        Args:
          x:
            A 3-D tensor of shape (N, T, C).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.

        Returns:
          encoder_out:
            Encoder output, of shape (N, T, C).
          encoder_out_lens:
            Encoder output lengths, of shape (N,).
        """
        # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
        x, x_lens = self.encoder_embed(x, x_lens)
        # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")

        src_key_padding_mask = make_pad_mask(x_lens)
        x = x.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)

        encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)

        encoder_out = encoder_out.permute(1, 0, 2)  # (T, N, C) ->(N, T, C)
        assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)

        return encoder_out, encoder_out_lens

    def forward_ctc(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        targets: torch.Tensor,
        target_lengths: torch.Tensor,
        reduction: str = "sum",
    ) -> torch.Tensor:
        """Compute CTC loss.
        Args:
          encoder_out:
            Encoder output, of shape (N, T, C).
          encoder_out_lens:
            Encoder output lengths, of shape (N,).
          targets:
            Target Tensor of shape (sum(target_lengths)). The targets are assumed
            to be un-padded and concatenated within 1 dimension.
        """
        # Compute CTC log-prob
        ctc_output = self.ctc_output(encoder_out)  # (N, T, C)

        ctc_loss = torch.nn.functional.ctc_loss(
            log_probs=ctc_output.permute(1, 0, 2),  # (T, N, C)
            targets=targets.cpu(),
            input_lengths=encoder_out_lens.cpu(),
            target_lengths=target_lengths.cpu(),
            reduction=reduction,
        )
        return ctc_loss

    def forward_cr_ctc(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        targets: torch.Tensor,
        target_lengths: torch.Tensor,
        reduction: str = "sum",
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute CTC loss with consistency regularization loss.
        Args:
          encoder_out:
            Encoder output, of shape (2 * N, T, C).
          encoder_out_lens:
            Encoder output lengths, of shape (2 * N,).
          targets:
            Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
            to be un-padded and concatenated within 1 dimension.
        """
        # Compute CTC loss
        ctc_output = self.ctc_output(encoder_out)  # (2 * N, T, C)
        ctc_loss = torch.nn.functional.ctc_loss(
            log_probs=ctc_output.permute(1, 0, 2),  # (T, 2 * N, C)
            targets=targets.cpu(),
            input_lengths=encoder_out_lens.cpu(),
            target_lengths=target_lengths.cpu(),
            reduction=reduction,
        )

        # Compute consistency regularization loss
        batch_size = ctc_output.shape[0]
        assert batch_size % 2 == 0, batch_size
        # exchange: [x1, x2] -> [x2, x1]
        exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
        cr_loss = torch.nn.functional.kl_div(
            input=ctc_output,
            target=exchanged_targets,
            reduction="none",
            log_target=True,
        )  # (2 * N, T, C)
        length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
        cr_loss = cr_loss.masked_fill(length_mask, 0.0)

        if reduction == "sum":
            cr_loss = cr_loss.sum()
        elif reduction == "mean":
            cr_loss = cr_loss.mean()

        return ctc_loss, cr_loss

    def forward_transducer(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        y: List[List[int]],
        prune_range: int = 5,
        am_scale: float = 0.0,
        lm_scale: float = 0.0,
        reduction: str = "sum",
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute Transducer loss.
        Args:
          encoder_out:
            Encoder output, of shape (N, T, C).
          encoder_out_lens:
            Encoder output lengths, of shape (N,).
          y:
            A list of token id list. It contains labels of each utterance.
          prune_range:
            The prune range for rnnt loss, it means how many symbols(context)
            we are considering for each frame to compute the loss.
          am_scale:
            The scale to smooth the loss with am (output of encoder network) part.
          lm_scale:
            The scale to smooth the loss with lm (output of predictor network) part.
        """
        # Now for the decoder, i.e., the prediction network
        blank_id = self.blank_id
        # sos_y_padded: [B, S + 1], start with SOS.
        sos_y_padded, _ = pad_sequences(
            y, padding_value=blank_id, sos_id=blank_id, device=encoder_out.device
        )
        # decoder_out: [B, S + 1, decoder_dim]
        decoder_out = self.decoder(sos_y_padded)

        # Note: y does not start with SOS
        # y_padded : [B, S]
        y_padded, y_lens = pad_sequences(y, padding_value=0, device=encoder_out.device)

        boundary = torch.zeros(
            (encoder_out.size(0), 4),
            dtype=torch.int64,
            device=encoder_out.device,
        )
        boundary[:, 2] = y_lens
        boundary[:, 3] = encoder_out_lens

        lm = self.simple_lm_proj(decoder_out)
        am = self.simple_am_proj(encoder_out)

        with torch_autocast(enabled=False):
            simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
                lm=lm.float(),
                am=am.float(),
                symbols=y_padded,
                termination_symbol=blank_id,
                lm_only_scale=lm_scale,
                am_only_scale=am_scale,
                boundary=boundary,
                reduction=reduction,
                return_grad=True,
            )

        # ranges : [B, T, prune_range]
        ranges = k2.get_rnnt_prune_ranges(
            px_grad=px_grad,
            py_grad=py_grad,
            boundary=boundary,
            s_range=prune_range,
        )

        # am_pruned : [B, T, prune_range, encoder_dim]
        # lm_pruned : [B, T, prune_range, decoder_dim]
        am_pruned, lm_pruned = k2.do_rnnt_pruning(
            am=self.joiner.encoder_proj(encoder_out),
            lm=self.joiner.decoder_proj(decoder_out),
            ranges=ranges,
        )

        # logits : [B, T, prune_range, vocab_size]
        # project_input=False since we applied the decoder's input projections
        # prior to do_rnnt_pruning (this is an optimization for speed).
        logits = self.joiner(am_pruned, lm_pruned, project_input=False)

        with torch_autocast(enabled=False):
            pruned_loss = k2.rnnt_loss_pruned(
                logits=logits.float(),
                symbols=y_padded,
                ranges=ranges,
                termination_symbol=blank_id,
                boundary=boundary,
                reduction=reduction,
            )

        return simple_loss, pruned_loss

    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: List[List[int]],
        prune_range: int = 5,
        am_scale: float = 0.0,
        lm_scale: float = 0.0,
        use_cr_ctc: bool = False,
        use_spec_aug: bool = False,
        spec_augment: Optional[SpecAugment] = None,
        supervision_segments: Optional[torch.Tensor] = None,
        time_warp_factor: Optional[int] = 80,
        reduction: str = "sum",
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
          x:
            A 3-D tensor of shape (N, T, C).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.
          y:
            A list of token id list. It contains labels of each utterance.
          prune_range:
            The prune range for rnnt loss, it means how many symbols(context)
            we are considering for each frame to compute the loss.
          am_scale:
            The scale to smooth the loss with am (output of encoder network)
            part
          lm_scale:
            The scale to smooth the loss with lm (output of predictor network)
            part
          use_cr_ctc:
            Whether use consistency-regularized CTC.
          use_spec_aug:
            Whether apply spec-augment manually, used only if use_cr_ctc is True.
          spec_augment:
            The SpecAugment instance that returns time masks,
            used only if use_cr_ctc is True.
          supervision_segments:
            An int tensor of shape ``(S, 3)``. ``S`` is the number of
            supervision segments that exist in ``features``.
            Used only if use_cr_ctc is True.
          time_warp_factor:
            Parameter for the time warping; larger values mean more warping.
            Set to ``None``, or less than ``1``, to disable.
            Used only if use_cr_ctc is True.

        Returns:
          Return the transducer losses, CTC loss, AED loss,
          and consistency-regularization loss in form of
          (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss)

        Note:
           Regarding am_scale & lm_scale, it will make the loss-function one of
           the form:
              lm_scale * lm_probs + am_scale * am_probs +
              (1-lm_scale-am_scale) * combined_probs
        """
        assert x.ndim == 3, x.shape
        assert x_lens.ndim == 1, x_lens.shape

        assert x.size(0) == x_lens.size(0) == len(y), (x.shape, x_lens.shape, len(y))

        if use_cr_ctc:
            assert self.use_ctc
            if use_spec_aug:
                assert spec_augment is not None and spec_augment.time_warp_factor < 1
                # Apply time warping before input duplicating
                assert supervision_segments is not None
                x = time_warp(
                    x,
                    time_warp_factor=time_warp_factor,
                    supervision_segments=supervision_segments,
                )
                # Independently apply frequency masking and time masking to the two copies
                x = spec_augment(x.repeat(2, 1, 1))
            else:
                x = x.repeat(2, 1, 1)
            x_lens = x_lens.repeat(2)
            y += y

        # Compute encoder outputs
        encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)

        if self.use_transducer:
            # Compute transducer loss
            simple_loss, pruned_loss = self.forward_transducer(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                y=y,
                prune_range=prune_range,
                am_scale=am_scale,
                lm_scale=lm_scale,
                reduction=reduction,
            )
            if use_cr_ctc:
                simple_loss = simple_loss * 0.5
                pruned_loss = pruned_loss * 0.5
        else:
            simple_loss = torch.empty(0)
            pruned_loss = torch.empty(0)

        if self.use_ctc:
            # Compute CTC loss
            targets, target_length = pad_sequences(
                y, padding_value=0, device=encoder_out.device
            )
            if not use_cr_ctc:
                ctc_loss = self.forward_ctc(
                    encoder_out=encoder_out,
                    encoder_out_lens=encoder_out_lens,
                    targets=targets,
                    target_lengths=target_length,
                    reduction=reduction,
                )
                cr_loss = torch.empty(0)
            else:
                ctc_loss, cr_loss = self.forward_cr_ctc(
                    encoder_out=encoder_out,
                    encoder_out_lens=encoder_out_lens,
                    targets=targets,
                    target_lengths=target_length,
                    reduction=reduction,
                )
                ctc_loss = ctc_loss * 0.5
                cr_loss = cr_loss * 0.5
        else:
            ctc_loss = torch.empty(0)
            cr_loss = torch.empty(0)

        if self.use_attention_decoder:
            attention_decoder_loss = self.attention_decoder.calc_att_loss(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                ys=y,
                reduction=reduction,
            )
            if use_cr_ctc:
                attention_decoder_loss = attention_decoder_loss * 0.5
        else:
            attention_decoder_loss = torch.empty(0)

        return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss

__init__(feature_dim=80, downsampling_factor=(2, 4), encoder_dim=384, num_encoder_layers=4, encoder_unmasked_dim=256, query_head_dim=24, pos_head_dim=4, value_head_dim=12, num_heads=8, feedforward_dim=1536, cnn_module_kernel=31, pos_dim=192, dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, causal=False, chunk_size=[-1], left_context_frames=[-1], use_ctc=False, blank_id=0, vocab_size=500, use_transducer=True, decoder_dim=512, context_size=2, joiner_dim=512, use_attention_decoder=False, attention_decoder_dim=512, attention_decoder_num_layers=2, attention_decoder_attention_dim=512, attention_decoder_num_heads=8, attention_decoder_feedforward_dim=2048, sos_id=1, eos_id=2, ignore_id=-100, label_smoothing=0.0)

A joint CTC & Transducer ASR model.

  • Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
  • Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
  • Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)

参数:

名称 类型 描述 默认
encoder_embed

It is a Convolutional 2D subsampling module. It converts an input of shape (N, T, idim) to an output of of shape (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.

必需
encoder

It is the transcription network in the paper. Its accepts two inputs: x of (N, T, encoder_dim) and x_lens of shape (N,). It returns two tensors: logits of shape (N, T, encoder_dim) and logit_lens of shape (N,).

必需
decoder

It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, decoder_dim). It should contain one attribute: blank_id. It is used when use_transducer is True.

必需
joiner

It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its output shape is (N, T, U, vocab_size). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. It is used when use_transducer is True.

必需
use_transducer bool

Whether use transducer head. Default: True.

True
use_ctc bool

Whether use CTC head. Default: False.

False
use_attention_decoder bool

Whether use attention-decoder head. Default: False.

False
源代码位于: zipformer/modules/model.py
def __init__(
    self,
    feature_dim: int = 80,
    downsampling_factor: Tuple[int] = (2, 4),
    encoder_dim: Union[int, Tuple[int]] = 384,
    num_encoder_layers: Union[int, Tuple[int]] = 4,
    encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
    query_head_dim: Union[int, Tuple[int]] = 24,
    pos_head_dim: Union[int, Tuple[int]] = 4,
    value_head_dim: Union[int, Tuple[int]] = 12,
    num_heads: Union[int, Tuple[int]] = 8,
    feedforward_dim: Union[int, Tuple[int]] = 1536,
    cnn_module_kernel: Union[int, Tuple[int]] = 31,
    pos_dim: int = 192,
    dropout: FloatLike = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
    warmup_batches: float = 4000.0,
    causal: bool = False,
    chunk_size: Tuple[int] = [-1],
    left_context_frames: Tuple[int] = [-1],
    use_ctc: bool = False,
    blank_id: int = 0,
    vocab_size: int = 500,
    use_transducer: bool = True,
    decoder_dim: int = 512,
    context_size: int = 2,
    joiner_dim: int = 512,
    use_attention_decoder: bool = False,
    attention_decoder_dim: int = 512,
    attention_decoder_num_layers: int = 2,
    attention_decoder_attention_dim: int = 512,
    attention_decoder_num_heads: int = 8,
    attention_decoder_feedforward_dim: int = 2048,
    sos_id: int = 1,
    eos_id: int = 2,
    ignore_id: int = -100,
    label_smoothing: float = 0.0,
):
    """A joint CTC & Transducer ASR model.

    - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
    - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
    - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)

    Args:
      encoder_embed:
        It is a Convolutional 2D subsampling module. It converts
        an input of shape (N, T, idim) to an output of of shape
        (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
      encoder:
        It is the transcription network in the paper. Its accepts
        two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
        It returns two tensors: `logits` of shape (N, T, encoder_dim) and
        `logit_lens` of shape (N,).
      decoder:
        It is the prediction network in the paper. Its input shape
        is (N, U) and its output shape is (N, U, decoder_dim).
        It should contain one attribute: `blank_id`.
        It is used when use_transducer is True.
      joiner:
        It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
        Its output shape is (N, T, U, vocab_size). Note that its output contains
        unnormalized probs, i.e., not processed by log-softmax.
        It is used when use_transducer is True.
      use_transducer:
        Whether use transducer head. Default: True.
      use_ctc:
        Whether use CTC head. Default: False.
      use_attention_decoder:
        Whether use attention-decoder head. Default: False.
    """
    super().__init__()

    assert use_transducer or use_ctc, (
        f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
    )

    self.blank_id = blank_id
    self.vocab_size = vocab_size

    # encoder_embed converts the input of shape (N, T, num_features)
    # to the shape (N, (T - 7) // 2, encoder_dims).
    # That is, it does two things simultaneously:
    #   (1) subsampling: T -> (T - 7) // 2
    #   (2) embedding: num_features -> encoder_dims
    # In the normal configuration, we will downsample once more at the end
    # by a factor of 2, and most of the encoder stacks will run at a lower
    # sampling rate.
    self.encoder_embed = Conv2dSubsampling(
        in_channels=feature_dim,
        out_channels=_to_int_tuple(encoder_dim)[0],
        dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
    )

    self.encoder = Zipformer(
        output_downsampling_factor=2,
        downsampling_factor=_to_int_tuple(downsampling_factor),
        num_encoder_layers=_to_int_tuple(num_encoder_layers),
        encoder_dim=_to_int_tuple(encoder_dim),
        encoder_unmasked_dim=_to_int_tuple(encoder_unmasked_dim),
        query_head_dim=_to_int_tuple(query_head_dim),
        pos_head_dim=_to_int_tuple(pos_head_dim),
        value_head_dim=_to_int_tuple(value_head_dim),
        pos_dim=pos_dim,
        num_heads=_to_int_tuple(num_heads),
        feedforward_dim=_to_int_tuple(feedforward_dim),
        cnn_module_kernel=_to_int_tuple(cnn_module_kernel),
        dropout=dropout,
        warmup_batches=warmup_batches,
        causal=causal,
        chunk_size=_to_int_tuple(chunk_size),
        left_context_frames=_to_int_tuple(left_context_frames),
    )

    self.use_transducer = use_transducer
    self.encoder_out_dim = max(_to_int_tuple(encoder_dim))
    if use_transducer:
        self.decoder = Decoder(
            vocab_size=vocab_size,
            decoder_dim=decoder_dim,
            blank_id=blank_id,
            context_size=context_size,
        )
        self.joiner = Joiner(
            encoder_dim=self.encoder_out_dim,
            decoder_dim=decoder_dim,
            joiner_dim=joiner_dim,
            vocab_size=vocab_size,
        )
        self.simple_am_proj = ScaledLinear(
            self.encoder_out_dim, vocab_size, initial_scale=0.25
        )
        self.simple_lm_proj = ScaledLinear(
            decoder_dim, vocab_size, initial_scale=0.25
        )
    else:
        self.decoder = None
        self.joiner = None

    self.use_ctc = use_ctc
    if use_ctc:
        self.ctc_output = torch.nn.Sequential(
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(self.encoder_out_dim, vocab_size),
            torch.nn.LogSoftmax(dim=-1),
        )
    else:
        self.ctc_output = None

    self.use_attention_decoder = use_attention_decoder
    if use_attention_decoder:
        self.attention_decoder = AttentionDecoderModel(
            vocab_size=vocab_size,
            decoder_dim=attention_decoder_dim,
            num_decoder_layers=attention_decoder_num_layers,
            attention_dim=attention_decoder_attention_dim,
            num_heads=attention_decoder_num_heads,
            feedforward_dim=attention_decoder_feedforward_dim,
            memory_dim=self.encoder_out_dim,
            sos_id=sos_id,
            eos_id=eos_id,
            ignore_id=ignore_id,
            label_smoothing=label_smoothing,
        )
    else:
        self.attention_decoder = None

forward_encoder(x, x_lens)

Compute encoder outputs. Args: x: A 3-D tensor of shape (N, T, C). x_lens: A 1-D tensor of shape (N,). It contains the number of frames in x before padding.

返回:

名称 类型 描述
encoder_out Tensor

Encoder output, of shape (N, T, C).

encoder_out_lens Tensor

Encoder output lengths, of shape (N,).

源代码位于: zipformer/modules/model.py
def forward_encoder(
    self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute encoder outputs.
    Args:
      x:
        A 3-D tensor of shape (N, T, C).
      x_lens:
        A 1-D tensor of shape (N,). It contains the number of frames in `x`
        before padding.

    Returns:
      encoder_out:
        Encoder output, of shape (N, T, C).
      encoder_out_lens:
        Encoder output lengths, of shape (N,).
    """
    # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
    x, x_lens = self.encoder_embed(x, x_lens)
    # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")

    src_key_padding_mask = make_pad_mask(x_lens)
    x = x.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)

    encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)

    encoder_out = encoder_out.permute(1, 0, 2)  # (T, N, C) ->(N, T, C)
    assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)

    return encoder_out, encoder_out_lens

forward_ctc(encoder_out, encoder_out_lens, targets, target_lengths, reduction='sum')

Compute CTC loss. Args: encoder_out: Encoder output, of shape (N, T, C). encoder_out_lens: Encoder output lengths, of shape (N,). targets: Target Tensor of shape (sum(target_lengths)). The targets are assumed to be un-padded and concatenated within 1 dimension.

源代码位于: zipformer/modules/model.py
def forward_ctc(
    self,
    encoder_out: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    targets: torch.Tensor,
    target_lengths: torch.Tensor,
    reduction: str = "sum",
) -> torch.Tensor:
    """Compute CTC loss.
    Args:
      encoder_out:
        Encoder output, of shape (N, T, C).
      encoder_out_lens:
        Encoder output lengths, of shape (N,).
      targets:
        Target Tensor of shape (sum(target_lengths)). The targets are assumed
        to be un-padded and concatenated within 1 dimension.
    """
    # Compute CTC log-prob
    ctc_output = self.ctc_output(encoder_out)  # (N, T, C)

    ctc_loss = torch.nn.functional.ctc_loss(
        log_probs=ctc_output.permute(1, 0, 2),  # (T, N, C)
        targets=targets.cpu(),
        input_lengths=encoder_out_lens.cpu(),
        target_lengths=target_lengths.cpu(),
        reduction=reduction,
    )
    return ctc_loss

forward_cr_ctc(encoder_out, encoder_out_lens, targets, target_lengths, reduction='sum')

Compute CTC loss with consistency regularization loss. Args: encoder_out: Encoder output, of shape (2 * N, T, C). encoder_out_lens: Encoder output lengths, of shape (2 * N,). targets: Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed to be un-padded and concatenated within 1 dimension.

源代码位于: zipformer/modules/model.py
def forward_cr_ctc(
    self,
    encoder_out: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    targets: torch.Tensor,
    target_lengths: torch.Tensor,
    reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute CTC loss with consistency regularization loss.
    Args:
      encoder_out:
        Encoder output, of shape (2 * N, T, C).
      encoder_out_lens:
        Encoder output lengths, of shape (2 * N,).
      targets:
        Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
        to be un-padded and concatenated within 1 dimension.
    """
    # Compute CTC loss
    ctc_output = self.ctc_output(encoder_out)  # (2 * N, T, C)
    ctc_loss = torch.nn.functional.ctc_loss(
        log_probs=ctc_output.permute(1, 0, 2),  # (T, 2 * N, C)
        targets=targets.cpu(),
        input_lengths=encoder_out_lens.cpu(),
        target_lengths=target_lengths.cpu(),
        reduction=reduction,
    )

    # Compute consistency regularization loss
    batch_size = ctc_output.shape[0]
    assert batch_size % 2 == 0, batch_size
    # exchange: [x1, x2] -> [x2, x1]
    exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
    cr_loss = torch.nn.functional.kl_div(
        input=ctc_output,
        target=exchanged_targets,
        reduction="none",
        log_target=True,
    )  # (2 * N, T, C)
    length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
    cr_loss = cr_loss.masked_fill(length_mask, 0.0)

    if reduction == "sum":
        cr_loss = cr_loss.sum()
    elif reduction == "mean":
        cr_loss = cr_loss.mean()

    return ctc_loss, cr_loss

forward_transducer(encoder_out, encoder_out_lens, y, prune_range=5, am_scale=0.0, lm_scale=0.0, reduction='sum')

Compute Transducer loss. Args: encoder_out: Encoder output, of shape (N, T, C). encoder_out_lens: Encoder output lengths, of shape (N,). y: A list of token id list. It contains labels of each utterance. prune_range: The prune range for rnnt loss, it means how many symbols(context) we are considering for each frame to compute the loss. am_scale: The scale to smooth the loss with am (output of encoder network) part. lm_scale: The scale to smooth the loss with lm (output of predictor network) part.

源代码位于: zipformer/modules/model.py
def forward_transducer(
    self,
    encoder_out: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    y: List[List[int]],
    prune_range: int = 5,
    am_scale: float = 0.0,
    lm_scale: float = 0.0,
    reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute Transducer loss.
    Args:
      encoder_out:
        Encoder output, of shape (N, T, C).
      encoder_out_lens:
        Encoder output lengths, of shape (N,).
      y:
        A list of token id list. It contains labels of each utterance.
      prune_range:
        The prune range for rnnt loss, it means how many symbols(context)
        we are considering for each frame to compute the loss.
      am_scale:
        The scale to smooth the loss with am (output of encoder network) part.
      lm_scale:
        The scale to smooth the loss with lm (output of predictor network) part.
    """
    # Now for the decoder, i.e., the prediction network
    blank_id = self.blank_id
    # sos_y_padded: [B, S + 1], start with SOS.
    sos_y_padded, _ = pad_sequences(
        y, padding_value=blank_id, sos_id=blank_id, device=encoder_out.device
    )
    # decoder_out: [B, S + 1, decoder_dim]
    decoder_out = self.decoder(sos_y_padded)

    # Note: y does not start with SOS
    # y_padded : [B, S]
    y_padded, y_lens = pad_sequences(y, padding_value=0, device=encoder_out.device)

    boundary = torch.zeros(
        (encoder_out.size(0), 4),
        dtype=torch.int64,
        device=encoder_out.device,
    )
    boundary[:, 2] = y_lens
    boundary[:, 3] = encoder_out_lens

    lm = self.simple_lm_proj(decoder_out)
    am = self.simple_am_proj(encoder_out)

    with torch_autocast(enabled=False):
        simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
            lm=lm.float(),
            am=am.float(),
            symbols=y_padded,
            termination_symbol=blank_id,
            lm_only_scale=lm_scale,
            am_only_scale=am_scale,
            boundary=boundary,
            reduction=reduction,
            return_grad=True,
        )

    # ranges : [B, T, prune_range]
    ranges = k2.get_rnnt_prune_ranges(
        px_grad=px_grad,
        py_grad=py_grad,
        boundary=boundary,
        s_range=prune_range,
    )

    # am_pruned : [B, T, prune_range, encoder_dim]
    # lm_pruned : [B, T, prune_range, decoder_dim]
    am_pruned, lm_pruned = k2.do_rnnt_pruning(
        am=self.joiner.encoder_proj(encoder_out),
        lm=self.joiner.decoder_proj(decoder_out),
        ranges=ranges,
    )

    # logits : [B, T, prune_range, vocab_size]
    # project_input=False since we applied the decoder's input projections
    # prior to do_rnnt_pruning (this is an optimization for speed).
    logits = self.joiner(am_pruned, lm_pruned, project_input=False)

    with torch_autocast(enabled=False):
        pruned_loss = k2.rnnt_loss_pruned(
            logits=logits.float(),
            symbols=y_padded,
            ranges=ranges,
            termination_symbol=blank_id,
            boundary=boundary,
            reduction=reduction,
        )

    return simple_loss, pruned_loss

forward(x, x_lens, y, prune_range=5, am_scale=0.0, lm_scale=0.0, use_cr_ctc=False, use_spec_aug=False, spec_augment=None, supervision_segments=None, time_warp_factor=80, reduction='sum')

参数:

名称 类型 描述 默认
x Tensor

A 3-D tensor of shape (N, T, C).

必需
x_lens Tensor

A 1-D tensor of shape (N,). It contains the number of frames in x before padding.

必需
y List[List[int]]

A list of token id list. It contains labels of each utterance.

必需
prune_range int

The prune range for rnnt loss, it means how many symbols(context) we are considering for each frame to compute the loss.

5
am_scale float

The scale to smooth the loss with am (output of encoder network) part

0.0
lm_scale float

The scale to smooth the loss with lm (output of predictor network) part

0.0
use_cr_ctc bool

Whether use consistency-regularized CTC.

False
use_spec_aug bool

Whether apply spec-augment manually, used only if use_cr_ctc is True.

False
spec_augment Optional[SpecAugment]

The SpecAugment instance that returns time masks, used only if use_cr_ctc is True.

None
supervision_segments Optional[Tensor]

An int tensor of shape (S, 3). S is the number of supervision segments that exist in features. Used only if use_cr_ctc is True.

None
time_warp_factor Optional[int]

Parameter for the time warping; larger values mean more warping. Set to None, or less than 1, to disable. Used only if use_cr_ctc is True.

80

返回:

类型 描述
Tensor

Return the transducer losses, CTC loss, AED loss,

Tensor

and consistency-regularization loss in form of

Tensor

(simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss)

Note

Regarding am_scale & lm_scale, it will make the loss-function one of the form: lm_scale * lm_probs + am_scale * am_probs + (1-lm_scale-am_scale) * combined_probs

源代码位于: zipformer/modules/model.py
def forward(
    self,
    x: torch.Tensor,
    x_lens: torch.Tensor,
    y: List[List[int]],
    prune_range: int = 5,
    am_scale: float = 0.0,
    lm_scale: float = 0.0,
    use_cr_ctc: bool = False,
    use_spec_aug: bool = False,
    spec_augment: Optional[SpecAugment] = None,
    supervision_segments: Optional[torch.Tensor] = None,
    time_warp_factor: Optional[int] = 80,
    reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Args:
      x:
        A 3-D tensor of shape (N, T, C).
      x_lens:
        A 1-D tensor of shape (N,). It contains the number of frames in `x`
        before padding.
      y:
        A list of token id list. It contains labels of each utterance.
      prune_range:
        The prune range for rnnt loss, it means how many symbols(context)
        we are considering for each frame to compute the loss.
      am_scale:
        The scale to smooth the loss with am (output of encoder network)
        part
      lm_scale:
        The scale to smooth the loss with lm (output of predictor network)
        part
      use_cr_ctc:
        Whether use consistency-regularized CTC.
      use_spec_aug:
        Whether apply spec-augment manually, used only if use_cr_ctc is True.
      spec_augment:
        The SpecAugment instance that returns time masks,
        used only if use_cr_ctc is True.
      supervision_segments:
        An int tensor of shape ``(S, 3)``. ``S`` is the number of
        supervision segments that exist in ``features``.
        Used only if use_cr_ctc is True.
      time_warp_factor:
        Parameter for the time warping; larger values mean more warping.
        Set to ``None``, or less than ``1``, to disable.
        Used only if use_cr_ctc is True.

    Returns:
      Return the transducer losses, CTC loss, AED loss,
      and consistency-regularization loss in form of
      (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss)

    Note:
       Regarding am_scale & lm_scale, it will make the loss-function one of
       the form:
          lm_scale * lm_probs + am_scale * am_probs +
          (1-lm_scale-am_scale) * combined_probs
    """
    assert x.ndim == 3, x.shape
    assert x_lens.ndim == 1, x_lens.shape

    assert x.size(0) == x_lens.size(0) == len(y), (x.shape, x_lens.shape, len(y))

    if use_cr_ctc:
        assert self.use_ctc
        if use_spec_aug:
            assert spec_augment is not None and spec_augment.time_warp_factor < 1
            # Apply time warping before input duplicating
            assert supervision_segments is not None
            x = time_warp(
                x,
                time_warp_factor=time_warp_factor,
                supervision_segments=supervision_segments,
            )
            # Independently apply frequency masking and time masking to the two copies
            x = spec_augment(x.repeat(2, 1, 1))
        else:
            x = x.repeat(2, 1, 1)
        x_lens = x_lens.repeat(2)
        y += y

    # Compute encoder outputs
    encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)

    if self.use_transducer:
        # Compute transducer loss
        simple_loss, pruned_loss = self.forward_transducer(
            encoder_out=encoder_out,
            encoder_out_lens=encoder_out_lens,
            y=y,
            prune_range=prune_range,
            am_scale=am_scale,
            lm_scale=lm_scale,
            reduction=reduction,
        )
        if use_cr_ctc:
            simple_loss = simple_loss * 0.5
            pruned_loss = pruned_loss * 0.5
    else:
        simple_loss = torch.empty(0)
        pruned_loss = torch.empty(0)

    if self.use_ctc:
        # Compute CTC loss
        targets, target_length = pad_sequences(
            y, padding_value=0, device=encoder_out.device
        )
        if not use_cr_ctc:
            ctc_loss = self.forward_ctc(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                targets=targets,
                target_lengths=target_length,
                reduction=reduction,
            )
            cr_loss = torch.empty(0)
        else:
            ctc_loss, cr_loss = self.forward_cr_ctc(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                targets=targets,
                target_lengths=target_length,
                reduction=reduction,
            )
            ctc_loss = ctc_loss * 0.5
            cr_loss = cr_loss * 0.5
    else:
        ctc_loss = torch.empty(0)
        cr_loss = torch.empty(0)

    if self.use_attention_decoder:
        attention_decoder_loss = self.attention_decoder.calc_att_loss(
            encoder_out=encoder_out,
            encoder_out_lens=encoder_out_lens,
            ys=y,
            reduction=reduction,
        )
        if use_cr_ctc:
            attention_decoder_loss = attention_decoder_loss * 0.5
    else:
        attention_decoder_loss = torch.empty(0)

    return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss

EncoderWrapper

Bases: Module

A wrapper for encoder and encoder_embed (non-streaming JIT export).

源代码位于: zipformer/modules/model.py
class EncoderWrapper(torch.nn.Module):
    """A wrapper for encoder and encoder_embed (non-streaming JIT export)."""

    def __init__(
        self, encoder: torch.nn.Module, encoder_embed: torch.nn.Module
    ) -> None:
        super().__init__()
        self.encoder = encoder
        self.encoder_embed = encoder_embed

    def forward(
        self, features: torch.Tensor, feature_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x, x_lens = self.encoder_embed(features, feature_lengths)
        src_key_padding_mask = make_pad_mask(x_lens)
        x = x.permute(1, 0, 2)
        encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
        encoder_out = encoder_out.permute(1, 0, 2)
        return encoder_out, encoder_out_lens

StreamingEncoderWrapper

Bases: Module

A wrapper for encoder and encoder_embed (streaming JIT export).

源代码位于: zipformer/modules/model.py
class StreamingEncoderWrapper(torch.nn.Module):
    """A wrapper for encoder and encoder_embed (streaming JIT export)."""

    def __init__(
        self, encoder: torch.nn.Module, encoder_embed: torch.nn.Module
    ) -> None:
        super().__init__()
        assert len(encoder.chunk_size) == 1, encoder.chunk_size
        assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
        self.chunk_size = encoder.chunk_size[0]
        self.left_context_len = encoder.left_context_frames[0]
        self.pad_length = 7 + 2 * 3
        self.encoder = encoder
        self.encoder_embed = encoder_embed

    def forward(
        self,
        features: torch.Tensor,
        feature_lengths: torch.Tensor,
        states: List[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        chunk_size = self.chunk_size
        left_context_len = self.left_context_len

        cached_embed_left_pad = states[-2]
        x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
            x=features,
            x_lens=feature_lengths,
            cached_left_pad=cached_embed_left_pad,
        )
        assert x.size(1) == chunk_size, (x.size(1), chunk_size)

        src_key_padding_mask = make_pad_mask(x_lens)

        processed_mask = torch.arange(left_context_len, device=x.device).expand(
            x.size(0), left_context_len
        )
        processed_lens = states[-1]
        processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
        new_processed_lens = processed_lens + x_lens

        src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)

        x = x.permute(1, 0, 2)
        encoder_states = states[:-2]

        (
            encoder_out,
            encoder_out_lens,
            new_encoder_states,
        ) = self.encoder.streaming_forward(
            x=x,
            x_lens=x_lens,
            states=encoder_states,
            src_key_padding_mask=src_key_padding_mask,
        )
        encoder_out = encoder_out.permute(1, 0, 2)

        new_states = new_encoder_states + [
            new_cached_embed_left_pad,
            new_processed_lens,
        ]
        return encoder_out, encoder_out_lens, new_states

    @torch.jit.export
    def get_init_states(
        self,
        batch_size: int = 1,
        device: torch.device = torch.device("cpu"),
    ) -> List[torch.Tensor]:
        states = self.encoder.get_init_states(batch_size, device)
        embed_states = self.encoder_embed.get_init_states(batch_size, device)
        states.append(embed_states)
        processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
        states.append(processed_lens)
        return states

OnnxEncoderWrapper

Bases: Module

A wrapper for Zipformer and the encoder_proj from the joiner (non-streaming).

源代码位于: zipformer/modules/model.py
class OnnxEncoderWrapper(torch.nn.Module):
    """A wrapper for Zipformer and the encoder_proj from the joiner (non-streaming)."""

    def __init__(self, encoder, encoder_embed, encoder_proj):
        super().__init__()
        self.encoder = encoder
        self.encoder_embed = encoder_embed
        self.encoder_proj = encoder_proj

    def forward(
        self, x: torch.Tensor, x_lens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x, x_lens = self.encoder_embed(x, x_lens)
        src_key_padding_mask = make_pad_mask(x_lens, x.shape[1])
        x = x.permute(1, 0, 2)
        encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
        encoder_out = encoder_out.permute(1, 0, 2)
        encoder_out = self.encoder_proj(encoder_out)
        return encoder_out, encoder_out_lens

OnnxDecoderWrapper

Bases: Module

A wrapper for Decoder and the decoder_proj from the joiner.

源代码位于: zipformer/modules/model.py
class OnnxDecoderWrapper(torch.nn.Module):
    """A wrapper for Decoder and the decoder_proj from the joiner."""

    def __init__(self, decoder, decoder_proj):
        super().__init__()
        self.decoder = decoder
        self.decoder_proj = decoder_proj

    def forward(self, y: torch.Tensor) -> torch.Tensor:
        need_pad = False
        decoder_output = self.decoder(y, need_pad=need_pad)
        decoder_output = decoder_output.squeeze(1)
        output = self.decoder_proj(decoder_output)
        return output

OnnxJoinerWrapper

Bases: Module

A wrapper for the joiner.

源代码位于: zipformer/modules/model.py
class OnnxJoinerWrapper(torch.nn.Module):
    """A wrapper for the joiner."""

    def __init__(self, output_linear):
        super().__init__()
        self.output_linear = output_linear

    def forward(
        self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
    ) -> torch.Tensor:
        logit = encoder_out + decoder_out
        logit = self.output_linear(torch.tanh(logit))
        return logit

OnnxCtcWrapper

Bases: Module

A wrapper for encoder_embed, Zipformer, and ctc_output layer (non-streaming).

源代码位于: zipformer/modules/model.py
class OnnxCtcWrapper(torch.nn.Module):
    """A wrapper for encoder_embed, Zipformer, and ctc_output layer (non-streaming)."""

    def __init__(self, encoder, encoder_embed, ctc_output):
        super().__init__()
        self.encoder = encoder
        self.encoder_embed = encoder_embed
        self.ctc_output = ctc_output

    def forward(
        self, x: torch.Tensor, x_lens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x, x_lens = self.encoder_embed(x, x_lens)
        src_key_padding_mask = make_pad_mask(x_lens)
        x = x.permute(1, 0, 2)
        encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask)
        encoder_out = encoder_out.permute(1, 0, 2)
        log_probs = self.ctc_output(encoder_out)
        return log_probs, log_probs_len

OnnxStreamingEncoderWrapper

Bases: Module

A wrapper for Zipformer and the encoder_proj from the joiner (streaming).

源代码位于: zipformer/modules/model.py
class OnnxStreamingEncoderWrapper(torch.nn.Module):
    """A wrapper for Zipformer and the encoder_proj from the joiner (streaming)."""

    def __init__(self, encoder, encoder_embed, encoder_proj):
        super().__init__()
        self.encoder = encoder
        self.encoder_embed = encoder_embed
        self.encoder_proj = encoder_proj
        self.chunk_size = encoder.chunk_size[0]
        self.left_context_len = encoder.left_context_frames[0]
        self.pad_length = 7 + 2 * 3

    def forward(
        self, x: torch.Tensor, states: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        N = x.size(0)
        T = self.chunk_size * 2 + self.pad_length
        x_lens = torch.tensor([T] * N, device=x.device)
        left_context_len = self.left_context_len

        cached_embed_left_pad = states[-2]
        x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
            x=x,
            x_lens=x_lens,
            cached_left_pad=cached_embed_left_pad,
        )
        assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)

        src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)

        processed_mask = torch.arange(left_context_len, device=x.device).expand(
            x.size(0), left_context_len
        )
        processed_lens = states[-1]
        processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
        new_processed_lens = processed_lens + x_lens
        src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)

        x = x.permute(1, 0, 2)
        encoder_states = states[:-2]
        logging.info(f"len_encoder_states={len(encoder_states)}")
        (
            encoder_out,
            encoder_out_lens,
            new_encoder_states,
        ) = self.encoder.streaming_forward(
            x=x,
            x_lens=x_lens,
            states=encoder_states,
            src_key_padding_mask=src_key_padding_mask,
        )
        encoder_out = encoder_out.permute(1, 0, 2)
        encoder_out = self.encoder_proj(encoder_out)

        new_states = new_encoder_states + [
            new_cached_embed_left_pad,
            new_processed_lens,
        ]
        return encoder_out, new_states

    def get_init_states(
        self,
        batch_size: int = 1,
        device: torch.device = torch.device("cpu"),
    ) -> List[torch.Tensor]:
        states = self.encoder.get_init_states(batch_size, device)
        embed_states = self.encoder_embed.get_init_states(batch_size, device)
        states.append(embed_states)
        processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
        states.append(processed_lens)
        return states

OnnxStreamingCtcWrapper

Bases: Module

A wrapper for Zipformer and the ctc_head (streaming).

源代码位于: zipformer/modules/model.py
class OnnxStreamingCtcWrapper(torch.nn.Module):
    """A wrapper for Zipformer and the ctc_head (streaming)."""

    def __init__(
        self,
        encoder: torch.nn.Module,
        encoder_embed: torch.nn.Module,
        ctc_output: torch.nn.Module,
    ):
        """
        Args:
          encoder:
            A Zipformer encoder.
          encoder_proj:
            The projection layer for encoder from the joiner.
          ctc_output:
            The ctc head.
        """
        super().__init__()
        self.encoder = encoder
        self.encoder_embed = encoder_embed
        self.ctc_output = ctc_output
        self.chunk_size = encoder.chunk_size[0]
        self.left_context_len = encoder.left_context_frames[0]
        self.pad_length = 7 + 2 * 3

    def forward(
        self,
        x: torch.Tensor,
        states: List[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        N = x.size(0)
        T = self.chunk_size * 2 + self.pad_length
        x_lens = torch.tensor([T] * N, device=x.device)
        left_context_len = self.left_context_len

        cached_embed_left_pad = states[-2]
        x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
            x=x,
            x_lens=x_lens,
            cached_left_pad=cached_embed_left_pad,
        )
        assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)

        src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)

        # processed_mask is used to mask out initial states
        processed_mask = torch.arange(left_context_len, device=x.device).expand(
            x.size(0), left_context_len
        )
        processed_lens = states[-1]  # (batch,)
        # (batch, left_context_size)
        processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
        # Update processed lengths
        new_processed_lens = processed_lens + x_lens
        # (batch, left_context_size + chunk_size)
        src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)

        x = x.permute(1, 0, 2)
        encoder_states = states[:-2]
        logging.info(f"len_encoder_states={len(encoder_states)}")
        (
            encoder_out,
            encoder_out_lens,
            new_encoder_states,
        ) = self.encoder.streaming_forward(
            x=x,
            x_lens=x_lens,
            states=encoder_states,
            src_key_padding_mask=src_key_padding_mask,
        )
        encoder_out = encoder_out.permute(1, 0, 2)
        encoder_out = self.ctc_output(encoder_out)
        # Now encoder_out is of shape (N, T, ctc_output_dim)

        new_states = new_encoder_states + [
            new_cached_embed_left_pad,
            new_processed_lens,
        ]

        return encoder_out, new_states

    def get_init_states(
        self,
        batch_size: int = 1,
        device: torch.device = torch.device("cpu"),
    ) -> List[torch.Tensor]:
        """
        Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
        is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
        states[-2] is the cached left padding for ConvNeXt module,
        of shape (batch_size, num_channels, left_pad, num_freqs)
        states[-1] is processed_lens of shape (batch,), which records the number
        of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
        """
        states = self.encoder.get_init_states(batch_size, device)

        embed_states = self.encoder_embed.get_init_states(batch_size, device)

        states.append(embed_states)

        processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
        states.append(processed_lens)

        return states

__init__(encoder, encoder_embed, ctc_output)

参数:

名称 类型 描述 默认
encoder Module

A Zipformer encoder.

必需
encoder_proj

The projection layer for encoder from the joiner.

必需
ctc_output Module

The ctc head.

必需
源代码位于: zipformer/modules/model.py
def __init__(
    self,
    encoder: torch.nn.Module,
    encoder_embed: torch.nn.Module,
    ctc_output: torch.nn.Module,
):
    """
    Args:
      encoder:
        A Zipformer encoder.
      encoder_proj:
        The projection layer for encoder from the joiner.
      ctc_output:
        The ctc head.
    """
    super().__init__()
    self.encoder = encoder
    self.encoder_embed = encoder_embed
    self.ctc_output = ctc_output
    self.chunk_size = encoder.chunk_size[0]
    self.left_context_len = encoder.left_context_frames[0]
    self.pad_length = 7 + 2 * 3

get_init_states(batch_size=1, device=torch.device('cpu'))

Returns a list of cached tensors of all encoder layers. For layer-i, states[i6:(i+1)6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). states[-2] is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) states[-1] is processed_lens of shape (batch,), which records the number of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.

源代码位于: zipformer/modules/model.py
def get_init_states(
    self,
    batch_size: int = 1,
    device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
    """
    Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
    is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
    states[-2] is the cached left padding for ConvNeXt module,
    of shape (batch_size, num_channels, left_pad, num_freqs)
    states[-1] is processed_lens of shape (batch,), which records the number
    of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
    """
    states = self.encoder.get_init_states(batch_size, device)

    embed_states = self.encoder_embed.get_init_states(batch_size, device)

    states.append(embed_states)

    processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
    states.append(processed_lens)

    return states

OnnxTransducerModel

Non-streaming ONNX transducer model (encoder + decoder + joiner).

源代码位于: zipformer/modules/model.py
class OnnxTransducerModel:
    """Non-streaming ONNX transducer model (encoder + decoder + joiner)."""

    def __init__(self, encoder_filename, decoder_filename, joiner_filename):
        import onnxruntime as ort

        session_opts = ort.SessionOptions()
        session_opts.inter_op_num_threads = 1
        session_opts.intra_op_num_threads = 4

        self.encoder = ort.InferenceSession(
            encoder_filename,
            sess_options=session_opts,
            providers=["CPUExecutionProvider"],
        )
        self.decoder = ort.InferenceSession(
            decoder_filename,
            sess_options=session_opts,
            providers=["CPUExecutionProvider"],
        )
        self.joiner = ort.InferenceSession(
            joiner_filename,
            sess_options=session_opts,
            providers=["CPUExecutionProvider"],
        )

        decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
        self.context_size = int(decoder_meta["context_size"])
        self.vocab_size = int(decoder_meta["vocab_size"])

    def run_encoder(self, x, x_lens):
        out = self.encoder.run(
            [self.encoder.get_outputs()[0].name, self.encoder.get_outputs()[1].name],
            {
                self.encoder.get_inputs()[0].name: x.numpy(),
                self.encoder.get_inputs()[1].name: x_lens.numpy(),
            },
        )
        return torch.from_numpy(out[0]), torch.from_numpy(out[1])

    def run_decoder(self, decoder_input):
        out = self.decoder.run(
            [self.decoder.get_outputs()[0].name],
            {self.decoder.get_inputs()[0].name: decoder_input.numpy()},
        )[0]
        return torch.from_numpy(out)

    def run_joiner(self, encoder_out, decoder_out):
        out = self.joiner.run(
            [self.joiner.get_outputs()[0].name],
            {
                self.joiner.get_inputs()[0].name: encoder_out.numpy(),
                self.joiner.get_inputs()[1].name: decoder_out.numpy(),
            },
        )[0]
        return torch.from_numpy(out)

OnnxCtcModel

Non-streaming ONNX CTC model.

源代码位于: zipformer/modules/model.py
class OnnxCtcModel:
    """Non-streaming ONNX CTC model."""

    def __init__(self, nn_model):
        import onnxruntime as ort

        session_opts = ort.SessionOptions()
        session_opts.inter_op_num_threads = 1
        session_opts.intra_op_num_threads = 1

        self.model = ort.InferenceSession(
            nn_model,
            sess_options=session_opts,
            providers=["CPUExecutionProvider"],
        )

    def __call__(self, x, x_lens):
        out = self.model.run(
            [self.model.get_outputs()[0].name, self.model.get_outputs()[1].name],
            {
                self.model.get_inputs()[0].name: x.numpy(),
                self.model.get_inputs()[1].name: x_lens.numpy(),
            },
        )
        return torch.from_numpy(out[0]), torch.from_numpy(out[1])

OnnxStreamingTransducerModel

Streaming ONNX transducer model with state management.

源代码位于: zipformer/modules/model.py
class OnnxStreamingTransducerModel:
    """Streaming ONNX transducer model with state management."""

    def __init__(self, encoder_filename, decoder_filename, joiner_filename):
        import onnxruntime as ort

        session_opts = ort.SessionOptions()
        session_opts.inter_op_num_threads = 1
        session_opts.intra_op_num_threads = 1

        self.encoder = ort.InferenceSession(
            encoder_filename,
            sess_options=session_opts,
            providers=["CPUExecutionProvider"],
        )
        self.decoder = ort.InferenceSession(
            decoder_filename,
            sess_options=session_opts,
            providers=["CPUExecutionProvider"],
        )
        self.joiner = ort.InferenceSession(
            joiner_filename,
            sess_options=session_opts,
            providers=["CPUExecutionProvider"],
        )

        decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
        self.context_size = int(decoder_meta["context_size"])
        self.vocab_size = int(decoder_meta["vocab_size"])

        self._init_encoder_states()

    def _init_encoder_states(self, batch_size=1):
        meta = self.encoder.get_modelmeta().custom_metadata_map
        self.segment = int(meta["T"])
        self.offset = int(meta["decode_chunk_len"])

        def to_int_list(s):
            return list(map(int, s.split(",")))

        num_encoder_layers = to_int_list(meta["num_encoder_layers"])
        encoder_dims = to_int_list(meta["encoder_dims"])
        cnn_module_kernels = to_int_list(meta["cnn_module_kernels"])
        left_context_len = to_int_list(meta["left_context_len"])
        query_head_dims = to_int_list(meta["query_head_dims"])
        value_head_dims = to_int_list(meta["value_head_dims"])
        num_heads = to_int_list(meta["num_heads"])

        self.states = []
        for i in range(len(num_encoder_layers)):
            key_dim = query_head_dims[i] * num_heads[i]
            embed_dim = encoder_dims[i]
            nonlin_attn_head_dim = 3 * embed_dim // 4
            value_dim = value_head_dims[i] * num_heads[i]
            conv_left_pad = cnn_module_kernels[i] // 2

            for _ in range(num_encoder_layers[i]):
                self.states += [
                    np.zeros(
                        (left_context_len[i], batch_size, key_dim), dtype=np.float32
                    ),
                    np.zeros(
                        (1, batch_size, left_context_len[i], nonlin_attn_head_dim),
                        dtype=np.float32,
                    ),
                    np.zeros(
                        (left_context_len[i], batch_size, value_dim), dtype=np.float32
                    ),
                    np.zeros(
                        (left_context_len[i], batch_size, value_dim), dtype=np.float32
                    ),
                    np.zeros((batch_size, embed_dim, conv_left_pad), dtype=np.float32),
                    np.zeros((batch_size, embed_dim, conv_left_pad), dtype=np.float32),
                ]
        self.states.append(np.zeros((batch_size, 128, 3, 19), dtype=np.float32))
        self.states.append(np.zeros(batch_size, dtype=np.int64))

    def reset_states(self):
        self._init_encoder_states()

    def _build_encoder_io(self, x):
        encoder_input = {"x": x.numpy()}
        encoder_output = ["encoder_out"]

        for i in range(len(self.states[:-2]) // 6):
            tensors = self.states[i * 6 : (i + 1) * 6]
            for j, prefix in enumerate(
                [
                    "cached_key",
                    "cached_nonlin_attn",
                    "cached_val1",
                    "cached_val2",
                    "cached_conv1",
                    "cached_conv2",
                ]
            ):
                name = f"{prefix}_{i}"
                encoder_input[name] = tensors[j]
                encoder_output.append(f"new_{name}")

        encoder_input["embed_states"] = self.states[-2]
        encoder_output.append("new_embed_states")
        encoder_input["processed_lens"] = self.states[-1]
        encoder_output.append("new_processed_lens")

        return encoder_input, encoder_output

    def run_encoder(self, x):
        encoder_input, encoder_output_names = self._build_encoder_io(x)
        out = self.encoder.run(encoder_output_names, encoder_input)
        self.states = out[1:]
        return torch.from_numpy(out[0])

    def run_decoder(self, decoder_input):
        out = self.decoder.run(
            [self.decoder.get_outputs()[0].name],
            {self.decoder.get_inputs()[0].name: decoder_input.numpy()},
        )[0]
        return torch.from_numpy(out)

    def run_joiner(self, encoder_out, decoder_out):
        out = self.joiner.run(
            [self.joiner.get_outputs()[0].name],
            {
                self.joiner.get_inputs()[0].name: encoder_out.numpy(),
                self.joiner.get_inputs()[1].name: decoder_out.numpy(),
            },
        )[0]
        return torch.from_numpy(out)

OnnxStreamingCtcModel

Streaming ONNX CTC model with state management.

源代码位于: zipformer/modules/model.py
class OnnxStreamingCtcModel:
    """Streaming ONNX CTC model with state management."""

    def __init__(self, model_filename):
        import onnxruntime as ort

        session_opts = ort.SessionOptions()
        session_opts.inter_op_num_threads = 1
        session_opts.intra_op_num_threads = 1

        self.model = ort.InferenceSession(
            model_filename,
            sess_options=session_opts,
            providers=["CPUExecutionProvider"],
        )
        self._init_states()

    def _init_states(self, batch_size=1):
        meta = self.model.get_modelmeta().custom_metadata_map
        self.segment = int(meta["T"])
        self.offset = int(meta["decode_chunk_len"])

        def to_int_list(s):
            return list(map(int, s.split(",")))

        num_encoder_layers = to_int_list(meta["num_encoder_layers"])
        encoder_dims = to_int_list(meta["encoder_dims"])
        cnn_module_kernels = to_int_list(meta["cnn_module_kernels"])
        left_context_len = to_int_list(meta["left_context_len"])
        query_head_dims = to_int_list(meta["query_head_dims"])
        value_head_dims = to_int_list(meta["value_head_dims"])
        num_heads = to_int_list(meta["num_heads"])

        self.states = []
        for i in range(len(num_encoder_layers)):
            key_dim = query_head_dims[i] * num_heads[i]
            embed_dim = encoder_dims[i]
            nonlin_attn_head_dim = 3 * embed_dim // 4
            value_dim = value_head_dims[i] * num_heads[i]
            conv_left_pad = cnn_module_kernels[i] // 2

            for _ in range(num_encoder_layers[i]):
                self.states += [
                    np.zeros(
                        (left_context_len[i], batch_size, key_dim), dtype=np.float32
                    ),
                    np.zeros(
                        (1, batch_size, left_context_len[i], nonlin_attn_head_dim),
                        dtype=np.float32,
                    ),
                    np.zeros(
                        (left_context_len[i], batch_size, value_dim), dtype=np.float32
                    ),
                    np.zeros(
                        (left_context_len[i], batch_size, value_dim), dtype=np.float32
                    ),
                    np.zeros((batch_size, embed_dim, conv_left_pad), dtype=np.float32),
                    np.zeros((batch_size, embed_dim, conv_left_pad), dtype=np.float32),
                ]
        self.states.append(np.zeros((batch_size, 128, 3, 19), dtype=np.float32))
        self.states.append(np.zeros(batch_size, dtype=np.int64))

    def reset_states(self):
        self._init_states()

    def _build_model_io(self, x):
        model_input = {"x": x.numpy()}
        model_output = ["log_probs"]

        for i in range(len(self.states[:-2]) // 6):
            tensors = self.states[i * 6 : (i + 1) * 6]
            for j, prefix in enumerate(
                [
                    "cached_key",
                    "cached_nonlin_attn",
                    "cached_val1",
                    "cached_val2",
                    "cached_conv1",
                    "cached_conv2",
                ]
            ):
                name = f"{prefix}_{i}"
                model_input[name] = tensors[j]
                model_output.append(f"new_{name}")

        model_input["embed_states"] = self.states[-2]
        model_output.append("new_embed_states")
        model_input["processed_lens"] = self.states[-1]
        model_output.append("new_processed_lens")

        return model_input, model_output

    def __call__(self, x):
        model_input, model_output_names = self._build_model_io(x)
        out = self.model.run(model_output_names, model_input)
        self.states = out[1:]
        return torch.from_numpy(out[0])

Attention Decoder

zipformer.modules.attention_decoder

AttentionDecoderModel

Bases: Module

参数:

名称 类型 描述 默认
vocab_size int

Number of classes.

必需
decoder_dim int

(int,int): embedding dimension of 2 encoder stacks

512
attention_dim int

(int,int): attention dimension of 2 encoder stacks

512
num_heads (int, int)

number of heads

8
dim_feedforward (int, int)

feedforward dimension in 2 encoder stacks

必需
num_encoder_layers int

number of encoder layers

必需
dropout float

dropout rate

0.1
源代码位于: zipformer/modules/attention_decoder.py
class AttentionDecoderModel(torch.nn.Module):
    """
    Args:
        vocab_size (int): Number of classes.
        decoder_dim: (int,int): embedding dimension of 2 encoder stacks
        attention_dim: (int,int): attention dimension of 2 encoder stacks
        num_heads (int, int): number of heads
        dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
        num_encoder_layers (int): number of encoder layers
        dropout (float): dropout rate
    """

    def __init__(
        self,
        vocab_size: int,
        decoder_dim: int = 512,
        num_decoder_layers: int = 6,
        attention_dim: int = 512,
        num_heads: int = 8,
        feedforward_dim: int = 2048,
        memory_dim: int = 512,
        sos_id: int = 1,
        eos_id: int = 1,
        dropout: float = 0.1,
        ignore_id: int = -1,
        label_smoothing: float = 0.1,
    ):
        super().__init__()
        self.eos_id = eos_id
        self.sos_id = sos_id
        self.ignore_id = ignore_id

        # For the segment of the warmup period, we let the Embedding
        # layer learn something.  Then we start to warm up the other encoders.
        self.decoder = TransformerDecoder(
            vocab_size=vocab_size,
            d_model=decoder_dim,
            num_decoder_layers=num_decoder_layers,
            attention_dim=attention_dim,
            num_heads=num_heads,
            feedforward_dim=feedforward_dim,
            memory_dim=memory_dim,
            dropout=dropout,
        )

        # Used to calculate attention-decoder loss
        self.loss_fun = LabelSmoothingLoss(
            ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="none"
        )

    def _pre_ys_in_out(
        self, ys: List[List[int]], device: Optional[torch.device] = None
    ):
        """Prepare ys_in_pad and ys_out_pad."""
        # [B, S+1], start with SOS
        ys_in_pad, ys_in_lens = pad_sequences(
            ys, padding_value=self.eos_id, sos_id=self.sos_id, device=device
        )
        # [B, S+1], end with EOS
        ys_out_pad, _ = pad_sequences(
            ys, padding_value=self.ignore_id, eos_id=self.eos_id, device=device
        )
        return ys_in_pad, ys_in_lens, ys_out_pad

    def calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys: List[List[int]],
        reduction: str = "sum",
    ) -> torch.Tensor:
        """Calculate attention-decoder loss.
        Args:
          encoder_out: (batch, num_frames, encoder_dim)
          encoder_out_lens: (batch,)
          ys: A list of token id list.

        Return: The attention-decoder loss.
        """
        ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(
            ys, device=encoder_out.device
        )

        # decoder forward
        decoder_out = self.decoder(
            x=ys_in_pad,
            x_lens=ys_in_lens,
            memory=encoder_out,
            memory_lens=encoder_out_lens,
        )

        loss = self.loss_fun(x=decoder_out, target=ys_out_pad)
        if reduction == "sum":
            loss = loss.sum()
        else:
            assert reduction == "none", reduction

        return loss

    def nll(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys: List[List[int]],
    ) -> torch.Tensor:
        """Compute negative log likelihood(nll) from attention-decoder.
        Args:
          encoder_out: (batch, num_frames, encoder_dim)
          encoder_out_lens: (batch,)
          ys: A list of token id list.

        Return: A tensor of shape (batch, num_tokens).
        """
        ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(
            ys, device=encoder_out.device
        )

        # decoder forward
        decoder_out = self.decoder(
            x=ys_in_pad,
            x_lens=ys_in_lens,
            memory=encoder_out,
            memory_lens=encoder_out_lens,
        )

        batch_size, _, num_classes = decoder_out.size()
        nll = torch.nn.functional.cross_entropy(
            decoder_out.view(-1, num_classes),
            ys_out_pad.view(-1),
            ignore_index=self.ignore_id,
            reduction="none",
        )
        nll = nll.view(batch_size, -1)
        return nll

calc_att_loss(encoder_out, encoder_out_lens, ys, reduction='sum')

Calculate attention-decoder loss. Args: encoder_out: (batch, num_frames, encoder_dim) encoder_out_lens: (batch,) ys: A list of token id list.

Return: The attention-decoder loss.

源代码位于: zipformer/modules/attention_decoder.py
def calc_att_loss(
    self,
    encoder_out: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    ys: List[List[int]],
    reduction: str = "sum",
) -> torch.Tensor:
    """Calculate attention-decoder loss.
    Args:
      encoder_out: (batch, num_frames, encoder_dim)
      encoder_out_lens: (batch,)
      ys: A list of token id list.

    Return: The attention-decoder loss.
    """
    ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(
        ys, device=encoder_out.device
    )

    # decoder forward
    decoder_out = self.decoder(
        x=ys_in_pad,
        x_lens=ys_in_lens,
        memory=encoder_out,
        memory_lens=encoder_out_lens,
    )

    loss = self.loss_fun(x=decoder_out, target=ys_out_pad)
    if reduction == "sum":
        loss = loss.sum()
    else:
        assert reduction == "none", reduction

    return loss

nll(encoder_out, encoder_out_lens, ys)

Compute negative log likelihood(nll) from attention-decoder. Args: encoder_out: (batch, num_frames, encoder_dim) encoder_out_lens: (batch,) ys: A list of token id list.

Return: A tensor of shape (batch, num_tokens).

源代码位于: zipformer/modules/attention_decoder.py
def nll(
    self,
    encoder_out: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    ys: List[List[int]],
) -> torch.Tensor:
    """Compute negative log likelihood(nll) from attention-decoder.
    Args:
      encoder_out: (batch, num_frames, encoder_dim)
      encoder_out_lens: (batch,)
      ys: A list of token id list.

    Return: A tensor of shape (batch, num_tokens).
    """
    ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(
        ys, device=encoder_out.device
    )

    # decoder forward
    decoder_out = self.decoder(
        x=ys_in_pad,
        x_lens=ys_in_lens,
        memory=encoder_out,
        memory_lens=encoder_out_lens,
    )

    batch_size, _, num_classes = decoder_out.size()
    nll = torch.nn.functional.cross_entropy(
        decoder_out.view(-1, num_classes),
        ys_out_pad.view(-1),
        ignore_index=self.ignore_id,
        reduction="none",
    )
    nll = nll.view(batch_size, -1)
    return nll

TransformerDecoder

Bases: Module

Transfomer decoder module.

参数:

名称 类型 描述 默认
vocab_size int

output dim

必需
d_model int

decoder dimension

512
num_decoder_layers int

number of decoder layers

6
attention_dim int

total dimension of multi head attention

512
num_heads int

number of attention heads

8
feedforward_dim int

hidden dimension of feed_forward module

2048
dropout float

dropout rate

0.1
源代码位于: zipformer/modules/attention_decoder.py
class TransformerDecoder(torch.nn.Module):
    """Transfomer decoder module.

    Args:
        vocab_size: output dim
        d_model: decoder dimension
        num_decoder_layers: number of decoder layers
        attention_dim: total dimension of multi head attention
        num_heads: number of attention heads
        feedforward_dim: hidden dimension of feed_forward module
        dropout: dropout rate
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_decoder_layers: int = 6,
        attention_dim: int = 512,
        num_heads: int = 8,
        feedforward_dim: int = 2048,
        memory_dim: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embed = torch.nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=d_model
        )

        # Absolute positional encoding
        self.pos = PositionalEncoding(d_model, dropout_rate=0.1)

        self.num_layers = num_decoder_layers
        self.layers = torch.nn.ModuleList(
            [
                DecoderLayer(
                    d_model=d_model,
                    attention_dim=attention_dim,
                    num_heads=num_heads,
                    feedforward_dim=feedforward_dim,
                    memory_dim=memory_dim,
                    dropout=dropout,
                )
                for _ in range(num_decoder_layers)
            ]
        )

        self.output_layer = torch.nn.Linear(d_model, vocab_size)

    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        memory: Optional[torch.Tensor] = None,
        memory_lens: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
          x: Input tensor of shape (batch, tgt_len).
          x_lens: A tensor of shape (batch,) containing the number of tokens in `x`
            before padding.
          memory:
            Memory sequence of shape (batch, src_len, memory_dim).
          memory_lens:
            A tensor of shape (batch,) containing the number of frames in
            `memory` before padding.

        Returns:
            Decoded token logits before softmax (batch, tgt_len, vocab_size)
        """
        x = self.embed(x)  # (batch, tgt_len, embed_dim)
        x = self.pos(x)  # (batch, tgt_len, embed_dim)

        x = x.permute(1, 0, 2)  # (tgt_len, batch, embed_dim)

        # construct attn_mask for self-attn modules
        padding_mask = make_pad_mask(x_lens)  # (batch, tgt_len)
        causal_mask = subsequent_mask(x.shape[0], device=x.device)  # (seq_len, seq_len)
        attn_mask = torch.logical_or(
            padding_mask.unsqueeze(1),  # (batch, 1, seq_len)
            torch.logical_not(causal_mask).unsqueeze(0),  # (1, seq_len, seq_len)
        )  # (batch, seq_len, seq_len)

        if memory is not None:
            memory = memory.permute(1, 0, 2)  # (src_len, batch, memory_dim)
            # construct memory_attn_mask for cross-attn modules
            memory_padding_mask = make_pad_mask(memory_lens)  # (batch, src_len)
            memory_attn_mask = memory_padding_mask.unsqueeze(1)  # (batch, 1, src_len)
        else:
            memory_attn_mask = None

        for i, mod in enumerate(self.layers):
            x = mod(
                x,
                attn_mask=attn_mask,
                memory=memory,
                memory_attn_mask=memory_attn_mask,
            )

        x = x.permute(1, 0, 2)  # (batch, tgt_len, vocab_size)
        x = self.output_layer(x)

        return x

forward(x, x_lens, memory=None, memory_lens=None)

参数:

名称 类型 描述 默认
x Tensor

Input tensor of shape (batch, tgt_len).

必需
x_lens Tensor

A tensor of shape (batch,) containing the number of tokens in x before padding.

必需
memory Optional[Tensor]

Memory sequence of shape (batch, src_len, memory_dim).

None
memory_lens Optional[Tensor]

A tensor of shape (batch,) containing the number of frames in memory before padding.

None

返回:

类型 描述
Tensor

Decoded token logits before softmax (batch, tgt_len, vocab_size)

源代码位于: zipformer/modules/attention_decoder.py
def forward(
    self,
    x: torch.Tensor,
    x_lens: torch.Tensor,
    memory: Optional[torch.Tensor] = None,
    memory_lens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Args:
      x: Input tensor of shape (batch, tgt_len).
      x_lens: A tensor of shape (batch,) containing the number of tokens in `x`
        before padding.
      memory:
        Memory sequence of shape (batch, src_len, memory_dim).
      memory_lens:
        A tensor of shape (batch,) containing the number of frames in
        `memory` before padding.

    Returns:
        Decoded token logits before softmax (batch, tgt_len, vocab_size)
    """
    x = self.embed(x)  # (batch, tgt_len, embed_dim)
    x = self.pos(x)  # (batch, tgt_len, embed_dim)

    x = x.permute(1, 0, 2)  # (tgt_len, batch, embed_dim)

    # construct attn_mask for self-attn modules
    padding_mask = make_pad_mask(x_lens)  # (batch, tgt_len)
    causal_mask = subsequent_mask(x.shape[0], device=x.device)  # (seq_len, seq_len)
    attn_mask = torch.logical_or(
        padding_mask.unsqueeze(1),  # (batch, 1, seq_len)
        torch.logical_not(causal_mask).unsqueeze(0),  # (1, seq_len, seq_len)
    )  # (batch, seq_len, seq_len)

    if memory is not None:
        memory = memory.permute(1, 0, 2)  # (src_len, batch, memory_dim)
        # construct memory_attn_mask for cross-attn modules
        memory_padding_mask = make_pad_mask(memory_lens)  # (batch, src_len)
        memory_attn_mask = memory_padding_mask.unsqueeze(1)  # (batch, 1, src_len)
    else:
        memory_attn_mask = None

    for i, mod in enumerate(self.layers):
        x = mod(
            x,
            attn_mask=attn_mask,
            memory=memory,
            memory_attn_mask=memory_attn_mask,
        )

    x = x.permute(1, 0, 2)  # (batch, tgt_len, vocab_size)
    x = self.output_layer(x)

    return x

DecoderLayer

Bases: Module

Single decoder layer module.

参数:

名称 类型 描述 默认
d_model int

equal to decoder_dim, total dimension of the decoder

512
attention_dim int

total dimension of multi head attention

512
num_heads int

number of attention heads

8
feedforward_dim int

hidden dimension of feed_forward module

2048
dropout float

dropout rate

0.1
源代码位于: zipformer/modules/attention_decoder.py
class DecoderLayer(torch.nn.Module):
    """Single decoder layer module.

    Args:
        d_model: equal to decoder_dim, total dimension of the decoder
        attention_dim: total dimension of multi head attention
        num_heads: number of attention heads
        feedforward_dim: hidden dimension of feed_forward module
        dropout: dropout rate
    """

    def __init__(
        self,
        d_model: int = 512,
        attention_dim: int = 512,
        num_heads: int = 8,
        feedforward_dim: int = 2048,
        memory_dim: int = 512,
        dropout: float = 0.1,
    ):
        """Construct an DecoderLayer object."""
        super(DecoderLayer, self).__init__()

        self.norm_self_attn = torch.nn.LayerNorm(d_model)
        self.self_attn = MultiHeadAttention(
            d_model, attention_dim, num_heads, dropout=0.0
        )

        self.norm_src_attn = torch.nn.LayerNorm(d_model)
        self.src_attn = MultiHeadAttention(
            d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0
        )

        self.norm_ff = torch.nn.LayerNorm(d_model)
        self.feed_forward = torch.nn.Sequential(
            torch.nn.Linear(d_model, feedforward_dim),
            Swish(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(feedforward_dim, d_model),
        )

        self.dropout = torch.nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        memory: Optional[torch.Tensor] = None,
        memory_attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            x: Input sequence of shape (seq_len, batch, embed_dim).
            attn_mask: A binary mask for self-attention module indicating which
                elements will be filled with -inf.
                Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
            memory: Memory sequence of shape (seq_len, batch, memory_dim).
            memory_attn_mask: A binary mask for cross-attention module indicating which
                elements will be filled with -inf.
                Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
        """
        # self-attn module
        qkv = self.norm_self_attn(x)
        self_attn_out = self.self_attn(
            query=qkv, key=qkv, value=qkv, attn_mask=attn_mask
        )
        x = x + self.dropout(self_attn_out)

        # cross-attn module
        q = self.norm_src_attn(x)
        src_attn_out = self.src_attn(
            query=q, key=memory, value=memory, attn_mask=memory_attn_mask
        )
        x = x + self.dropout(src_attn_out)

        # feed-forward module
        x = x + self.dropout(self.feed_forward(self.norm_ff(x)))

        return x

__init__(d_model=512, attention_dim=512, num_heads=8, feedforward_dim=2048, memory_dim=512, dropout=0.1)

Construct an DecoderLayer object.

源代码位于: zipformer/modules/attention_decoder.py
def __init__(
    self,
    d_model: int = 512,
    attention_dim: int = 512,
    num_heads: int = 8,
    feedforward_dim: int = 2048,
    memory_dim: int = 512,
    dropout: float = 0.1,
):
    """Construct an DecoderLayer object."""
    super(DecoderLayer, self).__init__()

    self.norm_self_attn = torch.nn.LayerNorm(d_model)
    self.self_attn = MultiHeadAttention(
        d_model, attention_dim, num_heads, dropout=0.0
    )

    self.norm_src_attn = torch.nn.LayerNorm(d_model)
    self.src_attn = MultiHeadAttention(
        d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0
    )

    self.norm_ff = torch.nn.LayerNorm(d_model)
    self.feed_forward = torch.nn.Sequential(
        torch.nn.Linear(d_model, feedforward_dim),
        Swish(),
        torch.nn.Dropout(dropout),
        torch.nn.Linear(feedforward_dim, d_model),
    )

    self.dropout = torch.nn.Dropout(dropout)

forward(x, attn_mask=None, memory=None, memory_attn_mask=None)

参数:

名称 类型 描述 默认
x Tensor

Input sequence of shape (seq_len, batch, embed_dim).

必需
attn_mask Optional[Tensor]

A binary mask for self-attention module indicating which elements will be filled with -inf. Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).

None
memory Optional[Tensor]

Memory sequence of shape (seq_len, batch, memory_dim).

None
memory_attn_mask Optional[Tensor]

A binary mask for cross-attention module indicating which elements will be filled with -inf. Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).

None
源代码位于: zipformer/modules/attention_decoder.py
def forward(
    self,
    x: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    memory: Optional[torch.Tensor] = None,
    memory_attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Args:
        x: Input sequence of shape (seq_len, batch, embed_dim).
        attn_mask: A binary mask for self-attention module indicating which
            elements will be filled with -inf.
            Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
        memory: Memory sequence of shape (seq_len, batch, memory_dim).
        memory_attn_mask: A binary mask for cross-attention module indicating which
            elements will be filled with -inf.
            Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
    """
    # self-attn module
    qkv = self.norm_self_attn(x)
    self_attn_out = self.self_attn(
        query=qkv, key=qkv, value=qkv, attn_mask=attn_mask
    )
    x = x + self.dropout(self_attn_out)

    # cross-attn module
    q = self.norm_src_attn(x)
    src_attn_out = self.src_attn(
        query=q, key=memory, value=memory, attn_mask=memory_attn_mask
    )
    x = x + self.dropout(src_attn_out)

    # feed-forward module
    x = x + self.dropout(self.feed_forward(self.norm_ff(x)))

    return x

MultiHeadAttention

Bases: Module

Multi-Head Attention layer.

参数:

名称 类型 描述 默认
embed_dim int

total dimension of the model.

必需
attention_dim int

dimension in the attention module, but must be a multiple of num_heads.

必需
num_heads int

number of parallel attention heads.

必需
memory_dim Optional[int]

dimension of memory embedding, optional.

None
dropout float

a Dropout layer on attn_output_weights.

0.0
源代码位于: zipformer/modules/attention_decoder.py
class MultiHeadAttention(torch.nn.Module):
    """Multi-Head Attention layer.

    Args:
        embed_dim: total dimension of the model.
        attention_dim: dimension in the attention module, but must be a multiple of num_heads.
        num_heads: number of parallel attention heads.
        memory_dim: dimension of memory embedding, optional.
        dropout: a Dropout layer on attn_output_weights.
    """

    def __init__(
        self,
        embed_dim: int,
        attention_dim: int,
        num_heads: int,
        memory_dim: Optional[int] = None,
        dropout: float = 0.0,
    ):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.attention_dim = attention_dim
        self.num_heads = num_heads
        self.head_dim = attention_dim // num_heads
        assert self.head_dim * num_heads == attention_dim, (
            self.head_dim,
            num_heads,
            attention_dim,
        )
        self.dropout = dropout
        self.name = None  # will be overwritten in training code; for diagnostics.

        self.linear_q = torch.nn.Linear(embed_dim, attention_dim, bias=True)
        self.linear_k = torch.nn.Linear(
            embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
        )
        self.linear_v = torch.nn.Linear(
            embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
        )

        self.out_proj = torch.nn.Linear(attention_dim, embed_dim, bias=True)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_padding_mask: Optional[torch.Tensor] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Compute dot product attention.

        Args:
            query: Query tensor of shape (tgt_len, batch, embed_dim).
            key: Key tensor of shape (src_len, batch, embed_dim or memory_dim).
            value: Value tensor of shape (src_len, batch, embed_dim or memory_dim).
            key_padding_mask: A binary mask indicating which elements are padding.
                Its shape is (batch, src_len).
            attn_mask: A binary mask indicating which elements will be filled with -inf.
                Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).

        Returns:
            Output tensor of shape (tgt_len, batch, embed_dim).
        """
        num_heads = self.num_heads
        head_dim = self.head_dim

        tgt_len, batch, _ = query.shape
        src_len = key.shape[0]

        q = self.linear_q(query)  # (tgt_len, batch, num_heads * head_dim)
        k = self.linear_k(key)  # (src_len, batch, num_heads * head_dim)
        v = self.linear_v(value)  # (src_len, batch, num_heads * head_dim)

        q = q.reshape(tgt_len, batch, num_heads, head_dim)
        q = q.permute(1, 2, 0, 3)  # (batch, head, tgt_len, head_dim)
        k = k.reshape(src_len, batch, num_heads, head_dim)
        k = k.permute(1, 2, 3, 0)  # (batch, head, head_dim, src_len)
        v = v.reshape(src_len, batch, num_heads, head_dim)
        v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1)

        # Note: could remove the scaling operation when using ScaledAdam
        # (batch, head, tgt_len, src_len)
        attn_weights = torch.matmul(q, k) / math.sqrt(head_dim)

        # From zipformer.py:
        # This is a harder way of limiting the attention scores to not be too large.
        # It incurs a penalty if any of them has an absolute value greater than 50.0.
        # this should be outside the normal range of the attention scores.  We use
        # this mechanism instead of, say, a limit on entropy, because once the entropy
        # gets very small gradients through the softmax can become very small, and
        # some mechanisms like that become ineffective.
        attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04)

        if key_padding_mask is not None:
            assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                float("-inf"),
            )

        if attn_mask is not None:
            assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == (
                batch,
                tgt_len,
                src_len,
            ), attn_mask.shape
            attn_weights = attn_weights.masked_fill(
                attn_mask.unsqueeze(1), float("-inf")
            )

        attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)

        attn_weights = torch.nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )

        # (batch * head, tgt_len, head_dim)
        attn_output = torch.bmm(attn_weights, v)
        assert attn_output.shape == (
            batch * num_heads,
            tgt_len,
            head_dim,
        ), attn_output.shape

        attn_output = attn_output.transpose(0, 1).contiguous()
        attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)

        # (batch, tgt_len, embed_dim)
        attn_output = self.out_proj(attn_output)

        return attn_output

forward(query, key, value, key_padding_mask=None, attn_mask=None)

Compute dot product attention.

参数:

名称 类型 描述 默认
query Tensor

Query tensor of shape (tgt_len, batch, embed_dim).

必需
key Tensor

Key tensor of shape (src_len, batch, embed_dim or memory_dim).

必需
value Tensor

Value tensor of shape (src_len, batch, embed_dim or memory_dim).

必需
key_padding_mask Optional[Tensor]

A binary mask indicating which elements are padding. Its shape is (batch, src_len).

None
attn_mask Optional[Tensor]

A binary mask indicating which elements will be filled with -inf. Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).

None

返回:

类型 描述
Tensor

Output tensor of shape (tgt_len, batch, embed_dim).

源代码位于: zipformer/modules/attention_decoder.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    key_padding_mask: Optional[torch.Tensor] = None,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Compute dot product attention.

    Args:
        query: Query tensor of shape (tgt_len, batch, embed_dim).
        key: Key tensor of shape (src_len, batch, embed_dim or memory_dim).
        value: Value tensor of shape (src_len, batch, embed_dim or memory_dim).
        key_padding_mask: A binary mask indicating which elements are padding.
            Its shape is (batch, src_len).
        attn_mask: A binary mask indicating which elements will be filled with -inf.
            Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).

    Returns:
        Output tensor of shape (tgt_len, batch, embed_dim).
    """
    num_heads = self.num_heads
    head_dim = self.head_dim

    tgt_len, batch, _ = query.shape
    src_len = key.shape[0]

    q = self.linear_q(query)  # (tgt_len, batch, num_heads * head_dim)
    k = self.linear_k(key)  # (src_len, batch, num_heads * head_dim)
    v = self.linear_v(value)  # (src_len, batch, num_heads * head_dim)

    q = q.reshape(tgt_len, batch, num_heads, head_dim)
    q = q.permute(1, 2, 0, 3)  # (batch, head, tgt_len, head_dim)
    k = k.reshape(src_len, batch, num_heads, head_dim)
    k = k.permute(1, 2, 3, 0)  # (batch, head, head_dim, src_len)
    v = v.reshape(src_len, batch, num_heads, head_dim)
    v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1)

    # Note: could remove the scaling operation when using ScaledAdam
    # (batch, head, tgt_len, src_len)
    attn_weights = torch.matmul(q, k) / math.sqrt(head_dim)

    # From zipformer.py:
    # This is a harder way of limiting the attention scores to not be too large.
    # It incurs a penalty if any of them has an absolute value greater than 50.0.
    # this should be outside the normal range of the attention scores.  We use
    # this mechanism instead of, say, a limit on entropy, because once the entropy
    # gets very small gradients through the softmax can become very small, and
    # some mechanisms like that become ineffective.
    attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04)

    if key_padding_mask is not None:
        assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
        attn_weights = attn_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2),
            float("-inf"),
        )

    if attn_mask is not None:
        assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == (
            batch,
            tgt_len,
            src_len,
        ), attn_mask.shape
        attn_weights = attn_weights.masked_fill(
            attn_mask.unsqueeze(1), float("-inf")
        )

    attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)

    attn_weights = torch.nn.functional.dropout(
        attn_weights, p=self.dropout, training=self.training
    )

    # (batch * head, tgt_len, head_dim)
    attn_output = torch.bmm(attn_weights, v)
    assert attn_output.shape == (
        batch * num_heads,
        tgt_len,
        head_dim,
    ), attn_output.shape

    attn_output = attn_output.transpose(0, 1).contiguous()
    attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)

    # (batch, tgt_len, embed_dim)
    attn_output = self.out_proj(attn_output)

    return attn_output

PositionalEncoding

Bases: Module

Positional encoding. Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35.

参数:

名称 类型 描述 默认
d_model int

Embedding dimension.

必需
dropout_rate float

Dropout rate.

必需
max_len int

Maximum input length.

5000
源代码位于: zipformer/modules/attention_decoder.py
class PositionalEncoding(torch.nn.Module):
    """Positional encoding.
    Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35.

    Args:
        d_model (int): Embedding dimension.
        dropout_rate (float): Dropout rate.
        max_len (int): Maximum input length.
    """

    def __init__(self, d_model, dropout_rate, max_len=5000):
        """Construct an PositionalEncoding object."""
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.xscale = math.sqrt(self.d_model)
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.pe = None
        self.extend_pe(torch.tensor(0.0).expand(1, max_len))

    def extend_pe(self, x):
        """Reset the positional encodings."""
        if self.pe is not None:
            if self.pe.size(1) >= x.size(1):
                if self.pe.dtype != x.dtype or self.pe.device != x.device:
                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                return
        pe = torch.zeros(x.size(1), self.d_model)
        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, dtype=torch.float32)
            * -(math.log(10000.0) / self.d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe.to(device=x.device, dtype=x.dtype)

    def forward(self, x: torch.Tensor):
        """Add positional encoding.

        Args:
            x (torch.Tensor): Input tensor (batch, time, `*`).

        Returns:
            torch.Tensor: Encoded tensor (batch, time, `*`).
        """
        self.extend_pe(x)
        x = x * self.xscale + self.pe[:, : x.size(1)]
        return self.dropout(x)

__init__(d_model, dropout_rate, max_len=5000)

Construct an PositionalEncoding object.

源代码位于: zipformer/modules/attention_decoder.py
def __init__(self, d_model, dropout_rate, max_len=5000):
    """Construct an PositionalEncoding object."""
    super(PositionalEncoding, self).__init__()
    self.d_model = d_model
    self.xscale = math.sqrt(self.d_model)
    self.dropout = torch.nn.Dropout(p=dropout_rate)
    self.pe = None
    self.extend_pe(torch.tensor(0.0).expand(1, max_len))

extend_pe(x)

Reset the positional encodings.

源代码位于: zipformer/modules/attention_decoder.py
def extend_pe(self, x):
    """Reset the positional encodings."""
    if self.pe is not None:
        if self.pe.size(1) >= x.size(1):
            if self.pe.dtype != x.dtype or self.pe.device != x.device:
                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
            return
    pe = torch.zeros(x.size(1), self.d_model)
    position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, self.d_model, 2, dtype=torch.float32)
        * -(math.log(10000.0) / self.d_model)
    )
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.pe = pe.to(device=x.device, dtype=x.dtype)

forward(x)

Add positional encoding.

参数:

名称 类型 描述 默认
x Tensor

Input tensor (batch, time, *).

必需

返回:

类型 描述

torch.Tensor: Encoded tensor (batch, time, *).

源代码位于: zipformer/modules/attention_decoder.py
def forward(self, x: torch.Tensor):
    """Add positional encoding.

    Args:
        x (torch.Tensor): Input tensor (batch, time, `*`).

    Returns:
        torch.Tensor: Encoded tensor (batch, time, `*`).
    """
    self.extend_pe(x)
    x = x * self.xscale + self.pe[:, : x.size(1)]
    return self.dropout(x)

Swish

Bases: Module

Construct an Swish object.

源代码位于: zipformer/modules/attention_decoder.py
class Swish(torch.nn.Module):
    """Construct an Swish object."""

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Return Swich activation function."""
        return x * torch.sigmoid(x)

forward(x)

Return Swich activation function.

源代码位于: zipformer/modules/attention_decoder.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Return Swich activation function."""
    return x * torch.sigmoid(x)

subsequent_mask(size, device='cpu', dtype=torch.bool)

Create mask for subsequent steps (size, size).

:param int size: size of mask :param str device: "cpu" or "cuda" or torch.Tensor.device :param torch.dtype dtype: result dtype :rtype: torch.Tensor

subsequent_mask(3) [[1, 0, 0], [1, 1, 0], [1, 1, 1]]

源代码位于: zipformer/modules/attention_decoder.py
def subsequent_mask(size, device="cpu", dtype=torch.bool):
    """Create mask for subsequent steps (size, size).

    :param int size: size of mask
    :param str device: "cpu" or "cuda" or torch.Tensor.device
    :param torch.dtype dtype: result dtype
    :rtype: torch.Tensor
    >>> subsequent_mask(3)
    [[1, 0, 0],
     [1, 1, 0],
     [1, 1, 1]]
    """
    ret = torch.ones(size, size, device=device, dtype=dtype)
    return torch.tril(ret, out=ret)

Subsampling

zipformer.modules.subsampling

ConvNeXt

Bases: Module

Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf

源代码位于: zipformer/modules/subsampling.py
class ConvNeXt(torch.nn.Module):
    """
    Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
    """

    def __init__(
        self,
        channels: int,
        hidden_ratio: int = 3,
        kernel_size: Tuple[int, int] = (7, 7),
        layerdrop_rate: FloatLike = None,
    ):
        super().__init__()
        self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
        hidden_channels = channels * hidden_ratio
        if layerdrop_rate is None:
            layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
        self.layerdrop_rate = layerdrop_rate

        self.depthwise_conv = torch.nn.Conv2d(
            in_channels=channels,
            out_channels=channels,
            groups=channels,
            kernel_size=kernel_size,
            padding=self.padding,
        )

        self.pointwise_conv1 = torch.nn.Conv2d(
            in_channels=channels, out_channels=hidden_channels, kernel_size=1
        )

        self.hidden_balancer = Balancer(
            hidden_channels,
            channel_dim=1,
            min_positive=0.3,
            max_positive=1.0,
            min_abs=0.75,
            max_abs=5.0,
        )

        self.activation = SwooshL()
        self.pointwise_conv2 = ScaledConv2d(
            in_channels=hidden_channels,
            out_channels=channels,
            kernel_size=1,
            initial_scale=0.01,
        )

        self.out_balancer = Balancer(
            channels,
            channel_dim=1,
            min_positive=0.4,
            max_positive=0.6,
            min_abs=1.0,
            max_abs=6.0,
        )
        self.out_whiten = Whiten(
            num_groups=1,
            whitening_limit=5.0,
            prob=(0.025, 0.25),
            grad_scale=0.01,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
            return self.forward_internal(x)
        layerdrop_rate = float(self.layerdrop_rate)

        if layerdrop_rate != 0.0:
            batch_size = x.shape[0]
            mask = (
                torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
                > layerdrop_rate
            )
        else:
            mask = None
        # turns out this caching idea does not work with --world-size > 1
        # return caching_eval(self.forward_internal, x, mask)
        return self.forward_internal(x, mask)

    def forward_internal(
        self, x: torch.Tensor, layer_skip_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)

        The returned value has the same shape as x.
        """
        bypass = x
        x = self.depthwise_conv(x)
        x = self.pointwise_conv1(x)
        x = self.hidden_balancer(x)
        x = self.activation(x)
        x = self.pointwise_conv2(x)

        if layer_skip_mask is not None:
            x = x * layer_skip_mask

        x = bypass + x
        x = self.out_balancer(x)

        if x.requires_grad:
            x = x.transpose(1, 3)  # (N, W, H, C); need channel dim to be last
            x = self.out_whiten(x)
            x = x.transpose(1, 3)  # (N, C, H, W)

        return x

    def streaming_forward(
        self,
        x: torch.Tensor,
        cached_left_pad: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
            cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)

        Returns:
            - The returned value has the same shape as x.
            - Updated cached_left_pad.
        """
        padding = self.padding

        # The length without right padding for depth-wise conv
        T = x.size(2) - padding[0]

        bypass = x[:, :, :T, :]

        # Pad left side
        assert cached_left_pad.size(2) == padding[0], (
            cached_left_pad.size(2),
            padding[0],
        )
        x = torch.cat([cached_left_pad, x], dim=2)
        # Update cached left padding
        cached_left_pad = x[:, :, T : padding[0] + T, :]

        # depthwise_conv
        x = torch.nn.functional.conv2d(
            x,
            weight=self.depthwise_conv.weight,
            bias=self.depthwise_conv.bias,
            padding=(0, padding[1]),
            groups=self.depthwise_conv.groups,
        )
        x = self.pointwise_conv1(x)
        x = self.hidden_balancer(x)
        x = self.activation(x)
        x = self.pointwise_conv2(x)

        x = bypass + x
        return x, cached_left_pad

forward_internal(x, layer_skip_mask=None)

x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)

The returned value has the same shape as x.

源代码位于: zipformer/modules/subsampling.py
def forward_internal(
    self, x: torch.Tensor, layer_skip_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """
    x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)

    The returned value has the same shape as x.
    """
    bypass = x
    x = self.depthwise_conv(x)
    x = self.pointwise_conv1(x)
    x = self.hidden_balancer(x)
    x = self.activation(x)
    x = self.pointwise_conv2(x)

    if layer_skip_mask is not None:
        x = x * layer_skip_mask

    x = bypass + x
    x = self.out_balancer(x)

    if x.requires_grad:
        x = x.transpose(1, 3)  # (N, W, H, C); need channel dim to be last
        x = self.out_whiten(x)
        x = x.transpose(1, 3)  # (N, C, H, W)

    return x

streaming_forward(x, cached_left_pad)

参数:

名称 类型 描述 默认
x layout

(N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)

必需
cached_left_pad Tensor

(batch_size, num_channels, left_pad, num_freqs)

必需

返回:

类型 描述
Tensor
  • The returned value has the same shape as x.
Tensor
  • Updated cached_left_pad.
源代码位于: zipformer/modules/subsampling.py
def streaming_forward(
    self,
    x: torch.Tensor,
    cached_left_pad: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
        x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
        cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)

    Returns:
        - The returned value has the same shape as x.
        - Updated cached_left_pad.
    """
    padding = self.padding

    # The length without right padding for depth-wise conv
    T = x.size(2) - padding[0]

    bypass = x[:, :, :T, :]

    # Pad left side
    assert cached_left_pad.size(2) == padding[0], (
        cached_left_pad.size(2),
        padding[0],
    )
    x = torch.cat([cached_left_pad, x], dim=2)
    # Update cached left padding
    cached_left_pad = x[:, :, T : padding[0] + T, :]

    # depthwise_conv
    x = torch.nn.functional.conv2d(
        x,
        weight=self.depthwise_conv.weight,
        bias=self.depthwise_conv.bias,
        padding=(0, padding[1]),
        groups=self.depthwise_conv.groups,
    )
    x = self.pointwise_conv1(x)
    x = self.hidden_balancer(x)
    x = self.activation(x)
    x = self.pointwise_conv2(x)

    x = bypass + x
    return x, cached_left_pad

Conv2dSubsampling

Bases: Module

Convolutional 2D subsampling (to 1/2 length).

Convert an input of shape (N, T, idim) to an output with shape (N, T', odim), where T' = (T-3)//2 - 2 == (T-7)//2

It is based on https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa

源代码位于: zipformer/modules/subsampling.py
class Conv2dSubsampling(torch.nn.Module):
    """Convolutional 2D subsampling (to 1/2 length).

    Convert an input of shape (N, T, idim) to an output
    with shape (N, T', odim), where
    T' = (T-3)//2 - 2 == (T-7)//2

    It is based on
    https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py  # noqa
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        layer1_channels: int = 8,
        layer2_channels: int = 32,
        layer3_channels: int = 128,
        dropout: FloatLike = 0.1,
    ) -> None:
        """
        Args:
          in_channels:
            Number of channels in. The input shape is (N, T, in_channels).
            Caution: It requires: T >=7, in_channels >=7
          out_channels
            Output dim. The output shape is (N, (T-3)//2, out_channels)
          layer1_channels:
            Number of channels in layer1
          layer1_channels:
            Number of channels in layer2
          bottleneck:
            bottleneck dimension for 1d squeeze-excite
        """
        assert in_channels >= 7
        super().__init__()

        # The ScaleGrad module is there to prevent the gradients
        # w.r.t. the weight or bias of the first Conv2d module in self.conv from
        # exceeding the range of fp16 when using automatic mixed precision (amp)
        # training.  (The second one is necessary to stop its bias from getting
        # a too-large gradient).

        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=1,
                out_channels=layer1_channels,
                kernel_size=3,
                padding=(0, 1),  # (time, freq)
            ),
            ScaleGrad(0.2),
            Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
            SwooshR(),
            torch.nn.Conv2d(
                in_channels=layer1_channels,
                out_channels=layer2_channels,
                kernel_size=3,
                stride=2,
                padding=0,
            ),
            Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
            SwooshR(),
            torch.nn.Conv2d(
                in_channels=layer2_channels,
                out_channels=layer3_channels,
                kernel_size=3,
                stride=(1, 2),  # (time, freq)
            ),
            Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
            SwooshR(),
        )

        # just one convnext layer
        self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))

        # (in_channels-3)//4
        self.out_width = (((in_channels - 1) // 2) - 1) // 2
        self.layer3_channels = layer3_channels

        self.out = torch.nn.Linear(self.out_width * layer3_channels, out_channels)
        # use a larger than normal grad_scale on this whitening module; there is
        # only one such module, so there is not a concern about adding together
        # many copies of this extra gradient term.
        self.out_whiten = Whiten(
            num_groups=1,
            whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
            prob=(0.025, 0.25),
            grad_scale=0.02,
        )

        # max_log_eps=0.0 is to prevent both eps and the output of self.out from
        # getting large, there is an unnecessary degree of freedom.
        self.out_norm = BiasNorm(out_channels)
        self.dropout = Dropout3(dropout, shared_dim=1)

    def forward(
        self, x: torch.Tensor, x_lens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Subsample x.

        Args:
          x:
            Its shape is (N, T, idim).
          x_lens:
            A tensor of shape (batch_size,) containing the number of frames in

        Returns:
          - a tensor of shape (N, (T-7)//2, odim)
          - output lengths, of shape (batch_size,)
        """
        # On entry, x is (N, T, idim)
        x = x.unsqueeze(1)  # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
        # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
        # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
        # gradients.
        x = self.conv(x)
        x = self.convnext(x)

        # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
        b, c, t, f = x.size()

        x = x.transpose(1, 2).reshape(b, t, c * f)
        # now x: (N, (T-7)//2, out_width * layer3_channels))

        x = self.out(x)
        # Now x is of shape (N, (T-7)//2, odim)
        x = self.out_whiten(x)
        x = self.out_norm(x)
        x = self.dropout(x)

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            x_lens = (x_lens - 7) // 2
        else:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                x_lens = (x_lens - 7) // 2
        assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())

        return x, x_lens

    def streaming_forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        cached_left_pad: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Subsample x.

        Args:
          x:
            Its shape is (N, T, idim).
          x_lens:
            A tensor of shape (batch_size,) containing the number of frames in

        Returns:
          - a tensor of shape (N, (T-7)//2, odim)
          - output lengths, of shape (batch_size,)
          - updated cache
        """
        # On entry, x is (N, T, idim)
        x = x.unsqueeze(1)  # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)

        # T' = (T-7)//2
        x = self.conv(x)

        # T' = (T-7)//2-3
        x, cached_left_pad = self.convnext.streaming_forward(
            x, cached_left_pad=cached_left_pad
        )

        # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
        b, c, t, f = x.size()

        x = x.transpose(1, 2).reshape(b, t, c * f)
        # now x: (N, T', out_width * layer3_channels))

        x = self.out(x)
        # Now x is of shape (N, T', odim)
        x = self.out_norm(x)

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            assert self.convnext.padding[0] == 3
            # The ConvNeXt module needs 3 frames of right padding after subsampling
            x_lens = (x_lens - 7) // 2 - 3
        else:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                # The ConvNeXt module needs 3 frames of right padding after subsampling
                assert self.convnext.padding[0] == 3
                x_lens = (x_lens - 7) // 2 - 3

        assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())

        return x, x_lens, cached_left_pad

    @torch.jit.export
    def get_init_states(
        self,
        batch_size: int = 1,
        device: torch.device = torch.device("cpu"),
    ) -> torch.Tensor:
        """Get initial states for Conv2dSubsampling module.
        It is the cached left padding for ConvNeXt module,
        of shape (batch_size, num_channels, left_pad, num_freqs)
        """
        left_pad = self.convnext.padding[0]
        freq = self.out_width
        channels = self.layer3_channels
        cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
            device
        )

        return cached_embed_left_pad

__init__(in_channels, out_channels, layer1_channels=8, layer2_channels=32, layer3_channels=128, dropout=0.1)

参数:

名称 类型 描述 默认
in_channels int

Number of channels in. The input shape is (N, T, in_channels). Caution: It requires: T >=7, in_channels >=7

必需
layer1_channels int

Number of channels in layer1

8
layer1_channels int

Number of channels in layer2

8
bottleneck

bottleneck dimension for 1d squeeze-excite

必需
源代码位于: zipformer/modules/subsampling.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    layer1_channels: int = 8,
    layer2_channels: int = 32,
    layer3_channels: int = 128,
    dropout: FloatLike = 0.1,
) -> None:
    """
    Args:
      in_channels:
        Number of channels in. The input shape is (N, T, in_channels).
        Caution: It requires: T >=7, in_channels >=7
      out_channels
        Output dim. The output shape is (N, (T-3)//2, out_channels)
      layer1_channels:
        Number of channels in layer1
      layer1_channels:
        Number of channels in layer2
      bottleneck:
        bottleneck dimension for 1d squeeze-excite
    """
    assert in_channels >= 7
    super().__init__()

    # The ScaleGrad module is there to prevent the gradients
    # w.r.t. the weight or bias of the first Conv2d module in self.conv from
    # exceeding the range of fp16 when using automatic mixed precision (amp)
    # training.  (The second one is necessary to stop its bias from getting
    # a too-large gradient).

    self.conv = torch.nn.Sequential(
        torch.nn.Conv2d(
            in_channels=1,
            out_channels=layer1_channels,
            kernel_size=3,
            padding=(0, 1),  # (time, freq)
        ),
        ScaleGrad(0.2),
        Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
        SwooshR(),
        torch.nn.Conv2d(
            in_channels=layer1_channels,
            out_channels=layer2_channels,
            kernel_size=3,
            stride=2,
            padding=0,
        ),
        Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
        SwooshR(),
        torch.nn.Conv2d(
            in_channels=layer2_channels,
            out_channels=layer3_channels,
            kernel_size=3,
            stride=(1, 2),  # (time, freq)
        ),
        Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
        SwooshR(),
    )

    # just one convnext layer
    self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))

    # (in_channels-3)//4
    self.out_width = (((in_channels - 1) // 2) - 1) // 2
    self.layer3_channels = layer3_channels

    self.out = torch.nn.Linear(self.out_width * layer3_channels, out_channels)
    # use a larger than normal grad_scale on this whitening module; there is
    # only one such module, so there is not a concern about adding together
    # many copies of this extra gradient term.
    self.out_whiten = Whiten(
        num_groups=1,
        whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
        prob=(0.025, 0.25),
        grad_scale=0.02,
    )

    # max_log_eps=0.0 is to prevent both eps and the output of self.out from
    # getting large, there is an unnecessary degree of freedom.
    self.out_norm = BiasNorm(out_channels)
    self.dropout = Dropout3(dropout, shared_dim=1)

forward(x, x_lens)

Subsample x.

参数:

名称 类型 描述 默认
x Tensor

Its shape is (N, T, idim).

必需
x_lens Tensor

A tensor of shape (batch_size,) containing the number of frames in

必需

返回:

类型 描述
Tensor
  • a tensor of shape (N, (T-7)//2, odim)
Tensor
  • output lengths, of shape (batch_size,)
源代码位于: zipformer/modules/subsampling.py
def forward(
    self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Subsample x.

    Args:
      x:
        Its shape is (N, T, idim).
      x_lens:
        A tensor of shape (batch_size,) containing the number of frames in

    Returns:
      - a tensor of shape (N, (T-7)//2, odim)
      - output lengths, of shape (batch_size,)
    """
    # On entry, x is (N, T, idim)
    x = x.unsqueeze(1)  # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
    # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
    # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
    # gradients.
    x = self.conv(x)
    x = self.convnext(x)

    # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
    b, c, t, f = x.size()

    x = x.transpose(1, 2).reshape(b, t, c * f)
    # now x: (N, (T-7)//2, out_width * layer3_channels))

    x = self.out(x)
    # Now x is of shape (N, (T-7)//2, odim)
    x = self.out_whiten(x)
    x = self.out_norm(x)
    x = self.dropout(x)

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        x_lens = (x_lens - 7) // 2
    else:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            x_lens = (x_lens - 7) // 2
    assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())

    return x, x_lens

streaming_forward(x, x_lens, cached_left_pad)

Subsample x.

参数:

名称 类型 描述 默认
x Tensor

Its shape is (N, T, idim).

必需
x_lens Tensor

A tensor of shape (batch_size,) containing the number of frames in

必需

返回:

类型 描述
Tensor
  • a tensor of shape (N, (T-7)//2, odim)
Tensor
  • output lengths, of shape (batch_size,)
Tensor
  • updated cache
源代码位于: zipformer/modules/subsampling.py
def streaming_forward(
    self,
    x: torch.Tensor,
    x_lens: torch.Tensor,
    cached_left_pad: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Subsample x.

    Args:
      x:
        Its shape is (N, T, idim).
      x_lens:
        A tensor of shape (batch_size,) containing the number of frames in

    Returns:
      - a tensor of shape (N, (T-7)//2, odim)
      - output lengths, of shape (batch_size,)
      - updated cache
    """
    # On entry, x is (N, T, idim)
    x = x.unsqueeze(1)  # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)

    # T' = (T-7)//2
    x = self.conv(x)

    # T' = (T-7)//2-3
    x, cached_left_pad = self.convnext.streaming_forward(
        x, cached_left_pad=cached_left_pad
    )

    # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
    b, c, t, f = x.size()

    x = x.transpose(1, 2).reshape(b, t, c * f)
    # now x: (N, T', out_width * layer3_channels))

    x = self.out(x)
    # Now x is of shape (N, T', odim)
    x = self.out_norm(x)

    if torch.jit.is_scripting() or torch.jit.is_tracing():
        assert self.convnext.padding[0] == 3
        # The ConvNeXt module needs 3 frames of right padding after subsampling
        x_lens = (x_lens - 7) // 2 - 3
    else:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            # The ConvNeXt module needs 3 frames of right padding after subsampling
            assert self.convnext.padding[0] == 3
            x_lens = (x_lens - 7) // 2 - 3

    assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())

    return x, x_lens, cached_left_pad

get_init_states(batch_size=1, device=torch.device('cpu'))

Get initial states for Conv2dSubsampling module. It is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs)

源代码位于: zipformer/modules/subsampling.py
@torch.jit.export
def get_init_states(
    self,
    batch_size: int = 1,
    device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """Get initial states for Conv2dSubsampling module.
    It is the cached left padding for ConvNeXt module,
    of shape (batch_size, num_channels, left_pad, num_freqs)
    """
    left_pad = self.convnext.padding[0]
    freq = self.out_width
    channels = self.layer3_channels
    cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
        device
    )

    return cached_embed_left_pad

Utilities

zipformer.utils.utils

SymbolTable dataclass

Bases: Generic[Symbol]

SymbolTable that maps symbol IDs, found on the FSA arcs to actual objects. These objects can be arbitrary Python objects that can serve as keys in a dictionary (i.e. they need to be hashable and immutable).

The SymbolTable can only be read to/written from disk if the symbols are strings.

源代码位于: zipformer/utils/utils.py
@dataclass(repr=False)
class SymbolTable(Generic[Symbol]):
    """SymbolTable that maps symbol IDs, found on the FSA arcs to
    actual objects. These objects can be arbitrary Python objects
    that can serve as keys in a dictionary (i.e. they need to be
    hashable and immutable).

    The SymbolTable can only be read to/written from disk if the
    symbols are strings.
    """

    _id2sym: Dict[int, Symbol] = field(default_factory=dict)
    """Map an integer to a symbol.
    """

    _sym2id: Dict[Symbol, int] = field(default_factory=dict)
    """Map a symbol to an integer.
    """

    _next_available_id: int = 1
    """A helper internal field that helps adding new symbols
    to the table efficiently.
    """

    eps: Symbol = "<eps>"
    """Null symbol, always mapped to index 0.
    """

    def __post_init__(self):
        for idx, sym in self._id2sym.items():
            assert self._sym2id[sym] == idx
            assert idx >= 0

        for sym, idx in self._sym2id.items():
            assert idx >= 0
            assert self._id2sym[idx] == sym

        if 0 not in self._id2sym:
            self._id2sym[0] = self.eps
            self._sym2id[self.eps] = 0
        else:
            assert self._id2sym[0] == self.eps
            assert self._sym2id[self.eps] == 0

        self._next_available_id = max(self._id2sym) + 1

    @staticmethod
    def from_str(s: str) -> "SymbolTable":
        """Build a symbol table from a string.

        The string consists of lines. Every line has two fields separated
        by space(s), tab(s) or both. The first field is the symbol and the
        second the integer id of the symbol.

        Args:
          s:
            The input string with the format described above.
        Returns:
          An instance of :class:`SymbolTable`.
        """
        id2sym: Dict[int, str] = dict()
        sym2id: Dict[str, int] = dict()

        for line in s.split("\n"):
            fields = line.split()
            if len(fields) == 0:
                continue  # skip empty lines
            assert len(fields) == 2, (
                f"Expect a line with 2 fields. Given: {len(fields)}"
            )
            sym, idx = fields[0], int(fields[1])
            assert sym not in sym2id, f"Duplicated symbol {sym}"
            assert idx not in id2sym, f"Duplicated id {idx}"
            id2sym[idx] = sym
            sym2id[sym] = idx

        eps = id2sym.get(0, "<eps>")

        return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)

    @staticmethod
    def from_file(filename: str) -> "SymbolTable":
        """Build a symbol table from file.

        Every line in the symbol table file has two fields separated by
        space(s), tab(s) or both. The following is an example file:

        .. code-block::

            <eps> 0
            a 1
            b 2
            c 3

        Args:
          filename:
            Name of the symbol table file. Its format is documented above.

        Returns:
          An instance of :class:`SymbolTable`.

        """
        with open(filename, "r", encoding="utf-8") as f:
            return SymbolTable.from_str(f.read().strip())

    def to_str(self) -> str:
        """
        Returns:
          Return a string representation of this object. You can pass
          it to the method ``from_str`` to recreate an identical object.
        """
        s = ""
        for idx, symbol in sorted(self._id2sym.items()):
            s += f"{symbol} {idx}\n"
        return s

    def to_file(self, filename: str):
        """Serialize the SymbolTable to a file.

        Every line in the symbol table file has two fields separated by
        space(s), tab(s) or both. The following is an example file:

        .. code-block::

            <eps> 0
            a 1
            b 2
            c 3

        Args:
          filename:
            Name of the symbol table file. Its format is documented above.
        """
        with open(filename, "w") as f:
            for idx, symbol in sorted(self._id2sym.items()):
                print(symbol, idx, file=f)

    def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
        """Add a new symbol to the SymbolTable.

        Args:
            symbol:
                The symbol to be added.
            index:
                Optional int id to which the symbol should be assigned.
                If it is not available, a ValueError will be raised.

        Returns:
            The int id to which the symbol has been assigned.
        """
        # Already in the table? Return its ID.
        if symbol in self._sym2id:
            return self._sym2id[symbol]
        # Specific ID not provided - use next available.
        if index is None:
            index = self._next_available_id
        # Specific ID provided but not available.
        if index in self._id2sym:
            raise ValueError(
                f"Cannot assign id '{index}' to '{symbol}' - "
                f"already occupied by {self._id2sym[index]}"
            )
        self._sym2id[symbol] = index
        self._id2sym[index] = symbol

        # Update next available ID if needed
        if self._next_available_id <= index:
            self._next_available_id = index + 1

        return index

    def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
        """Get a symbol for an id or get an id for a symbol

        Args:
          k:
            If it is an id, it tries to find the symbol corresponding
            to the id; if it is a symbol, it tries to find the id
            corresponding to the symbol.

        Returns:
          An id or a symbol depending on the given `k`.
        """
        if isinstance(k, int):
            return self._id2sym[k]
        else:
            return self._sym2id[k]

    def merge(self, other: "SymbolTable") -> "SymbolTable":
        """Create a union of two SymbolTables.
        Raises an AssertionError if the same IDs are occupied by
        different symbols.

        Args:
            other:
                A symbol table to merge with ``self``.

        Returns:
            A new symbol table.
        """
        self._check_compatible(other)

        id2sym = {**self._id2sym, **other._id2sym}
        sym2id = {**self._sym2id, **other._sym2id}

        return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)

    def _check_compatible(self, other: "SymbolTable") -> None:
        # Epsilon compatibility
        assert self.eps == other.eps, (
            f"Mismatched epsilon symbol: {self.eps} != {other.eps}"
        )
        # IDs compatibility
        common_ids = set(self._id2sym).intersection(other._id2sym)
        for idx in common_ids:
            assert self[idx] == other[idx], (
                f"ID conflict for id: {idx}, "
                f'self[idx] = "{self[idx]}", '
                f'other[idx] = "{other[idx]}"'
            )
        # Symbols compatibility
        common_symbols = set(self._sym2id).intersection(other._sym2id)
        for sym in common_symbols:
            assert self[sym] == other[sym], (
                f"ID conflict for id: {sym}, "
                f'self[sym] = "{self[sym]}", '
                f'other[sym] = "{other[sym]}"'
            )

    def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
        return self.get(item)

    def __contains__(self, item: Union[int, Symbol]) -> bool:
        if isinstance(item, int):
            return item in self._id2sym
        else:
            return item in self._sym2id

    def __len__(self) -> int:
        return len(self._id2sym)

    def __eq__(self, other: "SymbolTable") -> bool:
        if len(self) != len(other):
            return False

        for s in self.symbols:
            if self[s] != other[s]:
                return False

        return True

    @property
    def ids(self) -> List[int]:
        """Returns a list of integer IDs corresponding to the symbols."""
        ans = list(self._id2sym.keys())
        ans.sort()
        return ans

    @property
    def symbols(self) -> List[Symbol]:
        """Returns a list of symbols (e.g., strings) corresponding to
        the integer IDs.
        """
        ans = list(self._sym2id.keys())
        ans.sort()
        return ans

eps = '<eps>' class-attribute instance-attribute

Null symbol, always mapped to index 0.

ids property

Returns a list of integer IDs corresponding to the symbols.

symbols property

Returns a list of symbols (e.g., strings) corresponding to the integer IDs.

from_str(s) staticmethod

Build a symbol table from a string.

The string consists of lines. Every line has two fields separated by space(s), tab(s) or both. The first field is the symbol and the second the integer id of the symbol.

参数:

名称 类型 描述 默认
s str

The input string with the format described above.

必需

Returns: An instance of :class:SymbolTable.

源代码位于: zipformer/utils/utils.py
@staticmethod
def from_str(s: str) -> "SymbolTable":
    """Build a symbol table from a string.

    The string consists of lines. Every line has two fields separated
    by space(s), tab(s) or both. The first field is the symbol and the
    second the integer id of the symbol.

    Args:
      s:
        The input string with the format described above.
    Returns:
      An instance of :class:`SymbolTable`.
    """
    id2sym: Dict[int, str] = dict()
    sym2id: Dict[str, int] = dict()

    for line in s.split("\n"):
        fields = line.split()
        if len(fields) == 0:
            continue  # skip empty lines
        assert len(fields) == 2, (
            f"Expect a line with 2 fields. Given: {len(fields)}"
        )
        sym, idx = fields[0], int(fields[1])
        assert sym not in sym2id, f"Duplicated symbol {sym}"
        assert idx not in id2sym, f"Duplicated id {idx}"
        id2sym[idx] = sym
        sym2id[sym] = idx

    eps = id2sym.get(0, "<eps>")

    return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)

from_file(filename) staticmethod

Build a symbol table from file.

Every line in the symbol table file has two fields separated by space(s), tab(s) or both. The following is an example file:

.. code-block::

<eps> 0
a 1
b 2
c 3

参数:

名称 类型 描述 默认
filename str

Name of the symbol table file. Its format is documented above.

必需

返回:

类型 描述
SymbolTable

An instance of :class:SymbolTable.

源代码位于: zipformer/utils/utils.py
@staticmethod
def from_file(filename: str) -> "SymbolTable":
    """Build a symbol table from file.

    Every line in the symbol table file has two fields separated by
    space(s), tab(s) or both. The following is an example file:

    .. code-block::

        <eps> 0
        a 1
        b 2
        c 3

    Args:
      filename:
        Name of the symbol table file. Its format is documented above.

    Returns:
      An instance of :class:`SymbolTable`.

    """
    with open(filename, "r", encoding="utf-8") as f:
        return SymbolTable.from_str(f.read().strip())

to_str()

返回:

类型 描述
str

Return a string representation of this object. You can pass

str

it to the method from_str to recreate an identical object.

源代码位于: zipformer/utils/utils.py
def to_str(self) -> str:
    """
    Returns:
      Return a string representation of this object. You can pass
      it to the method ``from_str`` to recreate an identical object.
    """
    s = ""
    for idx, symbol in sorted(self._id2sym.items()):
        s += f"{symbol} {idx}\n"
    return s

to_file(filename)

Serialize the SymbolTable to a file.

Every line in the symbol table file has two fields separated by space(s), tab(s) or both. The following is an example file:

.. code-block::

<eps> 0
a 1
b 2
c 3

参数:

名称 类型 描述 默认
filename str

Name of the symbol table file. Its format is documented above.

必需
源代码位于: zipformer/utils/utils.py
def to_file(self, filename: str):
    """Serialize the SymbolTable to a file.

    Every line in the symbol table file has two fields separated by
    space(s), tab(s) or both. The following is an example file:

    .. code-block::

        <eps> 0
        a 1
        b 2
        c 3

    Args:
      filename:
        Name of the symbol table file. Its format is documented above.
    """
    with open(filename, "w") as f:
        for idx, symbol in sorted(self._id2sym.items()):
            print(symbol, idx, file=f)

add(symbol, index=None)

Add a new symbol to the SymbolTable.

参数:

名称 类型 描述 默认
symbol Symbol

The symbol to be added.

必需
index Optional[int]

Optional int id to which the symbol should be assigned. If it is not available, a ValueError will be raised.

None

返回:

类型 描述
int

The int id to which the symbol has been assigned.

源代码位于: zipformer/utils/utils.py
def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
    """Add a new symbol to the SymbolTable.

    Args:
        symbol:
            The symbol to be added.
        index:
            Optional int id to which the symbol should be assigned.
            If it is not available, a ValueError will be raised.

    Returns:
        The int id to which the symbol has been assigned.
    """
    # Already in the table? Return its ID.
    if symbol in self._sym2id:
        return self._sym2id[symbol]
    # Specific ID not provided - use next available.
    if index is None:
        index = self._next_available_id
    # Specific ID provided but not available.
    if index in self._id2sym:
        raise ValueError(
            f"Cannot assign id '{index}' to '{symbol}' - "
            f"already occupied by {self._id2sym[index]}"
        )
    self._sym2id[symbol] = index
    self._id2sym[index] = symbol

    # Update next available ID if needed
    if self._next_available_id <= index:
        self._next_available_id = index + 1

    return index

get(k)

Get a symbol for an id or get an id for a symbol

参数:

名称 类型 描述 默认
k Union[int, Symbol]

If it is an id, it tries to find the symbol corresponding to the id; if it is a symbol, it tries to find the id corresponding to the symbol.

必需

返回:

类型 描述
Union[Symbol, int]

An id or a symbol depending on the given k.

源代码位于: zipformer/utils/utils.py
def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
    """Get a symbol for an id or get an id for a symbol

    Args:
      k:
        If it is an id, it tries to find the symbol corresponding
        to the id; if it is a symbol, it tries to find the id
        corresponding to the symbol.

    Returns:
      An id or a symbol depending on the given `k`.
    """
    if isinstance(k, int):
        return self._id2sym[k]
    else:
        return self._sym2id[k]

merge(other)

Create a union of two SymbolTables. Raises an AssertionError if the same IDs are occupied by different symbols.

参数:

名称 类型 描述 默认
other SymbolTable

A symbol table to merge with self.

必需

返回:

类型 描述
SymbolTable

A new symbol table.

源代码位于: zipformer/utils/utils.py
def merge(self, other: "SymbolTable") -> "SymbolTable":
    """Create a union of two SymbolTables.
    Raises an AssertionError if the same IDs are occupied by
    different symbols.

    Args:
        other:
            A symbol table to merge with ``self``.

    Returns:
        A new symbol table.
    """
    self._check_compatible(other)

    id2sym = {**self._id2sym, **other._id2sym}
    sym2id = {**self._sym2id, **other._sym2id}

    return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)

num_tokens(token_table, disambig_pattern=re.compile('^#\\d+$'))

Return the number of tokens excluding those from disambiguation symbols.

Caution

0 is not a token ID so it is excluded from the return value.

源代码位于: zipformer/utils/utils.py
def num_tokens(
    token_table: SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
    """Return the number of tokens excluding those from
    disambiguation symbols.

    Caution:
      0 is not a token ID so it is excluded from the return value.
    """
    symbols = token_table.symbols
    ans = []
    for s in symbols:
        if not disambig_pattern.match(s):
            ans.append(token_table[s])
    num_tokens = len(ans)
    if 0 in ans:
        num_tokens -= 1
    return num_tokens

token_ids_to_text(token_ids, token_table)

Convert token IDs to text using a SymbolTable.

Supports byte-level BPE tokens in the format <0xNN>.

源代码位于: zipformer/utils/utils.py
def token_ids_to_text(token_ids: List[int], token_table: SymbolTable) -> str:
    """Convert token IDs to text using a SymbolTable.

    Supports byte-level BPE tokens in the format <0xNN>.
    """
    text = b""
    for i in token_ids:
        token = token_table[i]
        if len(token) >= 4 and token[:3] == "<0x" and token[-1] == ">":
            byte_val = int(token[1:-1], base=16)
            text += byte_val.to_bytes(1, byteorder="little")
        else:
            text += token.encode(encoding="utf-8")
    return text.decode(encoding="utf-8").replace("▁", " ").strip()

is_module_available(*modules)

Returns if a top-level module with :attr:name exists without* importing it. This is generally safer than try-catch block around a import X.

Note: "borrowed" from torchaudio:

源代码位于: zipformer/utils/utils.py
def is_module_available(*modules: str) -> bool:
    r"""Returns if a top-level module with :attr:`name` exists *without**
    importing it. This is generally safer than try-catch block around a
    `import X`.

    Note: "borrowed" from torchaudio:
    """
    import importlib

    return all(importlib.util.find_spec(m) is not None for m in modules)

stack_states(state_list)

Stack list of zipformer states that correspond to separate utterances into a single emformer state, so that it can be used as an input for zipformer when those utterances are formed into a batch.

参数:

名称 类型 描述 默认
state_list List[List[Tensor]]

Each element in state_list corresponding to the internal state of the zipformer model for a single utterance. For element-n, state_list[n] is a list of cached tensors of all encoder layers. For layer-i, state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). state_list[n][-2] is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) state_list[n][-1] is processed_lens of shape (batch,), which records the number of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.

必需
Note

It is the inverse of :func:unstack_states.

源代码位于: zipformer/utils/utils.py
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
    """Stack list of zipformer states that correspond to separate utterances
    into a single emformer state, so that it can be used as an input for
    zipformer when those utterances are formed into a batch.

    Args:
      state_list:
        Each element in state_list corresponding to the internal state
        of the zipformer model for a single utterance. For element-n,
        state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
        state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
        cached_val2, cached_conv1, cached_conv2).
        state_list[n][-2] is the cached left padding for ConvNeXt module,
          of shape (batch_size, num_channels, left_pad, num_freqs)
        state_list[n][-1] is processed_lens of shape (batch,), which records the number
        of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.

    Note:
      It is the inverse of :func:`unstack_states`.
    """
    batch_size = len(state_list)
    assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
    tot_num_layers = (len(state_list[0]) - 2) // 6

    batch_states = []
    for layer in range(tot_num_layers):
        layer_offset = layer * 6
        # cached_key: (left_context_len, batch_size, key_dim)
        cached_key = torch.cat(
            [state_list[i][layer_offset] for i in range(batch_size)], dim=1
        )
        # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
        cached_nonlin_attn = torch.cat(
            [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
        )
        # cached_val1: (left_context_len, batch_size, value_dim)
        cached_val1 = torch.cat(
            [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
        )
        # cached_val2: (left_context_len, batch_size, value_dim)
        cached_val2 = torch.cat(
            [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
        )
        # cached_conv1: (#batch, channels, left_pad)
        cached_conv1 = torch.cat(
            [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
        )
        # cached_conv2: (#batch, channels, left_pad)
        cached_conv2 = torch.cat(
            [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
        )
        batch_states += [
            cached_key,
            cached_nonlin_attn,
            cached_val1,
            cached_val2,
            cached_conv1,
            cached_conv2,
        ]

    cached_embed_left_pad = torch.cat(
        [state_list[i][-2] for i in range(batch_size)], dim=0
    )
    batch_states.append(cached_embed_left_pad)

    processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
    batch_states.append(processed_lens)

    return batch_states

unstack_states(batch_states)

Unstack the zipformer state corresponding to a batch of utterances into a list of states, where the i-th entry is the state from the i-th utterance in the batch.

Note

It is the inverse of :func:stack_states.

参数:

名称 类型 描述 默认
batch_states List[Tensor]

A list of cached tensors of all encoder layers. For layer-i, states[i6:(i+1)6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). state_list[-2] is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) states[-1] is processed_lens of shape (batch,), which records the number of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.

必需

返回:

名称 类型 描述
state_list List[List[Tensor]]

A list of list. Each element in state_list corresponding to the internal state

List[List[Tensor]]

of the zipformer model for a single utterance.

源代码位于: zipformer/utils/utils.py
def unstack_states(batch_states: List[torch.Tensor]) -> List[List[torch.Tensor]]:
    """Unstack the zipformer state corresponding to a batch of utterances
    into a list of states, where the i-th entry is the state from the i-th
    utterance in the batch.

    Note:
      It is the inverse of :func:`stack_states`.

    Args:
        batch_states: A list of cached tensors of all encoder layers. For layer-i,
          states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
          cached_conv1, cached_conv2).
          state_list[-2] is the cached left padding for ConvNeXt module,
          of shape (batch_size, num_channels, left_pad, num_freqs)
          states[-1] is processed_lens of shape (batch,), which records the number
          of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.

    Returns:
        state_list: A list of list. Each element in state_list corresponding to the internal state
        of the zipformer model for a single utterance.
    """
    assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
    tot_num_layers = (len(batch_states) - 2) // 6

    processed_lens = batch_states[-1]
    batch_size = processed_lens.shape[0]

    state_list = [[] for _ in range(batch_size)]

    for layer in range(tot_num_layers):
        layer_offset = layer * 6
        # cached_key: (left_context_len, batch_size, key_dim)
        cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
        # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
        cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
            chunks=batch_size, dim=1
        )
        # cached_val1: (left_context_len, batch_size, value_dim)
        cached_val1_list = batch_states[layer_offset + 2].chunk(
            chunks=batch_size, dim=1
        )
        # cached_val2: (left_context_len, batch_size, value_dim)
        cached_val2_list = batch_states[layer_offset + 3].chunk(
            chunks=batch_size, dim=1
        )
        # cached_conv1: (#batch, channels, left_pad)
        cached_conv1_list = batch_states[layer_offset + 4].chunk(
            chunks=batch_size, dim=0
        )
        # cached_conv2: (#batch, channels, left_pad)
        cached_conv2_list = batch_states[layer_offset + 5].chunk(
            chunks=batch_size, dim=0
        )
        for i in range(batch_size):
            state_list[i] += [
                cached_key_list[i],
                cached_nonlin_attn_list[i],
                cached_val1_list[i],
                cached_val2_list[i],
                cached_conv1_list[i],
                cached_conv2_list[i],
            ]

    cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
    for i in range(batch_size):
        state_list[i].append(cached_embed_left_pad_list[i])

    processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
    for i in range(batch_size):
        state_list[i].append(processed_lens_list[i])

    return state_list

pad_sequences(seq, padding_value, sos_id=None, eos_id=None, device=None)

Pad a list of sequences to the same length with a specified padding value. Optionally, add SOS and EOS tokens. Args: seq: A list of sequences, where each sequence is a list of integers. padding_value: The value to use for padding. sos_id: If not None, the ID to use for the start-of-sequence token. If None, no SOS token will be added. eos_id: If not None, the ID to use for the end-of-sequence token. If None, no EOS token will be added. device: The device on which to create the output tensor. If None, the output tensor will be created on the CPU.

返回:

类型 描述
Tensor

A tuple of two tensors:

Tensor
  • A tensor of shape (batch_size, max_len) with the padded sequences.
Tuple[Tensor, Tensor]
  • A tensor of shape (batch_size,) with the lengths of each sequence.
源代码位于: zipformer/utils/utils.py
def pad_sequences(
    seq: List[List[int]],
    padding_value: int,
    sos_id: Optional[int] = None,
    eos_id: Optional[int] = None,
    device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Pad a list of sequences to the same length with a specified padding value.
    Optionally, add SOS and EOS tokens.
    Args:
        seq: A list of sequences, where each sequence is a list of integers.
        padding_value: The value to use for padding.
        sos_id: If not None, the ID to use for the start-of-sequence token.
                If None, no SOS token will be added.
        eos_id: If not None, the ID to use for the end-of-sequence token.
                If None, no EOS token will be added.
        device: The device on which to create the output tensor.
                If None, the output tensor will be created on the CPU.

    Returns:
        A tuple of two tensors:
        - A tensor of shape (batch_size, max_len) with the padded sequences.
        - A tensor of shape (batch_size,) with the lengths of each sequence.
    """
    batch_size = len(seq)
    seq_lens = []
    max_len = 0
    for s in seq:
        length = len(s)
        if sos_id is not None:
            length += 1
        if eos_id is not None:
            length += 1
        seq_lens.append(length)
        if length > max_len:
            max_len = length
    out = torch.full(
        (batch_size, max_len),
        fill_value=padding_value,
        dtype=torch.int64,
        device=device,
    )
    for i, s in enumerate(seq):
        if sos_id is not None:
            out[i, 0] = sos_id
        if len(s) > 0:
            out[
                i,
                (1 if sos_id is not None else 0) : (
                    len(s) + (1 if sos_id is not None else 0)
                ),
            ] = torch.tensor(s, dtype=torch.int64, device=device)
        if eos_id is not None:
            out[i, len(s) + (1 if sos_id is not None else 0)] = eos_id
    out_lens = torch.tensor(seq_lens, dtype=torch.int64, device=device)
    return out, out_lens

评论

如果您通过 github 登录评论有困难, 您可以加入我们的微信QQ群与广大开发者一起交流,也欢迎大家关注我们的微信公众号