diff --git a/tools/train.py b/tools/train.py index 274cdff9fd94c93e690f352e51900254b5b8c13f..718a46b7bcf5ccd4088aeb97eb2822a05176b4d0 100644 --- a/tools/train.py +++ b/tools/train.py @@ -126,8 +126,9 @@ def main(): build_strategy.memory_optimize = False build_strategy.enable_inplace = True sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' - # only enable sync_bn in multi-devices - build_strategy.sync_batch_norm = sync_bn and devices_num > 1 + # only enable sync_bn in multi GPU devices + build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \ + and cfg.use_gpu train_compile_program = fluid.compiler.CompiledProgram( train_prog).with_data_parallel( loss_name=loss.name, build_strategy=build_strategy)