提交 f1ef3f22 编写于 作者: W Wu Yi 提交者: typhoonzero

Merge pull request #11728 from typhoonzero/fix_paraexe_bcast

Fix dist train broadcasting bug
上级 fac1d477
......@@ -483,6 +483,9 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
}
} else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else {
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
......
......@@ -30,7 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
......@@ -16,6 +16,8 @@
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
......@@ -33,6 +35,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
bool IsValidGraph(const SSAGraph* graph) const;
private:
......
......@@ -15,6 +15,7 @@
#pragma once
#include <iosfwd>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
......@@ -55,6 +56,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
return graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
private:
std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_;
......
......@@ -133,17 +133,18 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToGPUs(
const std::unordered_set<std::string> &vars) const {
// the the initialize bcast, all vars would be bcast from device(0), otherwise
// the the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device.
bool initialize = builder_.get() == nullptr ? true : false;
bool initializing = builder_.get() == nullptr ? true : false;
for (auto &var : vars) {
int var_dev_id =
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
if (!initialize && var_dev_id == -1) continue;
if (!initializing && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr;
if (initialize) {
if (initializing) {
main_var = member_->local_scopes_[0]->FindVar(var);
} else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
......@@ -164,7 +165,8 @@ void ParallelExecutor::BCastParamsToGPUs(
auto place = member_->places_[i];
void *buffer;
if ((initialize && i == 0) || (!initialize && i == var_dev_id)) {
if ((initializing && i == 0) ||
(!initializing && static_cast<int>(i) == var_dev_id)) {
buffer = const_cast<void *>(main_tensor.data<void>());
} else {
auto local_scope = member_->local_scopes_[i];
......@@ -181,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
if (initializing) {
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
if (var_dev_id >= 0) {
platform::dynload::ncclBcast(buffers[i], numel, data_type,
var_dev_id, nccl_ctx.comm_,
nccl_ctx.stream());
}
}
}
member_->nccl_ctxs_->WaitAll();
}
......
......@@ -302,7 +302,6 @@ class DistributeTranspiler(object):
"""
# remove optimize ops and add a send op to main_program
delete_ops(self.origin_program.global_block(), self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.origin_program.__str__()
return self.origin_program
......@@ -383,11 +382,12 @@ class DistributeTranspiler(object):
if self._is_adam_connected_op(op):
global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id, merged_var):
def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
lr_ops):
if self._is_optimizer_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program, merged_var)
else:
elif op not in lr_ops:
self._append_pserver_non_opt_ops(block, op)
def __op_have_grad_input__(op):
......@@ -447,7 +447,7 @@ class DistributeTranspiler(object):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block, grad_to_block_id,
merged_var)
merged_var, lr_ops)
# append global ops
if global_ops:
......@@ -455,7 +455,7 @@ class DistributeTranspiler(object):
pserver_program.num_blocks - 1)
for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id, None)
grad_to_block_id, None, lr_ops)
# process distributed lookup_table
prefetch_var_name_to_block_id = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册