diff --git a/python/paddle/incubate/distributed/fleet/collective.py b/python/paddle/incubate/distributed/fleet/collective.py index 5e135ced868303e3816ffbe96a381135181c37d7..68c77e36f4acce72d940d885f9d15254b822b8f4 100644 --- a/python/paddle/incubate/distributed/fleet/collective.py +++ b/python/paddle/incubate/distributed/fleet/collective.py @@ -18,6 +18,7 @@ import paddle import paddle.distributed.transpiler.distribute_transpiler as dist_transpiler import paddle.fluid as fluid import paddle.fluid.io as io +from paddle.distributed.fleet.meta_optimizers import RawProgramOptimizer from paddle.fluid.compiler import CompiledProgram from paddle.fluid.executor import Executor from paddle.fluid.framework import Program @@ -472,13 +473,19 @@ class CollectiveOptimizer(DistributedOptimizer): self._strategy.trainers_endpoints = fleet.worker_endpoints() self._strategy.enable_backward_optimizer_op_deps = True - self._compiled_program = CompiledProgram(main_program) - - self._compiled_program.with_data_parallel( - loss_name=self._loss.name, - build_strategy=self._strategy, - exec_strategy=self._strategy.exec_strategy, - share_vars_from=None, + comm_opt = RawProgramOptimizer(self._optimizer) + comm_opt.fuse_all_reduce_ops = True + comm_opt.fuse_grad_size_in_num = True + comm_opt.endpoints = self._strategy.trainers_endpoints + comm_opt.current_endpoint = comm_opt.endpoints[fleet.worker_index()] + comm_opt.rank = fleet.worker_index() + comm_opt.nranks = fleet.worker_num() + comm_opt.main_program = main_program + if comm_opt.nranks > 1: + comm_opt._transpile_main_program(self._loss) + + self._compiled_program = CompiledProgram( + comm_opt.main_program, build_strategy=self._strategy ) return self._compiled_program