提交 ab72d28a 编写于 作者: X Xin Pan

clean up and correctness check

上级 aa1085dd
...@@ -75,7 +75,12 @@ can also fuse some `Graph`'s `Node`s. ...@@ -75,7 +75,12 @@ can also fuse some `Graph`'s `Node`s.
class Pass { class Pass {
public: public:
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0; std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const {
// Some correctness check.
auto new_graph = ApplyImpl(std::move(graph));
// Some correctness check.
return new_graph;
}
// Get a reference to the attributed previously set. // Get a reference to the attributed previously set.
template <typename AttrType> template <typename AttrType>
...@@ -89,6 +94,9 @@ class Pass { ...@@ -89,6 +94,9 @@ class Pass {
// should delete the attribute. // should delete the attribute.
template <typename AttrType> template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr); void SetNotOwned(const std::string &attr_name, AttrType *attr);
protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const = 0;
}; };
// In my_pass.cc // In my_pass.cc
......
...@@ -34,16 +34,22 @@ namespace paddle { ...@@ -34,16 +34,22 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static const char kLossVarName[] = "loss_var_name";
static const char kPlaces[] = "places";
static const char kParams[] = "params";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
void MultiDevSSAGraphBuilder::Init() const { void MultiDevSSAGraphBuilder::Init() const {
loss_var_name_ = Get<const std::string>("loss_var_name"); loss_var_name_ = Get<const std::string>(kLossVarName);
places_ = Get<const std::vector<platform::Place>>("places"); places_ = Get<const std::vector<platform::Place>>(kPlaces);
local_scopes_ = Get<const std::vector<Scope *>>("local_scopes"); local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
strategy_ = Get<const BuildStrategy>("strategy"); strategy_ = Get<const BuildStrategy>(kStrategy);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs"); nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
#endif #endif
for (auto &p : Get<const std::unordered_set<std::string>>("params")) { for (auto &p : Get<const std::unordered_set<std::string>>(kParams)) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
balance_vars_.resize(places_.size(), 0); balance_vars_.resize(places_.size(), 0);
...@@ -58,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ...@@ -58,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node, ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -225,7 +231,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { ...@@ -225,7 +231,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
return sorted_ret; return sorted_ret;
} }
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
Init(); Init();
// Give the topology sort order and rebuild the graph structure. // Give the topology sort order and rebuild the graph structure.
...@@ -241,10 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -241,10 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8 // We cannot invoke resize. It is a bug of GCC 4.8
result.Set("vars", new GraphVars(places_.size())); result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set("dep_vars", new GraphDepVars); result.Set(kGraphDepVars, new GraphDepVars);
result.Set("ops", new GraphOps); result.Set(kGraphOps, new GraphOps);
result.Set("sharded_var_device", new ShardedVarDevice); result.Set(kShardedVarDevice, new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// realted op in the place 0 // realted op in the place 0
...@@ -281,7 +287,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -281,7 +287,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
graph->Get<ShardedVarDevice>("sharded_var_device") graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(n->Name(), op_dev_id); .emplace(n->Name(), op_dev_id);
} }
} else { } else {
...@@ -319,7 +325,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -319,7 +325,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name}); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>("sharded_var_device") graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id); .emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
break; break;
...@@ -406,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, ...@@ -406,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_); local_scopes_, places_);
#endif #endif
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
auto *in = auto *in =
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name); auto &vars = result->Get<GraphVars>(kGraphVars).at(i).at(p_name);
auto *out_var = new VarHandle( auto *out_var = new VarHandle(
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
i, p_name, p); i, p_name, p);
...@@ -427,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, ...@@ -427,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id])); local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
...@@ -436,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ...@@ -436,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
const std::string &og) const { const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_)); local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars")[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
...@@ -465,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -465,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
ir::Graph *result, const std::vector<std::string> &datas) const { ir::Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_)); local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) { for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>("vars")[i][d_name]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back().get());
auto var = new VarHandle( auto var = new VarHandle(
...@@ -524,7 +530,7 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, ...@@ -524,7 +530,7 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
const std::string &varname) const { const std::string &varname) const {
auto &sharded_var_device = graph.Get<ShardedVarDevice>("sharded_var_device"); auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice);
auto got = sharded_var_device.find(varname); auto got = sharded_var_device.find(varname);
return got == sharded_var_device.end() ? -1 : got->second; return got == sharded_var_device.end() ? -1 : got->second;
} }
...@@ -544,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { ...@@ -544,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i], local_scopes_.size(), local_scopes_[i], places_[i],
communication_dev_ctx); communication_dev_ctx);
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators. // factor. So it does not depend on any other operators.
...@@ -565,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, ...@@ -565,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx); CreateOpHandleIOs(result, node, scope_idx);
} }
...@@ -575,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -575,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const std::string &og, const std::string &og,
int dst_dev_id) const { int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_)); local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars")[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.size(), dst_dev_id, og, places_[dst_dev_id]);
...@@ -606,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -606,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
// on it. // on it.
void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op, void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) { for (auto &prev_op : result->Get<GraphOps>(kGraphOps)) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(result->CreateControlDepVar()); auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var); result->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
} }
} }
...@@ -635,18 +641,18 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -635,18 +641,18 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device") result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device") result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} else if (node->Op()->Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device") result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} else { } else {
...@@ -661,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -661,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"fetch_barrier"); "fetch_barrier");
} }
} }
...@@ -687,7 +693,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -687,7 +693,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
} }
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device") result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} }
...@@ -698,7 +704,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -698,7 +704,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
} }
op_dev_id = GetAppropriateDeviceID(output_var_names); op_dev_id = GetAppropriateDeviceID(output_var_names);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device") result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} else { } else {
...@@ -709,17 +715,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -709,17 +715,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
node->Op()->Type()); node->Op()->Type());
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id])); node->Op()->Type(), places_[op_dev_id]));
if (node->Op()->Type() == "send_barrier") { if (node->Op()->Type() == "send_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send"); ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "send");
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"send_barrier"); "send_barrier");
} else if (node->Op()->Type() == "fetch_barrier") { } else if (node->Op()->Type() == "fetch_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv"); ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "recv");
} else if (node->Op()->Type() == "send") { } else if (node->Op()->Type() == "send") {
// do nothing // do nothing
} else { } else {
...@@ -743,4 +749,9 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { ...@@ -743,4 +749,9 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_device_pass, REGISTER_PASS(multi_device_pass,
paddle::framework::details::MultiDevSSAGraphBuilder); paddle::framework::details::MultiDevSSAGraphBuilder)
.RequirePassAttr(paddle::framework::details::kLossVarName)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kParams)
.RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy);
...@@ -31,8 +31,8 @@ class Scope; ...@@ -31,8 +31,8 @@ class Scope;
namespace details { namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder { class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public: protected:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
private: private:
......
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
continue; continue;
...@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { ...@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var); graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
} }
} }
} }
...@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { ...@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place, ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset]; auto &var_holders = graph->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[node->Name()]; auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
...@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, ...@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node, ir::Node *new_node,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()]; auto &vars =
graph->Get<GraphVars>(kGraphVars)[place_offset][new_node->Name()];
size_t version = vars.size(); size_t version = vars.size();
auto var = auto var =
new VarHandle(new_node, version, place_offset, new_node->Name(), place); new VarHandle(new_node, version, place_offset, new_node->Name(), place);
...@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, ...@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
} }
void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) { for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf); graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
} }
......
...@@ -39,15 +39,19 @@ namespace details { ...@@ -39,15 +39,19 @@ namespace details {
typedef std::vector< typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>> std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is // all operators. NOTE that even we use a vector here, the operators is
// unordered. // unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps; typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice; typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
class SSAGraphBuilder : public ir::Pass { class SSAGraphBuilder : public ir::Pass {
public: public:
......
...@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} }
}; };
for (auto &var_map : graph->Get<GraphVars>("vars")) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair.get());
...@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} }
} }
for (auto &var : graph->Get<GraphDepVars>("dep_vars")) { for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get()); insert_pending_var(var.get());
} }
for (auto &op : graph->Get<GraphOps>("ops")) { for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
...@@ -87,4 +87,8 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -87,4 +87,8 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_device_check_pass, REGISTER_PASS(multi_device_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker); paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
...@@ -23,8 +23,8 @@ namespace framework { ...@@ -23,8 +23,8 @@ namespace framework {
namespace details { namespace details {
class SSAGraghBuilderWithChecker : public SSAGraphBuilder { class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public: protected:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get())); PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph; return graph;
......
...@@ -22,7 +22,7 @@ namespace details { ...@@ -22,7 +22,7 @@ namespace details {
template <typename Callback> template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) { static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) { for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
for (auto &pair1 : each) { for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) { for (auto &pair2 : pair1.second) {
callback(*pair2); callback(*pair2);
...@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) { ...@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
} }
} }
for (auto &var : graph.Get<GraphDepVars>("dep_vars")) { for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
callback(*var); callback(*var);
} }
} }
...@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>("ops")) { for (auto &op : graph.Get<GraphOps>(kGraphOps)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
......
...@@ -36,8 +36,8 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { ...@@ -36,8 +36,8 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
}; };
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public: protected:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<const std::string>("debug_graphviz_path"))); new std::ofstream(Get<const std::string>("debug_graphviz_path")));
......
...@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops; std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars // Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, &ready_vars, var.get()); InsertPendingVar(&pending_vars, &ready_vars, var.get());
} }
for (auto &op : graph_->Get<details::GraphOps>("ops")) { for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
...@@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
......
...@@ -40,10 +40,14 @@ class Graph { ...@@ -40,10 +40,14 @@ class Graph {
attr_dels_.clear(); attr_dels_.clear();
} }
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
}
template <typename AttrType> template <typename AttrType>
AttrType &Get(const std::string &attr_name) const { AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
"%s attr not registered for graph.", attr_name); attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} }
......
...@@ -20,10 +20,11 @@ limitations under the License. */ ...@@ -20,10 +20,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kGraphVizPath[] = "graph_viz_path";
std::unique_ptr<ir::Graph> GraphVizPass::Apply( std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>("graph_viz_path"); const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout; std::ostream& sout = *fout;
...@@ -67,4 +68,5 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply( ...@@ -67,4 +68,5 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass); REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
...@@ -28,8 +28,8 @@ namespace framework { ...@@ -28,8 +28,8 @@ namespace framework {
namespace ir { namespace ir {
class GraphVizPass : public Pass { class GraphVizPass : public Pass {
public: protected:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
}; };
......
...@@ -17,6 +17,22 @@ limitations under the License. */ ...@@ -17,6 +17,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
"Required pass atrribute %s not registered.", attr);
}
for (const std::string& attr : required_graph_attrs_) {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not exist.",
attr);
}
auto applied_graph = ApplyImpl(std::move(graph));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*applied_graph),
"Illegal Pass. Generated graph shouldn't has cycle.");
return applied_graph;
}
PassRegistry& PassRegistry::Instance() { PassRegistry& PassRegistry::Instance() {
static PassRegistry g_pass_info_map; static PassRegistry g_pass_info_map;
return g_pass_info_map; return g_pass_info_map;
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
...@@ -26,6 +27,8 @@ limitations under the License. */ ...@@ -26,6 +27,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <typename PassType>
struct PassRegistrar;
class Pass { class Pass {
public: public:
...@@ -40,7 +43,7 @@ class Pass { ...@@ -40,7 +43,7 @@ class Pass {
attr_dels_.clear(); attr_dels_.clear();
} }
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0; std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
// Get a reference to the attributed previously set. // Get a reference to the attributed previously set.
template <typename AttrType> template <typename AttrType>
...@@ -69,7 +72,25 @@ class Pass { ...@@ -69,7 +72,25 @@ class Pass {
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
} }
protected:
virtual std::unique_ptr<Graph> ApplyImpl(
std::unique_ptr<Graph> graph) const = 0;
private: private:
template <typename PassType>
friend struct PassRegistrar;
void RegisterRequiredPassAttrs(const std::unordered_set<std::string> &attrs) {
required_pass_attrs_.insert(attrs.begin(), attrs.end());
}
void RegisterRequiredGraphAttrs(
const std::unordered_set<std::string> &attrs) {
required_graph_attrs_.insert(attrs.begin(), attrs.end());
}
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
}; };
...@@ -119,10 +140,28 @@ struct PassRegistrar : public Registrar { ...@@ -119,10 +140,28 @@ struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) { explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type); "'%s' is registered more than once.", pass_type);
PassRegistry::Instance().Insert(pass_type, []() -> std::unique_ptr<Pass> { PassRegistry::Instance().Insert(
return std::unique_ptr<Pass>(new PassType()); pass_type, [this]() -> std::unique_ptr<Pass> {
}); std::unique_ptr<Pass> pass(new PassType());
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
return pass;
});
} }
PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) {
required_pass_attrs_.insert(attr);
return *this;
}
PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) {
required_graph_attrs_.insert(attr);
return *this;
}
private:
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
}; };
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ #define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
...@@ -132,16 +171,19 @@ struct PassRegistrar : public Registrar { ...@@ -132,16 +171,19 @@ struct PassRegistrar : public Registrar {
msg) msg)
// Register a new pass that can be applied on the IR. // Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \ #define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \ __reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \ "REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \ static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \ __pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \ int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \ __pass_registrar_##pass_type##__.Touch(); \
return 0; \ return 0; \
} } \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \ #define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
......
...@@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices(
if (member_->executor_) { if (member_->executor_) {
auto &sharded_var_device = auto &sharded_var_device =
member_->executor_->Graph().Get<details::ShardedVarDevice>( member_->executor_->Graph().Get<details::ShardedVarDevice>(
"sharded_var_device"); details::kShardedVarDevice);
if (sharded_var_device.find(var) != sharded_var_device.end()) { if (sharded_var_device.find(var) != sharded_var_device.end()) {
var_dev_id = sharded_var_device.at(var); var_dev_id = sharded_var_device.at(var);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册