diff --git a/tools/train.py b/tools/train.py index 159cacc586f4a1fd20bd333c25603b6138a347fd..b9099210edecf41f4ff548166ae36c043dfc59b5 100644 --- a/tools/train.py +++ b/tools/train.py @@ -158,6 +158,7 @@ def main(): # compile program for multi-devices build_strategy = fluid.BuildStrategy() build_strategy.fuse_all_optimizer_ops = False + build_strategy.fuse_elewise_add_act_ops = True # only enable sync_bn in multi GPU devices sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \