未验证 提交 3bb95624 编写于 作者: W wangguanzhong 提交者: GitHub

fix mixup & cutmix in multiprocessing (#1320)

上级 c764ccdb
...@@ -183,6 +183,7 @@ class Reader(object): ...@@ -183,6 +183,7 @@ class Reader(object):
inputs_def (dict): network input definition use to get input fields, inputs_def (dict): network input definition use to get input fields,
which is used to determine the order of returned data. which is used to determine the order of returned data.
devices_num (int): number of devices. devices_num (int): number of devices.
num_trainers (int): number of trainers. Default 1.
""" """
def __init__(self, def __init__(self,
...@@ -203,7 +204,8 @@ class Reader(object): ...@@ -203,7 +204,8 @@ class Reader(object):
bufsize=-1, bufsize=-1,
memsize='3G', memsize='3G',
inputs_def=None, inputs_def=None,
devices_num=1): devices_num=1,
num_trainers=1):
self._dataset = dataset self._dataset = dataset
self._roidbs = self._dataset.get_roidb() self._roidbs = self._dataset.get_roidb()
self._fields = copy.deepcopy(inputs_def[ self._fields = copy.deepcopy(inputs_def[
...@@ -244,8 +246,8 @@ class Reader(object): ...@@ -244,8 +246,8 @@ class Reader(object):
self._drop_empty = drop_empty self._drop_empty = drop_empty
# sampling # sampling
self._mixup_epoch = mixup_epoch self._mixup_epoch = mixup_epoch // num_trainers
self._cutmix_epoch = cutmix_epoch self._cutmix_epoch = cutmix_epoch // num_trainers
self._class_aware_sampling = class_aware_sampling self._class_aware_sampling = class_aware_sampling
self._load_img = False self._load_img = False
...@@ -415,7 +417,11 @@ class Reader(object): ...@@ -415,7 +417,11 @@ class Reader(object):
self._parallel.stop() self._parallel.stop()
def create_reader(cfg, max_iter=0, global_cfg=None, devices_num=1): def create_reader(cfg,
max_iter=0,
global_cfg=None,
devices_num=1,
num_trainers=1):
""" """
Return iterable data reader. Return iterable data reader.
...@@ -431,6 +437,7 @@ def create_reader(cfg, max_iter=0, global_cfg=None, devices_num=1): ...@@ -431,6 +437,7 @@ def create_reader(cfg, max_iter=0, global_cfg=None, devices_num=1):
'use_fine_grained_loss', False) 'use_fine_grained_loss', False)
cfg['num_classes'] = getattr(global_cfg, 'num_classes', 80) cfg['num_classes'] = getattr(global_cfg, 'num_classes', 80)
cfg['devices_num'] = devices_num cfg['devices_num'] = devices_num
cfg['num_trainers'] = num_trainers
reader = Reader(**cfg)() reader = Reader(**cfg)()
def _reader(): def _reader():
......
...@@ -56,6 +56,7 @@ def main(): ...@@ -56,6 +56,7 @@ def main():
FLAGS.dist = 'PADDLE_TRAINER_ID' in env \ FLAGS.dist = 'PADDLE_TRAINER_ID' in env \
and 'PADDLE_TRAINERS_NUM' in env \ and 'PADDLE_TRAINERS_NUM' in env \
and int(env['PADDLE_TRAINERS_NUM']) > 1 and int(env['PADDLE_TRAINERS_NUM']) > 1
num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1))
if FLAGS.dist: if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID']) trainer_id = int(env['PADDLE_TRAINER_ID'])
local_seed = (99 + trainer_id) local_seed = (99 + trainer_id)
...@@ -203,7 +204,8 @@ def main(): ...@@ -203,7 +204,8 @@ def main():
train_reader = create_reader( train_reader = create_reader(
cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num,
cfg, cfg,
devices_num=devices_num) devices_num=devices_num,
num_trainers=num_trainers)
train_loader.set_sample_list_generator(train_reader, place) train_loader.set_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer # whether output bbox is normalized in model output layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册