未验证 提交 8b793d0e 编写于 作者: G gongweibao 提交者: GitHub

Fix DGC bug. (#16697)

上级 3fe8cb0d
...@@ -53,6 +53,10 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -53,6 +53,10 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p)); this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
} }
} }
// TODO(gongwb) :polish them!
if (is_encoded) {
VLOG(1) << "Use dgc allreduce mode";
}
} }
#else #else
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
...@@ -86,7 +90,7 @@ void AllReduceOpHandle::RunImplEncoded() { ...@@ -86,7 +90,7 @@ void AllReduceOpHandle::RunImplEncoded() {
paddle::framework::GradOriginalVarName(in_var_handles[i]->name()); paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded; auto encode_var_name = original_name + g_dgc_encoded;
auto *in_var = local_scope->FindVar(encode_var_name); auto *in_var = local_scope->FindVar(encode_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var, "%s should not be null", encode_var_name);
auto &in = in_var->Get<LoDTensor>(); auto &in = in_var->Get<LoDTensor>();
ins.emplace_back(&in); ins.emplace_back(&in);
......
...@@ -251,7 +251,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -251,7 +251,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
VLOG(3) << "apply " << pass->Type(); VLOG(3) << "BuildStrategy::Apply pass:" << pass->Type();
if (IsMultiDevPass(pass->Type())) { if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......
...@@ -752,7 +752,7 @@ class DGCMomentumOptimizer(MomentumOptimizer): ...@@ -752,7 +752,7 @@ class DGCMomentumOptimizer(MomentumOptimizer):
force_cpu=True) force_cpu=True)
for param_var, grad_var in param_and_grads: for param_var, grad_var in param_and_grads:
var_numel = reduce(lambda x, y: x * y, param_var.shape) var_numel = abs(reduce(lambda x, y: x * y, param_var.shape))
if var_numel < 16384 or \ if var_numel < 16384 or \
param_var.type == core.VarDesc.VarType.SELECTED_ROWS or \ param_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
grad_var.type == core.VarDesc.VarType.SELECTED_ROWS or \ grad_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
......
...@@ -104,6 +104,7 @@ class ParallelExecutor(object): ...@@ -104,6 +104,7 @@ class ParallelExecutor(object):
self._scope = scope if scope is not None else executor.global_scope() self._scope = scope if scope is not None else executor.global_scope()
if main_program is not None and main_program._enable_dgc: if main_program is not None and main_program._enable_dgc:
assert num_trainers > 1
assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce
assert num_trainers * len( assert num_trainers * len(
self._places) > 1, "dgc is not useful for single card training" self._places) > 1, "dgc is not useful for single card training"
...@@ -123,6 +124,11 @@ class ParallelExecutor(object): ...@@ -123,6 +124,11 @@ class ParallelExecutor(object):
exec_strategy=exec_strategy, exec_strategy=exec_strategy,
share_vars_from=share_vars_from._compiled_program share_vars_from=share_vars_from._compiled_program
if share_vars_from else None) if share_vars_from else None)
# FIXME(gongwb): I will move dgc from dist mode to allreduce mode in next pr.
if main_program._enable_dgc:
self._compiled_program._build_strategy.is_distribution = True
self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
self._exe = executor.Executor(self._place) self._exe = executor.Executor(self._place)
self._compiled_program._compile(place=self._place, scope=self._scope) self._compiled_program._compile(place=self._place, scope=self._scope)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册