From 7246b3e5d91d1dff2c4c59bb3c8806fec7c61b38 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Sun, 27 Sep 2020 20:27:30 +0800 Subject: [PATCH] move train_batch_size sync from load_config to merge_config (#1511) * move train_batch_size sync from load_config to merge_config --- ppdet/core/workspace.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ppdet/core/workspace.py b/ppdet/core/workspace.py index a5124b5cd..03c53df58 100644 --- a/ppdet/core/workspace.py +++ b/ppdet/core/workspace.py @@ -98,14 +98,6 @@ def load_config(file_path): merge_config(cfg) - # NOTE: training batch size defined only in TrainReader, sychornized - # batch size config to global, models can get batch size config - # from global config when building model. - # batch size in evaluation or inference can also be added here - if 'TrainReader' in global_config: - global_config['train_batch_size'] = global_config['TrainReader'][ - 'batch_size'] - return global_config @@ -141,7 +133,16 @@ def merge_config(config, another_cfg=None): """ global global_config dct = another_cfg if another_cfg is not None else global_config - return dict_merge(dct, config) + dct = dict_merge(dct, config) + + # NOTE: training batch size defined only in TrainReader, sychornized + # batch size config to global, models can get batch size config + # from global config when building model. + # batch size in evaluation or inference can also be added here + if 'TrainReader' in dct: + dct['train_batch_size'] = dct['TrainReader']['batch_size'] + + return dct def get_registered_modules(): -- GitLab