From e81849277ede018e575007820c7573c5db13c480 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 8 Jul 2022 09:36:26 +0000 Subject: [PATCH] att cache for streaming asr --- .../local/rtf_from_log.py | 2 +- paddlespeech/s2t/models/u2_st/u2_st.py | 47 ++++++++++++------- paddlespeech/s2t/modules/attention.py | 7 ++- paddlespeech/s2t/modules/encoder.py | 3 +- .../engine/asr/online/python/asr_engine.py | 12 ++--- 5 files changed, 41 insertions(+), 30 deletions(-) diff --git a/demos/streaming_asr_server/local/rtf_from_log.py b/demos/streaming_asr_server/local/rtf_from_log.py index 4f30d640..4b89b48f 100755 --- a/demos/streaming_asr_server/local/rtf_from_log.py +++ b/demos/streaming_asr_server/local/rtf_from_log.py @@ -38,4 +38,4 @@ if __name__ == '__main__': T += m['T'] P += m['P'] - print(f"RTF: {P/T}") + print(f"RTF: {P/T}, utts: {n}") diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 00ded912..e86bbedf 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -401,29 +401,42 @@ class U2STBaseModel(nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - subsampling_cache: Optional[paddle.Tensor]=None, - elayers_output_cache: Optional[List[paddle.Tensor]]=None, - conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ - paddle.Tensor]]: + att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Export interface for c++ call, give input chunk xs, and return output from time 0 to current chunk. + Args: - xs (paddle.Tensor): chunk input - subsampling_cache (Optional[paddle.Tensor]): subsampling cache - elayers_output_cache (Optional[List[paddle.Tensor]]): - transformer/conformer encoder layers output cache - conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer - cnn cache + xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (paddle.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + `d_k * 2` for att key & value. + cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + Returns: - paddle.Tensor: output, it ranges from time 0 to current chunk. - paddle.Tensor: subsampling cache - List[paddle.Tensor]: attention cache - List[paddle.Tensor]: conformer cnn cache + paddle.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + paddle.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, T(?), d_k * 2) + depending on required_cache_size. + paddle.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. """ return self.encoder.forward_chunk( - xs, offset, required_cache_size, subsampling_cache, - elayers_output_cache, conformer_cnn_cache) + xs, offset, required_cache_size, att_cache, cnn_cache) # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index c0b76f08..454f9c14 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -181,8 +181,7 @@ class MultiHeadedAttention(nn.Layer): # >>> torch.equal(d[0], d[1]) # True if paddle.shape(cache)[0] > 0: # last dim `d_k * 2` for (key, val) - key_cache, value_cache = paddle.split( - cache, paddle.shape(cache)[-1] // 2, axis=-1) + key_cache, value_cache = paddle.split(cache, 2, axis=-1) k = paddle.concat([key_cache, k], axis=2) v = paddle.concat([value_cache, v], axis=2) # We do cache slicing in encoder.forward_chunk, since it's @@ -289,8 +288,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True if paddle.shape(cache)[0] > 0: - key_cache, value_cache = paddle.split( - cache, paddle.shape(cache)[-1] // 2, axis=-1) + # last dim `d_k * 2` for (key, val) + key_cache, value_cache = paddle.split(cache, 2, axis=-1) k = paddle.concat([key_cache, k], axis=2) v = paddle.concat([value_cache, v], axis=2) # We do cache slicing in encoder.forward_chunk, since it's diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index e05d0cc4..72300579 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -230,7 +230,8 @@ class BaseEncoder(nn.Layer): xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) # after embed, xs=(B=1, chunk_size, hidden-dim) - elayers, cache_t1 = paddle.shape(att_cache)[0], paddle.shape(att_cache)[2] + elayers = paddle.shape(att_cache)[0] + cache_t1 = paddle.shape(att_cache)[2] chunk_size = paddle.shape(xs)[1] attention_key_size = cache_t1 + chunk_size diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 2bacfecd..4df38f09 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -130,9 +130,9 @@ class PaddleASRConnectionHanddler: ## conformer # cache for conformer online - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None + self.att_cache = paddle.zeros([0,0,0,0]) + self.cnn_cache = paddle.zeros([0,0,0,0]) + self.encoder_out = None # conformer decoding state self.offset = 0 # global offset in decoding frame unit @@ -474,11 +474,9 @@ class PaddleASRConnectionHanddler: # cur chunk chunk_xs = self.cached_feat[:, cur:end, :] # forward chunk - (y, self.subsampling_cache, self.elayers_output_cache, - self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + (y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size, - self.subsampling_cache, self.elayers_output_cache, - self.conformer_cnn_cache) + self.att_cache, self.cnn_cache) outputs.append(y) # update the global offset, in decoding frame unit -- GitLab