未验证 提交 25877ca3 编写于 作者: W wangguanzhong 提交者: GitHub

fix mixup & cutmix in multiprocessing (#1319)

上级 a8670536
......@@ -183,6 +183,7 @@ class Reader(object):
inputs_def (dict): network input definition use to get input fields,
which is used to determine the order of returned data.
devices_num (int): number of devices.
num_trainers (int): number of trainers. Default 1.
"""
def __init__(self,
......@@ -203,7 +204,8 @@ class Reader(object):
bufsize=-1,
memsize='3G',
inputs_def=None,
devices_num=1):
devices_num=1,
num_trainers=1):
self._dataset = dataset
self._roidbs = self._dataset.get_roidb()
self._fields = copy.deepcopy(inputs_def[
......@@ -244,8 +246,8 @@ class Reader(object):
self._drop_empty = drop_empty
# sampling
self._mixup_epoch = mixup_epoch
self._cutmix_epoch = cutmix_epoch
self._mixup_epoch = mixup_epoch // num_trainers
self._cutmix_epoch = cutmix_epoch // num_trainers
self._class_aware_sampling = class_aware_sampling
self._load_img = False
......@@ -415,7 +417,11 @@ class Reader(object):
self._parallel.stop()
def create_reader(cfg, max_iter=0, global_cfg=None, devices_num=1):
def create_reader(cfg,
max_iter=0,
global_cfg=None,
devices_num=1,
num_trainers=1):
"""
Return iterable data reader.
......@@ -431,6 +437,7 @@ def create_reader(cfg, max_iter=0, global_cfg=None, devices_num=1):
'use_fine_grained_loss', False)
cfg['num_classes'] = getattr(global_cfg, 'num_classes', 80)
cfg['devices_num'] = devices_num
cfg['num_trainers'] = num_trainers
reader = Reader(**cfg)()
def _reader():
......
......@@ -56,6 +56,7 @@ def main():
FLAGS.dist = 'PADDLE_TRAINER_ID' in env \
and 'PADDLE_TRAINERS_NUM' in env \
and int(env['PADDLE_TRAINERS_NUM']) > 1
num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1))
if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID'])
local_seed = (99 + trainer_id)
......@@ -203,7 +204,8 @@ def main():
train_reader = create_reader(
cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num,
cfg,
devices_num=devices_num)
devices_num=devices_num,
num_trainers=num_trainers)
train_loader.set_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册