diff --git a/ppdet/core/workspace.py b/ppdet/core/workspace.py index a5124b5cd9526d9e77391a5af29a6da0cc1feb28..03c53df58a9d00feb1578b46dfb4eeaf345a519e 100644 --- a/ppdet/core/workspace.py +++ b/ppdet/core/workspace.py @@ -98,14 +98,6 @@ def load_config(file_path): merge_config(cfg) - # NOTE: training batch size defined only in TrainReader, sychornized - # batch size config to global, models can get batch size config - # from global config when building model. - # batch size in evaluation or inference can also be added here - if 'TrainReader' in global_config: - global_config['train_batch_size'] = global_config['TrainReader'][ - 'batch_size'] - return global_config @@ -141,7 +133,16 @@ def merge_config(config, another_cfg=None): """ global global_config dct = another_cfg if another_cfg is not None else global_config - return dict_merge(dct, config) + dct = dict_merge(dct, config) + + # NOTE: training batch size defined only in TrainReader, sychornized + # batch size config to global, models can get batch size config + # from global config when building model. + # batch size in evaluation or inference can also be added here + if 'TrainReader' in dct: + dct['train_batch_size'] = dct['TrainReader']['batch_size'] + + return dct def get_registered_modules():