提交 50b5c803 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4701 fix concat get_dataset_size error

Merge pull request !4701 from guozhijian/fix_concat_get_dataset_size_err
......@@ -2343,8 +2343,10 @@ class ConcatDataset(DatasetOp):
Number, number of batches.
"""
if self.dataset_size is None:
children_sizes = [c.get_dataset_size() for c in self.children]
self.dataset_size = sum(children_sizes)
num_rows = 0
for _ in self.create_dict_iterator():
num_rows += 1
self.dataset_size = num_rows
return self.dataset_size
def use_sampler(self, sampler):
......
......@@ -1115,7 +1115,8 @@ class RandomAffine:
- Inter.BICUBIC, means resample method is bicubic interpolation.
fill_value (Union[tuple, int], optional): Optional fill_value to fill the area outside the transform
in the output image. Used only in Pillow versions > 5.0.0 (default=0, filling is performed).
in the output image. There must be three elements in tuple and the value of single element is [0, 255].
Used only in Pillow versions > 5.0.0 (default=0, filling is performed).
Raises:
ValueError: If degrees is negative.
......@@ -1127,6 +1128,7 @@ class RandomAffine:
TypeError: If translate is specified but is not list or a tuple of length 2.
TypeError: If scale is not a list or tuple of length 2.
TypeError: If shear is not a list or tuple of length 2 or 4.
TypeError: If fill_value is not a single integer or a 3-tuple.
Examples:
>>> py_transforms.ComposeOp([py_transforms.Decode(),
......
......@@ -225,31 +225,63 @@ def test_imagefolder_padded():
assert verify_list[9] == 6
def test_imagefolder_padded_with_decode():
DATA_DIR = "../data/dataset/testPK/data"
data = ds.ImageFolderDatasetV2(DATA_DIR)
num_shards = 5
count = 0
for shard_id in range(num_shards):
DATA_DIR = "../data/dataset/testPK/data"
data = ds.ImageFolderDatasetV2(DATA_DIR)
white_io = BytesIO()
Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
padded_sample = {}
padded_sample['image'] = np.array(bytearray(white_io), dtype='uint8')
padded_sample['label'] = np.array(-1, np.int32)
white_io = BytesIO()
Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
padded_sample = {}
padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8')
padded_sample['label'] = np.array(-1, np.int32)
white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
data2 = ds.PaddedDataset(white_samples)
data3 = data + data2
white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
data2 = ds.PaddedDataset(white_samples)
data3 = data + data2
testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
data3.use_sampler(testsampler)
data3 = data3.map(input_columns="image", operations=V_C.Decode())
shard_sample_count = 0
for ele in data3.create_dict_iterator():
print("label: {}".format(ele['label']))
count += 1
shard_sample_count += 1
assert shard_sample_count in (9, 10)
assert count == 48
def test_imagefolder_padded_with_decode_and_get_dataset_size():
num_shards = 5
count = 0
for shard_id in range(num_shards):
DATA_DIR = "../data/dataset/testPK/data"
data = ds.ImageFolderDatasetV2(DATA_DIR)
white_io = BytesIO()
Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
padded_sample = {}
padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8')
padded_sample['label'] = np.array(-1, np.int32)
white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
data2 = ds.PaddedDataset(white_samples)
data3 = data + data2
testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
data3.use_sampler(testsampler)
data3.map(input_columns="image", operations=V_C.Decode())
shard_dataset_size = data3.get_dataset_size()
data3 = data3.map(input_columns="image", operations=V_C.Decode())
shard_sample_count = 0
for ele in data3.create_dict_iterator():
print("label: {}".format(ele['label']))
count += 1
shard_sample_count += 1
assert shard_sample_count in (9, 10)
assert shard_dataset_size == shard_sample_count
assert count == 48
def test_more_shard_padded():
result_list = []
for i in range(8):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册