From 3bb956246d364714d5a56f1c447dcf3a1a92b231 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Wed, 9 Sep 2020 20:49:14 +0800 Subject: [PATCH] fix mixup & cutmix in multiprocessing (#1320) --- ppdet/data/reader.py | 15 +++++++++++---- tools/train.py | 4 +++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index 7d808b589..d6e4786a6 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -183,6 +183,7 @@ class Reader(object): inputs_def (dict): network input definition use to get input fields, which is used to determine the order of returned data. devices_num (int): number of devices. + num_trainers (int): number of trainers. Default 1. """ def __init__(self, @@ -203,7 +204,8 @@ class Reader(object): bufsize=-1, memsize='3G', inputs_def=None, - devices_num=1): + devices_num=1, + num_trainers=1): self._dataset = dataset self._roidbs = self._dataset.get_roidb() self._fields = copy.deepcopy(inputs_def[ @@ -244,8 +246,8 @@ class Reader(object): self._drop_empty = drop_empty # sampling - self._mixup_epoch = mixup_epoch - self._cutmix_epoch = cutmix_epoch + self._mixup_epoch = mixup_epoch // num_trainers + self._cutmix_epoch = cutmix_epoch // num_trainers self._class_aware_sampling = class_aware_sampling self._load_img = False @@ -415,7 +417,11 @@ class Reader(object): self._parallel.stop() -def create_reader(cfg, max_iter=0, global_cfg=None, devices_num=1): +def create_reader(cfg, + max_iter=0, + global_cfg=None, + devices_num=1, + num_trainers=1): """ Return iterable data reader. @@ -431,6 +437,7 @@ def create_reader(cfg, max_iter=0, global_cfg=None, devices_num=1): 'use_fine_grained_loss', False) cfg['num_classes'] = getattr(global_cfg, 'num_classes', 80) cfg['devices_num'] = devices_num + cfg['num_trainers'] = num_trainers reader = Reader(**cfg)() def _reader(): diff --git a/tools/train.py b/tools/train.py index 7f058c528..dd2edbd43 100644 --- a/tools/train.py +++ b/tools/train.py @@ -56,6 +56,7 @@ def main(): FLAGS.dist = 'PADDLE_TRAINER_ID' in env \ and 'PADDLE_TRAINERS_NUM' in env \ and int(env['PADDLE_TRAINERS_NUM']) > 1 + num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1)) if FLAGS.dist: trainer_id = int(env['PADDLE_TRAINER_ID']) local_seed = (99 + trainer_id) @@ -203,7 +204,8 @@ def main(): train_reader = create_reader( cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, cfg, - devices_num=devices_num) + devices_num=devices_num, + num_trainers=num_trainers) train_loader.set_sample_list_generator(train_reader, place) # whether output bbox is normalized in model output layer -- GitLab