diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 7f85aba69427486a7e99eae79945f16d683b4b96..27a3d478b2ee9556ecab15890fca0b6c2718bc12 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1001,7 +1001,7 @@ class Dataset: if isinstance(sampler, samplers.DistributedSampler): dev_id = sampler.shard_id return "", dev_id - if isinstance(output_dataset, TFRecordDataset): + if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)): if output_dataset.shard_id is not None: dev_id = output_dataset.shard_id return "", dev_id diff --git a/tests/ut/python/dataset/test_datasets_clue.py b/tests/ut/python/dataset/test_datasets_clue.py index c49db45abe3b2efff27949ef94eeb79aade4e11e..e1959acb4267e2e2f803bbc5a24bd5e482dcd1e3 100644 --- a/tests/ut/python/dataset/test_datasets_clue.py +++ b/tests/ut/python/dataset/test_datasets_clue.py @@ -344,6 +344,15 @@ def test_clue_wsc(): }) assert len(buffer) == 3 +def test_clue_to_device(): + """ + Test CLUE with to_device + """ + TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' + data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) + data = data.to_device() + data.send() + if __name__ == "__main__": test_clue() diff --git a/tests/ut/python/dataset/test_datasets_textfileop.py b/tests/ut/python/dataset/test_datasets_textfileop.py index a1d19d88e40d5a8255404152809b3515efbca844..1732c1817d7c1743d7d3e959140d7c0e2f60963e 100644 --- a/tests/ut/python/dataset/test_datasets_textfileop.py +++ b/tests/ut/python/dataset/test_datasets_textfileop.py @@ -89,6 +89,10 @@ def test_textline_dataset_get_datasetsize(): size = data.get_dataset_size() assert size == 3 +def test_textline_dataset_to_device(): + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.to_device() + data.send() if __name__ == "__main__": test_textline_dataset_one_file()