未验证 提交 aeaf7886 编写于 作者: R ruri 提交者: GitHub

fix dali crop size bug (#4011)

上级 c08160b8
...@@ -26,19 +26,32 @@ from paddle import fluid ...@@ -26,19 +26,32 @@ from paddle import fluid
class HybridTrainPipe(Pipeline): class HybridTrainPipe(Pipeline):
def __init__(self, file_root, file_list, batch_size, resize_shorter, def __init__(self,
crop, min_area, lower, upper, interp, mean, std, file_root,
device_id, shard_id=0, num_shards=1, random_shuffle=True, file_list,
num_threads=4, seed=42): batch_size,
super(HybridTrainPipe, self).__init__(batch_size, resize_shorter,
num_threads, crop,
device_id, min_area,
seed=seed) lower,
self.input = ops.FileReader(file_root=file_root, upper,
file_list=file_list, interp,
shard_id=shard_id, mean,
num_shards=num_shards, std,
random_shuffle=random_shuffle) device_id,
shard_id=0,
num_shards=1,
random_shuffle=True,
num_threads=4,
seed=42):
super(HybridTrainPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader(
file_root=file_root,
file_list=file_list,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=random_shuffle)
# set internal nvJPEG buffers size to handle full-sized ImageNet images # set internal nvJPEG buffers size to handle full-sized ImageNet images
# without additional reallocations # without additional reallocations
device_memory_padding = 211025920 device_memory_padding = 211025920
...@@ -51,10 +64,8 @@ class HybridTrainPipe(Pipeline): ...@@ -51,10 +64,8 @@ class HybridTrainPipe(Pipeline):
random_aspect_ratio=[lower, upper], random_aspect_ratio=[lower, upper],
random_area=[min_area, 1.0], random_area=[min_area, 1.0],
num_attempts=100) num_attempts=100)
self.res = ops.Resize(device='gpu', self.res = ops.Resize(
resize_x=crop, device='gpu', resize_x=crop, resize_y=crop, interp_type=interp)
resize_y=crop,
interp_type=interp)
self.cmnp = ops.CropMirrorNormalize( self.cmnp = ops.CropMirrorNormalize(
device="gpu", device="gpu",
output_dtype=types.FLOAT, output_dtype=types.FLOAT,
...@@ -79,23 +90,32 @@ class HybridTrainPipe(Pipeline): ...@@ -79,23 +90,32 @@ class HybridTrainPipe(Pipeline):
class HybridValPipe(Pipeline): class HybridValPipe(Pipeline):
def __init__(self, file_root, file_list, batch_size, def __init__(self,
resize_shorter, crop, interp, mean, std, file_root,
device_id, shard_id=0, num_shards=1, random_shuffle=False, file_list,
num_threads=4, seed=42): batch_size,
super(HybridValPipe, self).__init__(batch_size, resize_shorter,
num_threads, crop,
device_id, interp,
seed=seed) mean,
self.input = ops.FileReader(file_root=file_root, std,
file_list=file_list, device_id,
shard_id=shard_id, shard_id=0,
num_shards=num_shards, num_shards=1,
random_shuffle=random_shuffle) random_shuffle=False,
num_threads=4,
seed=42):
super(HybridValPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader(
file_root=file_root,
file_list=file_list,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=random_shuffle)
self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
self.res = ops.Resize(device="gpu", self.res = ops.Resize(
resize_shorter=resize_shorter, device="gpu", resize_shorter=resize_shorter, interp_type=interp)
interp_type=interp)
self.cmnp = ops.CropMirrorNormalize( self.cmnp = ops.CropMirrorNormalize(
device="gpu", device="gpu",
output_dtype=types.FLOAT, output_dtype=types.FLOAT,
...@@ -134,7 +154,7 @@ def build(settings, mode='train'): ...@@ -134,7 +154,7 @@ def build(settings, mode='train'):
mean = [v * 255 for v in settings.image_mean] mean = [v * 255 for v in settings.image_mean]
std = [v * 255 for v in settings.image_std] std = [v * 255 for v in settings.image_std]
crop = settings.crop_size crop = settings.image_shape[1]
resize_shorter = settings.resize_short_size resize_shorter = settings.resize_short_size
min_area = settings.lower_scale min_area = settings.lower_scale
lower = settings.lower_ratio lower = settings.lower_ratio
...@@ -142,9 +162,9 @@ def build(settings, mode='train'): ...@@ -142,9 +162,9 @@ def build(settings, mode='train'):
interp = settings.interpolation or 1 # default to linear interp = settings.interpolation or 1 # default to linear
interp_map = { interp_map = {
0: types.INTERP_NN, # cv2.INTER_NEAREST 0: types.INTERP_NN, # cv2.INTER_NEAREST
1: types.INTERP_LINEAR, # cv2.INTER_LINEAR 1: types.INTERP_LINEAR, # cv2.INTER_LINEAR
2: types.INTERP_CUBIC, # cv2.INTER_CUBIC 2: types.INTERP_CUBIC, # cv2.INTER_CUBIC
4: types.INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4 4: types.INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
} }
assert interp in interp_map, "interpolation method not supported by DALI" assert interp in interp_map, "interpolation method not supported by DALI"
...@@ -159,14 +179,23 @@ def build(settings, mode='train'): ...@@ -159,14 +179,23 @@ def build(settings, mode='train'):
if not os.path.exists(file_list): if not os.path.exists(file_list):
file_list = None file_list = None
file_root = os.path.join(file_root, 'val') file_root = os.path.join(file_root, 'val')
pipe = HybridValPipe(file_root, file_list, batch_size, pipe = HybridValPipe(
resize_shorter, crop, interp, mean, std, file_root,
device_id=device_id) file_list,
batch_size,
resize_shorter,
crop,
interp,
mean,
std,
device_id=device_id)
pipe.build() pipe.build()
return DALIGenericIterator(pipe, ['feed_image', 'feed_label'], return DALIGenericIterator(
size=len(pipe), dynamic_shape=True, pipe, ['feed_image', 'feed_label'],
fill_last_batch=False, size=len(pipe),
last_batch_padded=True) dynamic_shape=True,
fill_last_batch=False,
last_batch_padded=True)
file_list = os.path.join(file_root, 'train_list.txt') file_list = os.path.join(file_root, 'train_list.txt')
if not os.path.exists(file_list): if not os.path.exists(file_list):
...@@ -177,11 +206,22 @@ def build(settings, mode='train'): ...@@ -177,11 +206,22 @@ def build(settings, mode='train'):
shard_id = int(env['PADDLE_TRAINER_ID']) shard_id = int(env['PADDLE_TRAINER_ID'])
num_shards = int(env['PADDLE_TRAINERS_NUM']) num_shards = int(env['PADDLE_TRAINERS_NUM'])
device_id = int(env['FLAGS_selected_gpus']) device_id = int(env['FLAGS_selected_gpus'])
pipe = HybridTrainPipe(file_root, file_list, batch_size, pipe = HybridTrainPipe(
resize_shorter, crop, min_area, file_root,
lower, upper, interp, mean, std, file_list,
device_id, shard_id, num_shards, batch_size,
seed=42 + shard_id) resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
shard_id,
num_shards,
seed=42 + shard_id)
pipe.build() pipe.build()
pipelines = [pipe] pipelines = [pipe]
sample_per_shard = len(pipe) // num_shards sample_per_shard = len(pipe) // num_shards
...@@ -194,10 +234,21 @@ def build(settings, mode='train'): ...@@ -194,10 +234,21 @@ def build(settings, mode='train'):
place.set_place(p) place.set_place(p)
device_id = place.gpu_device_id() device_id = place.gpu_device_id()
pipe = HybridTrainPipe( pipe = HybridTrainPipe(
file_root, file_list, batch_size, file_root,
resize_shorter, crop, min_area, file_list,
lower, upper, interp, mean, std, batch_size,
device_id, idx, num_shards, seed=42 + idx) resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
idx,
num_shards,
seed=42 + idx)
pipe.build() pipe.build()
pipelines.append(pipe) pipelines.append(pipe)
sample_per_shard = len(pipelines[0]) sample_per_shard = len(pipelines[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册