提交 ff599b92 编写于 作者: C chengduoZH

use Reduce and Broadcast

上级 0441c2cc
......@@ -111,6 +111,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
for (auto *var : program.Block(0).AllVars()) {
var_types[var->Name()] = var->GetType();
}
auto graph = new SSAGraph();
SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast;
......@@ -120,13 +121,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());
size_t cur_dev_id = 0;
std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
sparse_var_name_on_devices.resize(places_.size());
bcast_sparse_var_name_set.resize(places_.size());
// Find "send" op first for split is in front of send.
OpDesc *send_op = GetSendOpDesc(program);
......@@ -145,27 +139,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(sparse_var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
sparse_var_name_on_devices[op_dev_id].emplace(var_name);
}
}
if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, cur_dev_id, og);
sparse_var_name_on_devices[cur_dev_id].emplace(og);
bcast_sparse_var_name_set[cur_dev_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_dev_id = (cur_dev_id + 1) % places_.size();
CreateReduceOp(&result, og, 0);
CreateBroadcastOp(&result, og, 0);
} else {
InsertNCCLAllReduceOp(&result, og);
}
......@@ -175,14 +157,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
// Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_sparse_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_sparse_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
......@@ -213,26 +187,9 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
return false;
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>>
&sparse_var_name_on_devices,
const OpDesc &op) const {
int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break;
for (size_t i = 0; i < sparse_var_name_on_devices.size(); ++i) {
if (sparse_var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
}
}
return var_dev_id;
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
const std::string &p_name,
size_t dev_id) const {
size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
#else
......@@ -240,11 +197,11 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
#endif
result->ops_.emplace_back(op_handle);
auto *in = result->vars_.at(dev_id).at(p_name).back().get();
auto *in = result->vars_.at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_.at(dev_id).at(p_name);
auto &vars = result->vars_.at(i).at(p_name);
auto &p = places_[i];
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var);
......@@ -345,8 +302,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
}
}
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(
SSAGraph *result, int dst_dev_id, const std::string &og) const {
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
const std::string &og,
int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
......
......@@ -75,8 +75,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, int dst_dev_id,
const std::string &og) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
int dev_id) const;
......@@ -87,11 +87,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t dev_id) const;
int GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
size_t src_dev_id) const;
/**
* Get send op in the global block of program.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册