diff --git a/PaddleNLP/preprocess/padding.py b/PaddleNLP/preprocess/padding.py index 6094562d396181349bebac9e883f6fca9dc71afc..82171e68eb3af3513eaf4655c740a06bb1112d57 100644 --- a/PaddleNLP/preprocess/padding.py +++ b/PaddleNLP/preprocess/padding.py @@ -69,7 +69,7 @@ def pad_batch_data(insts, if return_seq_lens: seq_lens = np.array([len(inst) for inst in insts]) - return_list += [seq_lens.astype("int64").reshape([-1, 1])] + return_list += [seq_lens.astype("int64").reshape([-1])] return return_list if len(return_list) > 1 else return_list[0]