diff --git a/paddlespeech/s2t/utils/tensor_utils.py b/paddlespeech/s2t/utils/tensor_utils.py index 0dbaa0b6b77031d4b8e8aa29fcc9246458b8ab99..e105253c2116f65d4e24508a8cc281d84d7926b3 100644 --- a/paddlespeech/s2t/utils/tensor_utils.py +++ b/paddlespeech/s2t/utils/tensor_utils.py @@ -82,7 +82,7 @@ def pad_sequence(sequences: List[paddle.Tensor], max_size = sequences[0].size() # (TODO Hui Zhang): slice not supprot `end==start` # trailing_dims = max_size[1:] - trailing_dims = max_size[1:] if max_size.ndim >= 2 else () + trailing_dims = tuple(max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else () max_len = max([s.shape[0] for s in sequences]) if batch_first: out_dims = (len(sequences), max_len) + trailing_dims @@ -99,7 +99,7 @@ def pad_sequence(sequences: List[paddle.Tensor], if batch_first: # TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not support int16 - # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] + # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] # out_tensor[i, :length, ...] = tensor if length != 0: out_tensor[i, :length] = tensor @@ -145,7 +145,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, [ 4, 5, 6, 11, -1, -1], [ 7, 8, 9, 11, -1, -1]]) """ - # TODO(Hui Zhang): using comment code, + # TODO(Hui Zhang): using comment code, #_sos = paddle.to_tensor( # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) #_eos = paddle.to_tensor(