提交 ff599b92 编写于 作者: C chengduoZH

use Reduce and Broadcast

上级 0441c2cc
...@@ -111,6 +111,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -111,6 +111,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
var_types[var->Name()] = var->GetType(); var_types[var->Name()] = var->GetType();
} }
auto graph = new SSAGraph(); auto graph = new SSAGraph();
SSAGraph &result = *graph; SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
...@@ -120,13 +121,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -120,13 +121,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size()); places_.size());
size_t cur_dev_id = 0;
std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
sparse_var_name_on_devices.resize(places_.size());
bcast_sparse_var_name_set.resize(places_.size());
// Find "send" op first for split is in front of send. // Find "send" op first for split is in front of send.
OpDesc *send_op = GetSendOpDesc(program); OpDesc *send_op = GetSendOpDesc(program);
...@@ -145,27 +139,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -145,27 +139,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(sparse_var_name_on_devices, *op); CreateComputationalOps(&result, *op, places_.size());
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
sparse_var_name_on_devices[op_dev_id].emplace(var_name);
}
}
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (IsSparseGradient(var_types, og)) { if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, cur_dev_id, og); CreateReduceOp(&result, og, 0);
sparse_var_name_on_devices[cur_dev_id].emplace(og); CreateBroadcastOp(&result, og, 0);
bcast_sparse_var_name_set[cur_dev_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_dev_id = (cur_dev_id + 1) % places_.size();
} else { } else {
InsertNCCLAllReduceOp(&result, og); InsertNCCLAllReduceOp(&result, og);
} }
...@@ -175,14 +157,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -175,14 +157,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
} }
// Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_sparse_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_sparse_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
}
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
harzaeds need to be handled. harzaeds need to be handled.
...@@ -213,26 +187,9 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient( ...@@ -213,26 +187,9 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
return false; return false;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>>
&sparse_var_name_on_devices,
const OpDesc &op) const {
int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break;
for (size_t i = 0; i < sparse_var_name_on_devices.size(); ++i) {
if (sparse_var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
}
}
return var_dev_id;
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
const std::string &p_name, const std::string &p_name,
size_t dev_id) const { size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_); auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
#else #else
...@@ -240,11 +197,11 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, ...@@ -240,11 +197,11 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
#endif #endif
result->ops_.emplace_back(op_handle); result->ops_.emplace_back(op_handle);
auto *in = result->vars_.at(dev_id).at(p_name).back().get(); auto *in = result->vars_.at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_.at(dev_id).at(p_name); auto &vars = result->vars_.at(i).at(p_name);
auto &p = places_[i]; auto &p = places_[i];
auto *out_var = new VarHandle(vars.size(), i, p_name, p); auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var); vars.emplace_back(out_var);
...@@ -345,8 +302,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, ...@@ -345,8 +302,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
} }
} }
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp( VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
SSAGraph *result, int dst_dev_id, const std::string &og) const { const std::string &og,
int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back( result->ops_.emplace_back(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
......
...@@ -75,8 +75,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -75,8 +75,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const; void CreateScaleLossGradOp(SSAGraph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, int dst_dev_id, VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
const std::string &og) const; int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op, void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
int dev_id) const; int dev_id) const;
...@@ -87,11 +87,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -87,11 +87,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t dev_id) const; size_t src_dev_id) const;
int GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
/** /**
* Get send op in the global block of program. * Get send op in the global block of program.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册