提交 ea297c08 编写于 作者: A anthonyaje

Fix dataset serdes for MindDataset

上级 c0c0b098
...@@ -127,9 +127,12 @@ def serialize_operations(node_repr, key, val): ...@@ -127,9 +127,12 @@ def serialize_operations(node_repr, key, val):
def serialize_sampler(node_repr, val): def serialize_sampler(node_repr, val):
"""Serialize sampler object to dictionary.""" """Serialize sampler object to dictionary."""
node_repr['sampler'] = val.__dict__ if val is None:
node_repr['sampler']['sampler_module'] = type(val).__module__ node_repr['sampler'] = None
node_repr['sampler']['sampler_name'] = type(val).__name__ else:
node_repr['sampler'] = val.__dict__
node_repr['sampler']['sampler_module'] = type(val).__module__
node_repr['sampler']['sampler_name'] = type(val).__name__
def traverse(node): def traverse(node):
...@@ -253,9 +256,10 @@ def create_node(node): ...@@ -253,9 +256,10 @@ def create_node(node):
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'MindDataset': elif dataset_op == 'MindDataset':
pyobj = pyclass(node['dataset_file'], node.get('column_list'), sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_file'], node.get('columns_list'),
node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'), node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'),
node.get('shard_id'), node.get('block_reader')) node.get('shard_id'), node.get('block_reader'), sampler)
elif dataset_op == 'TFRecordDataset': elif dataset_op == 'TFRecordDataset':
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'), pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
...@@ -341,24 +345,25 @@ def create_node(node): ...@@ -341,24 +345,25 @@ def create_node(node):
def construct_sampler(in_sampler): def construct_sampler(in_sampler):
"""Instantiate Sampler object based on the information from dictionary['sampler']""" """Instantiate Sampler object based on the information from dictionary['sampler']"""
sampler_name = in_sampler['sampler_name']
sampler_module = in_sampler['sampler_module']
sampler_class = getattr(sys.modules[sampler_module], sampler_name)
sampler = None sampler = None
if sampler_name == 'DistributedSampler': if in_sampler is not None:
sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle')) sampler_name = in_sampler['sampler_name']
elif sampler_name == 'PKSampler': sampler_module = in_sampler['sampler_module']
sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle')) sampler_class = getattr(sys.modules[sampler_module], sampler_name)
elif sampler_name == 'RandomSampler': if sampler_name == 'DistributedSampler':
sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples')) sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
elif sampler_name == 'SequentialSampler': elif sampler_name == 'PKSampler':
sampler = sampler_class() sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
elif sampler_name == 'SubsetRandomSampler': elif sampler_name == 'RandomSampler':
sampler = sampler_class(in_sampler['indices']) sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples'))
elif sampler_name == 'WeightedRandomSampler': elif sampler_name == 'SequentialSampler':
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement')) sampler = sampler_class()
else: elif sampler_name == 'SubsetRandomSampler':
raise ValueError("Sampler type is unknown: " + sampler_name) sampler = sampler_class(in_sampler['indices'])
elif sampler_name == 'WeightedRandomSampler':
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
else:
raise ValueError("Sampler type is unknown: " + sampler_name)
return sampler return sampler
......
...@@ -19,7 +19,7 @@ import filecmp ...@@ -19,7 +19,7 @@ import filecmp
import glob import glob
import json import json
import os import os
import pytest
import numpy as np import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
...@@ -28,7 +28,6 @@ import mindspore.dataset.transforms.vision.c_transforms as vision ...@@ -28,7 +28,6 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore.dataset.transforms.vision import Inter from mindspore.dataset.transforms.vision import Inter
from mindspore import log as logger from mindspore import log as logger
def test_imagefolder(remove_json_files=True): def test_imagefolder(remove_json_files=True):
""" """
Test simulating resnet50 dataset pipeline. Test simulating resnet50 dataset pipeline.
...@@ -217,6 +216,38 @@ def delete_json_files(): ...@@ -217,6 +216,38 @@ def delete_json_files():
except IOError: except IOError:
logger.info("Error while deleting: {}".format(f)) logger.info("Error while deleting: {}".format(f))
# Test save load minddataset
from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME, FILES_NUM, \
FileWriter, Inter
def test_minddataset(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
indices = [1, 2, 3, 5, 7]
sampler = ds.SubsetRandomSampler(indices)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
# Serializing into python dictionary
ds1_dict = ds.serialize(data_set)
# Serializing into json object
ds1_json = json.dumps(ds1_dict, sort_keys=True)
# Reconstruct dataset pipeline from its serialized form
data_set = ds.deserialize(input_dict=ds1_dict)
ds2_dict = ds.serialize(data_set)
# Serializing into json object
ds2_json = json.dumps(ds2_dict, sort_keys=True)
assert ds1_json == ds2_json
data = get_data(CV_DIR_NAME)
assert data_set.get_dataset_size() == 10
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
assert num_iter == 5
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册