From 1cdd41bd03488b38c6082c766bf819b6bc94f61c Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 24 May 2022 09:46:12 +0000 Subject: [PATCH] fix pad_sequence, test=asr --- paddlespeech/s2t/utils/tensor_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddlespeech/s2t/utils/tensor_utils.py b/paddlespeech/s2t/utils/tensor_utils.py index 0dbaa0b6..e105253c 100644 --- a/paddlespeech/s2t/utils/tensor_utils.py +++ b/paddlespeech/s2t/utils/tensor_utils.py @@ -82,7 +82,7 @@ def pad_sequence(sequences: List[paddle.Tensor], max_size = sequences[0].size() # (TODO Hui Zhang): slice not supprot `end==start` # trailing_dims = max_size[1:] - trailing_dims = max_size[1:] if max_size.ndim >= 2 else () + trailing_dims = tuple(max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else () max_len = max([s.shape[0] for s in sequences]) if batch_first: out_dims = (len(sequences), max_len) + trailing_dims @@ -99,7 +99,7 @@ def pad_sequence(sequences: List[paddle.Tensor], if batch_first: # TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not support int16 - # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] + # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] # out_tensor[i, :length, ...] = tensor if length != 0: out_tensor[i, :length] = tensor @@ -145,7 +145,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, [ 4, 5, 6, 11, -1, -1], [ 7, 8, 9, 11, -1, -1]]) """ - # TODO(Hui Zhang): using comment code, + # TODO(Hui Zhang): using comment code, #_sos = paddle.to_tensor( # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) #_eos = paddle.to_tensor( -- GitLab