提交 ab72d28a 编写于 作者: X Xin Pan

clean up and correctness check

上级 aa1085dd
......@@ -75,7 +75,12 @@ can also fuse some `Graph`'s `Node`s.
class Pass {
public:
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const {
// Some correctness check.
auto new_graph = ApplyImpl(std::move(graph));
// Some correctness check.
return new_graph;
}
// Get a reference to the attributed previously set.
template <typename AttrType>
......@@ -89,6 +94,9 @@ class Pass {
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr);
protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const = 0;
};
// In my_pass.cc
......
......@@ -34,16 +34,22 @@ namespace paddle {
namespace framework {
namespace details {
static const char kLossVarName[] = "loss_var_name";
static const char kPlaces[] = "places";
static const char kParams[] = "params";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
void MultiDevSSAGraphBuilder::Init() const {
loss_var_name_ = Get<const std::string>("loss_var_name");
places_ = Get<const std::vector<platform::Place>>("places");
local_scopes_ = Get<const std::vector<Scope *>>("local_scopes");
strategy_ = Get<const BuildStrategy>("strategy");
loss_var_name_ = Get<const std::string>(kLossVarName);
places_ = Get<const std::vector<platform::Place>>(kPlaces);
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
strategy_ = Get<const BuildStrategy>(kStrategy);
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
#endif
for (auto &p : Get<const std::unordered_set<std::string>>("params")) {
for (auto &p : Get<const std::unordered_set<std::string>>(kParams)) {
grad_names_.insert(GradVarName(p));
}
balance_vars_.resize(places_.size(), 0);
......@@ -58,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node,
size_t place_id) const {
auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>("ops").back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
......@@ -225,7 +231,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
return sorted_ret;
}
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
Init();
// Give the topology sort order and rebuild the graph structure.
......@@ -241,10 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8
result.Set("vars", new GraphVars(places_.size()));
result.Set("dep_vars", new GraphDepVars);
result.Set("ops", new GraphOps);
result.Set("sharded_var_device", new ShardedVarDevice);
result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps);
result.Set(kShardedVarDevice, new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
......@@ -281,7 +287,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
graph->Get<ShardedVarDevice>("sharded_var_device")
graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(n->Name(), op_dev_id);
}
} else {
......@@ -319,7 +325,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>("sharded_var_device")
graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name);
break;
......@@ -406,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_);
#endif
result->Get<GraphOps>("ops").emplace_back(op_handle);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
auto *in =
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
auto &vars = result->Get<GraphVars>(kGraphVars).at(i).at(p_name);
auto *out_var = new VarHandle(
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
i, p_name, p);
......@@ -427,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
ir::Node *node,
int dev_id) const {
result->Get<GraphOps>("ops").emplace_back(
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id);
......@@ -436,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>("ops").back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars")[i][og];
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
......@@ -465,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
ir::Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>("ops").back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get());
auto var = new VarHandle(
......@@ -524,7 +530,7 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
const std::string &varname) const {
auto &sharded_var_device = graph.Get<ShardedVarDevice>("sharded_var_device");
auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice);
auto got = sharded_var_device.find(varname);
return got == sharded_var_device.end() ? -1 : got->second;
}
......@@ -544,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i],
communication_dev_ctx);
result->Get<GraphOps>("ops").emplace_back(op_handle);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
......@@ -565,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->Get<GraphOps>("ops").emplace_back(
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx);
}
......@@ -575,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const std::string &og,
int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>("ops").back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars")[i][og];
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
}
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), dst_dev_id, og, places_[dst_dev_id]);
......@@ -606,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
// on it.
void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) {
for (auto &prev_op : result->Get<GraphOps>(kGraphOps)) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
result->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
op->AddInput(dep_var);
}
}
......@@ -635,18 +641,18 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device")
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
}
}
for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device")
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
}
} else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device")
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
}
} else {
......@@ -661,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
CreateComputationalOp(result, node, op_dev_id);
if (node->Op()->Type() == "concat") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"fetch_barrier");
}
}
......@@ -687,7 +693,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device")
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
}
}
......@@ -698,7 +704,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id = GetAppropriateDeviceID(output_var_names);
for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>("sharded_var_device")
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
}
} else {
......@@ -709,17 +715,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
node->Op()->Type());
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id]));
if (node->Op()->Type() == "send_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "send");
} else if (node->Op()->Type() == "recv") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"send_barrier");
} else if (node->Op()->Type() == "fetch_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "recv");
} else if (node->Op()->Type() == "send") {
// do nothing
} else {
......@@ -743,4 +749,9 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
} // namespace paddle
REGISTER_PASS(multi_device_pass,
paddle::framework::details::MultiDevSSAGraphBuilder);
paddle::framework::details::MultiDevSSAGraphBuilder)
.RequirePassAttr(paddle::framework::details::kLossVarName)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kParams)
.RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy);
......@@ -31,8 +31,8 @@ class Scope;
namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:
......
......@@ -18,7 +18,7 @@ namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
......@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
}
}
......@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holders = graph->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr;
if (var_holder.empty()) {
......@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()];
auto &vars =
graph->Get<GraphVars>(kGraphVars)[place_offset][new_node->Name()];
size_t version = vars.size();
auto var =
new VarHandle(new_node, version, place_offset, new_node->Name(), place);
......@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
}
void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) {
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}
......
......@@ -39,15 +39,19 @@ namespace details {
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
class SSAGraphBuilder : public ir::Pass {
public:
......
......@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
};
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get());
......@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
}
for (auto &var : graph->Get<GraphDepVars>("dep_vars")) {
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get());
}
for (auto &op : graph->Get<GraphOps>("ops")) {
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (op->Inputs().empty()) {
ready_ops.insert(op.get());
} else {
......@@ -87,4 +87,8 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} // namespace paddle
REGISTER_PASS(multi_device_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker);
paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
......@@ -23,8 +23,8 @@ namespace framework {
namespace details {
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
......
......@@ -22,7 +22,7 @@ namespace details {
template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
......@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
for (auto &var : graph.Get<GraphDepVars>("dep_vars")) {
for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
callback(*var);
}
}
......@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>("ops")) {
for (auto &op : graph.Get<GraphOps>(kGraphOps)) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......
......@@ -36,8 +36,8 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
};
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<const std::string>("debug_graphviz_path")));
......
......@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
}
}
}
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) {
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, &ready_vars, var.get());
}
for (auto &op : graph_->Get<details::GraphOps>("ops")) {
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
......@@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
......
......@@ -40,10 +40,14 @@ class Graph {
attr_dels_.clear();
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
}
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for graph.", attr_name);
PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}
......
......@@ -20,10 +20,11 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
static const char kGraphVizPath[] = "graph_viz_path";
std::unique_ptr<ir::Graph> GraphVizPass::Apply(
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>("graph_viz_path");
const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout;
......@@ -67,4 +68,5 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass);
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
......@@ -28,8 +28,8 @@ namespace framework {
namespace ir {
class GraphVizPass : public Pass {
public:
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
......
......@@ -17,6 +17,22 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
"Required pass atrribute %s not registered.", attr);
}
for (const std::string& attr : required_graph_attrs_) {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not exist.",
attr);
}
auto applied_graph = ApplyImpl(std::move(graph));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*applied_graph),
"Illegal Pass. Generated graph shouldn't has cycle.");
return applied_graph;
}
PassRegistry& PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
......@@ -26,6 +27,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
template <typename PassType>
struct PassRegistrar;
class Pass {
public:
......@@ -40,7 +43,7 @@ class Pass {
attr_dels_.clear();
}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
// Get a reference to the attributed previously set.
template <typename AttrType>
......@@ -69,7 +72,25 @@ class Pass {
attrs_[attr_name] = attr;
}
protected:
virtual std::unique_ptr<Graph> ApplyImpl(
std::unique_ptr<Graph> graph) const = 0;
private:
template <typename PassType>
friend struct PassRegistrar;
void RegisterRequiredPassAttrs(const std::unordered_set<std::string> &attrs) {
required_pass_attrs_.insert(attrs.begin(), attrs.end());
}
void RegisterRequiredGraphAttrs(
const std::unordered_set<std::string> &attrs) {
required_graph_attrs_.insert(attrs.begin(), attrs.end());
}
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
......@@ -119,10 +140,28 @@ struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type);
PassRegistry::Instance().Insert(pass_type, []() -> std::unique_ptr<Pass> {
return std::unique_ptr<Pass>(new PassType());
});
PassRegistry::Instance().Insert(
pass_type, [this]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> pass(new PassType());
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
return pass;
});
}
PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) {
required_pass_attrs_.insert(attr);
return *this;
}
PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) {
required_graph_attrs_.insert(attr);
return *this;
}
private:
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
......@@ -132,16 +171,19 @@ struct PassRegistrar : public Registrar {
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
}
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
......
......@@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices(
if (member_->executor_) {
auto &sharded_var_device =
member_->executor_->Graph().Get<details::ShardedVarDevice>(
"sharded_var_device");
details::kShardedVarDevice);
if (sharded_var_device.find(var) != sharded_var_device.end()) {
var_dev_id = sharded_var_device.at(var);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册