diff --git a/tools/train.py b/tools/train.py index 2428469863c14736c5dc000c47b1eef9553f7240..037decdcc4139ca7336d6bb9df1772d09564033f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -126,7 +126,8 @@ def main(): build_strategy.memory_optimize = False build_strategy.enable_inplace = True sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' - build_strategy.sync_batch_norm = sync_bn + # only enable sync_bn in multi-devices + build_strategy.sync_batch_norm = sync_bn and devices_num > 1 train_compile_program = fluid.compiler.CompiledProgram( train_prog).with_data_parallel( loss_name=loss.name, build_strategy=build_strategy)