From 11e383871adf4cf0cc31ddc62d28fe9b5c9216e8 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 8 Mar 2023 20:34:13 +0800 Subject: [PATCH] [with_data_parallel][part6.3] remove with_data_parallel in collective_optimizer (#51032) * remove with_data_parallel in collective_optimizer * add comm op * fix collective optimizer * remove check_err_log=True --- .../incubate/distributed/fleet/collective.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/paddle/incubate/distributed/fleet/collective.py b/python/paddle/incubate/distributed/fleet/collective.py index 5e135ced868..68c77e36f4a 100644 --- a/python/paddle/incubate/distributed/fleet/collective.py +++ b/python/paddle/incubate/distributed/fleet/collective.py @@ -18,6 +18,7 @@ import paddle import paddle.distributed.transpiler.distribute_transpiler as dist_transpiler import paddle.fluid as fluid import paddle.fluid.io as io +from paddle.distributed.fleet.meta_optimizers import RawProgramOptimizer from paddle.fluid.compiler import CompiledProgram from paddle.fluid.executor import Executor from paddle.fluid.framework import Program @@ -472,13 +473,19 @@ class CollectiveOptimizer(DistributedOptimizer): self._strategy.trainers_endpoints = fleet.worker_endpoints() self._strategy.enable_backward_optimizer_op_deps = True - self._compiled_program = CompiledProgram(main_program) - - self._compiled_program.with_data_parallel( - loss_name=self._loss.name, - build_strategy=self._strategy, - exec_strategy=self._strategy.exec_strategy, - share_vars_from=None, + comm_opt = RawProgramOptimizer(self._optimizer) + comm_opt.fuse_all_reduce_ops = True + comm_opt.fuse_grad_size_in_num = True + comm_opt.endpoints = self._strategy.trainers_endpoints + comm_opt.current_endpoint = comm_opt.endpoints[fleet.worker_index()] + comm_opt.rank = fleet.worker_index() + comm_opt.nranks = fleet.worker_num() + comm_opt.main_program = main_program + if comm_opt.nranks > 1: + comm_opt._transpile_main_program(self._loss) + + self._compiled_program = CompiledProgram( + comm_opt.main_program, build_strategy=self._strategy ) return self._compiled_program -- GitLab