提交 bbf69912 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2363 fix TextFildDataset and CLUEDataset does not support to_device

Merge pull request !2363 from yanghaitao/yht_fix_textfiledataset
......@@ -1001,7 +1001,7 @@ class Dataset:
if isinstance(sampler, samplers.DistributedSampler):
dev_id = sampler.shard_id
return "", dev_id
if isinstance(output_dataset, TFRecordDataset):
if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)):
if output_dataset.shard_id is not None:
dev_id = output_dataset.shard_id
return "", dev_id
......
......@@ -344,6 +344,15 @@ def test_clue_wsc():
})
assert len(buffer) == 3
def test_clue_to_device():
"""
Test CLUE with to_device
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
data = data.to_device()
data.send()
if __name__ == "__main__":
test_clue()
......
......@@ -89,6 +89,10 @@ def test_textline_dataset_get_datasetsize():
size = data.get_dataset_size()
assert size == 3
def test_textline_dataset_to_device():
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
data = data.to_device()
data.send()
if __name__ == "__main__":
test_textline_dataset_one_file()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册