未验证 提交 11e38387 编写于 作者: K kangguangli 提交者: GitHub

[with_data_parallel][part6.3] remove with_data_parallel in collective_optimizer (#51032)

* remove with_data_parallel in collective_optimizer

* add comm op

* fix collective optimizer

* remove check_err_log=True
上级 b4a500bc
......@@ -18,6 +18,7 @@ import paddle
import paddle.distributed.transpiler.distribute_transpiler as dist_transpiler
import paddle.fluid as fluid
import paddle.fluid.io as io
from paddle.distributed.fleet.meta_optimizers import RawProgramOptimizer
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program
......@@ -472,13 +473,19 @@ class CollectiveOptimizer(DistributedOptimizer):
self._strategy.trainers_endpoints = fleet.worker_endpoints()
self._strategy.enable_backward_optimizer_op_deps = True
self._compiled_program = CompiledProgram(main_program)
self._compiled_program.with_data_parallel(
loss_name=self._loss.name,
build_strategy=self._strategy,
exec_strategy=self._strategy.exec_strategy,
share_vars_from=None,
comm_opt = RawProgramOptimizer(self._optimizer)
comm_opt.fuse_all_reduce_ops = True
comm_opt.fuse_grad_size_in_num = True
comm_opt.endpoints = self._strategy.trainers_endpoints
comm_opt.current_endpoint = comm_opt.endpoints[fleet.worker_index()]
comm_opt.rank = fleet.worker_index()
comm_opt.nranks = fleet.worker_num()
comm_opt.main_program = main_program
if comm_opt.nranks > 1:
comm_opt._transpile_main_program(self._loss)
self._compiled_program = CompiledProgram(
comm_opt.main_program, build_strategy=self._strategy
)
return self._compiled_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册