未验证 提交 cd14bbe0 编写于 作者: S Steffy-zxf 提交者: GitHub

Fix data type error

上级 3eb4d334
......@@ -276,7 +276,7 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
if Version(paddlenlp.__version__) >= Version('2.0.0rc5'):
token_type_ids = np.array(record['token_type_ids'])
else:
token_type_ids = record['segment_ids']
token_type_ids = np.array(record['segment_ids'])
if 'label' in record.keys():
return input_ids, token_type_ids, np.array(record['label'], dtype=np.int64)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册