diff --git a/paddlehub/datasets/base_nlp_dataset.py b/paddlehub/datasets/base_nlp_dataset.py index c4cebdda83cd96595c41c480d700944292fb4061..c9e425ad7c3ccd212f90ccdf526e359f0f9d676a 100644 --- a/paddlehub/datasets/base_nlp_dataset.py +++ b/paddlehub/datasets/base_nlp_dataset.py @@ -276,7 +276,7 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset): if Version(paddlenlp.__version__) >= Version('2.0.0rc5'): token_type_ids = np.array(record['token_type_ids']) else: - token_type_ids = record['segment_ids'] + token_type_ids = np.array(record['segment_ids']) if 'label' in record.keys(): return input_ids, token_type_ids, np.array(record['label'], dtype=np.int64)