提交 9b960330 编写于 作者: X Xin Pan

graph attrs

上级 2eeaa8d5
......@@ -70,8 +70,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op,
size_t place_id) const {
auto p = places_[place_id];
auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
auto *op_handle = result->Get<GraphOps>("ops").back().get();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
......@@ -179,13 +178,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8
result.attrs["vars"] = new std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());
result.attrs["dep_vars"] =
new std::unordered_set<std::unique_ptr<VarHandleBase>>();
result.attrs["ops"] = new std::vector<std::unique_ptr<OpHandleBase>>();
result.Set("vars", new GraphVars(places_.size()));
result.Set("dep_vars", new GraphDepVars);
result.Set("ops", new GraphOps);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
auto send_vars = FindDistTrainSendVars(program);
......@@ -308,13 +303,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result);
std::unique_ptr<SSAGraph> ssa_graph(new SSAGraph);
ssa_graph->vars_ =
std::move(*boost::any_cast<GraphVars *>(graph->attrs["vars"]));
ssa_graph->ops_ =
std::move(*boost::any_cast<GraphOps *>(graph->attrs["ops"]));
ssa_graph->dep_vars_ =
std::move(*boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"]));
ssa_graph->vars_ = std::move(*graph->Erase<GraphVars>("vars"));
ssa_graph->ops_ = std::move(*graph->Erase<GraphOps>("ops"));
ssa_graph->dep_vars_ = std::move(*graph->Erase<GraphDepVars>("dep_vars"));
return std::move(ssa_graph);
}
......@@ -347,20 +338,15 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
#else
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
#endif
boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
auto *in = boost::any_cast<GraphVars *>(result->attrs["vars"])
->at(src_dev_id)
.at(p_name)
.back()
.get();
result->Get<GraphOps>("ops").emplace_back(op_handle);
auto *in =
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars =
boost::any_cast<GraphVars *>(result->attrs["vars"])->at(i).at(p_name);
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
......@@ -370,8 +356,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
const OpDesc &op,
int dev_id) const {
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(
result->Get<GraphOps>("ops").emplace_back(
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, op, dev_id);
}
......@@ -379,19 +364,18 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>("ops").emplace_back(
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new AllReduceOpHandle(local_scopes_, places_));
result->Get<GraphOps>("ops").emplace_back(
new AllReduceOpHandle(local_scopes_, places_));
#endif
auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
auto &vars = result->Get<GraphVars>("vars")[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
......@@ -405,21 +389,18 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(
result->Get<GraphOps>("ops").emplace_back(
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
result->Get<GraphOps>("ops").emplace_back(
new DataBalanceOpHandle(local_scopes_, places_));
#endif
auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) {
auto &vars =
(*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][d_name];
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get());
auto var = new VarHandle(vars.size(), i, d_name, p);
......@@ -480,7 +461,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
auto *op_handle =
new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
places_[i], communication_dev_ctx);
boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
result->Get<GraphOps>("ops").emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
......@@ -499,8 +480,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new ComputationOpHandle(op, s, p));
result->Get<GraphOps>("ops").emplace_back(
new ComputationOpHandle(op, s, p));
CreateOpHandleIOs(result, op, scope_idx);
}
}
......@@ -509,25 +490,23 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
const std::string &og,
int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>("ops").emplace_back(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new ReduceOpHandle(local_scopes_, places_));
result->Get<GraphOps>("ops").emplace_back(
new ReduceOpHandle(local_scopes_, places_));
#endif
auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
auto &vars = result->Get<GraphVars>("vars")[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
}
auto &vars =
(*boost::any_cast<GraphVars *>(result->attrs["vars"]))[dst_dev_id][og];
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var);
op_handle->AddOutput(var);
......@@ -538,12 +517,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
// on it.
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : (*boost::any_cast<GraphOps *>(result->attrs["ops"]))) {
for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var);
boost::any_cast<GraphDepVars *>(result->attrs["dep_vars"])
->emplace(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var);
}
}
......@@ -579,8 +557,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
CreateComputationalOp(result, op, op_dev_id);
if (op.Type() == "concat") {
ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
"fetch_barrier");
}
}
......@@ -615,22 +592,16 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type());
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], op.Type(),
places_[op_dev_id]));
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
op, local_scopes_[op_dev_id], op.Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") {
ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"send");
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
} else if (op.Type() == "recv") {
ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
"send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"recv");
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
} else if (op.Type() == "send") {
// do nothing
} else {
......
......@@ -18,7 +18,7 @@ namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
for (auto &var_map : *boost::any_cast<GraphVars *>(graph->attrs["vars"])) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
......@@ -40,8 +40,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
}
}
}
......@@ -51,8 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
Graph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
auto &var_holders =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset];
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
......@@ -68,9 +66,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset) {
auto &vars =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset]
[each_var_name];
auto &vars = graph->Get<GraphVars>("vars")[place_offset][each_var_name];
size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place);
vars.emplace_back(var);
......@@ -78,15 +74,14 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
}
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
GraphOps &all_ops = *boost::any_cast<GraphOps *>(graph->attrs["ops"]);
GraphOps &all_ops = graph->Get<GraphOps>("ops");
for (auto &op : all_ops) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle();
boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dummy_leaf);
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}
......
......@@ -20,18 +20,77 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
class Graph;
template <typename AttrType>
struct AnyAttr {
public:
explicit AnyAttr(AttrType* attr) : attr_(attr) {}
AttrType& Get() { return *boost::any_cast<AttrType*>(attr_); }
private:
friend Graph;
AttrType* Release() {
released_ = true;
return boost::any_cast<AttrType*>(attr_);
}
void Delete() {
if (!released_) {
delete boost::any_cast<AttrType*>(attr_);
}
}
bool released_ = false;
boost::any attr_;
};
class Graph {
public:
std::map<std::string, boost::any> attrs;
virtual ~Graph() {
for (auto& attr : attrs) {
attr_dels[attr.first]();
}
attrs.clear();
attr_dels.clear();
}
template <typename AttrType>
AttrType& Get(const std::string& attr_name) {
return boost::any_cast<AnyAttr<AttrType>>(attrs[attr_name]).Get();
}
template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) {
AnyAttr<AttrType> any_attr = AnyAttr<AttrType>(attr);
attrs[attr_name] = any_attr;
attr_dels[attr_name] = [&any_attr]() { any_attr.Delete(); };
}
std::vector<Node *> inputs;
std::vector<Node *> outputs;
template <typename AttrType>
AttrType* Erase(const std::string& attr_name) {
AnyAttr<AttrType> attr_type =
boost::any_cast<AnyAttr<AttrType>>(attrs[attr_name]);
attrs.erase(attr_name);
attr_dels.erase(attr_name);
return attr_type.Release();
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
std::vector<std::unique_ptr<Node>> nodes;
std::map<std::string, boost::any> attrs;
std::map<std::string, std::function<void(void)>> attr_dels;
private:
};
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册