提交 5173a53c 编写于 作者: X Xin Pan

fix reorder issue.

上级 21a45420
## Motivation ## Motivation
There is a ```gap``` between the ```Program``` defined by There is a `gap` between the `Program` defined by
user and the ```Executable``` that can be scheduled user and the `Executable` that can be scheduled
efficiently on heterogeneous hardware, either locally efficiently on heterogeneous hardware, either locally
or distributedly. or distributedly.
Usually, the ```gap``` is bridged by Usually, the `gap` is bridged by
* A serious transformations with defined order. * A serious transformations with defined order.
* These transformations usually involve * These transformations usually involve
```insert, delete, clustering, split, dependency analysis```. `insert, delete, clustering, split, dependency analysis`.
* Has a simple way to verify and debug each transformation. * Has a simple way to verify and debug each transformation.
...@@ -38,44 +38,44 @@ design below. ...@@ -38,44 +38,44 @@ design below.
#### Node #### Node
```Node``` represents an operation that performs some computation or `Node` represents an operation that performs some computation or
a variable that is input or output of operation. a variable that is input or output of operation.
```Node```s are connected to other ```Node```s via inputs and outputs. `Node`s are connected to other `Node`s via inputs and outputs.
Other properties (maybe device placement information) can be added Other properties (maybe device placement information) can be added
to ```Node``` in the future if it's a to `Node` in the future if it's a
common requirement of many other ```Pass```es. Otherwise, it should live common requirement of many other `Pass`es. Otherwise, it should live
in a ```Node``` wrapper class that is private to some ```Pass``` or be in a `Node` wrapper class that is private to some `Pass` or be
a local member of a ```Pass```. a local member of a `Pass`.
#### Graph #### Graph
```Graph``` contains a list of ```Node```s, which are connected to `Graph` contains a list of `Node`s, which are connected to
each other via inputs and outputs. each other via inputs and outputs.
TODO: Better definitions for the graph. TODO: Better definitions for the graph.
```Graph``` can also contain ```Attribute```s. ```Attribute```s `Graph` can also contain `Attribute`s. `Attribute`s
can be ``any`` thing. For example, it can be a list of "wraper" can be `any` thing. For example, it can be a list of "wraper"
nodes. The ```wrapper``` nodes compose ```Node```s and provide nodes. The `wrapper` nodes compose `Node`s and provide
helper method for execution or transformation. ```Attribute``` helper method for execution or transformation. `Attribute`
can also contain other things that describe some properties of can also contain other things that describe some properties of
the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed the `Graph` or `Graph` nodes. `Attribute` can be passed
across ```Pass```. However, it should be used with care. across `Pass`. However, it should be used with care.
#### Pass #### Pass
```Pass``` represents a transformation of ```Graph```. Its input `Pass` represents a transformation of `Graph`. Its input
is a ```Graph``` and its output is also a ```Graph```. For example, is a `Graph` and its output is also a `Graph`. For example,
a ```Pass``` can simply print out the ```Graph```. A ```Pass``` a `Pass` can simply print out the `Graph`. A `Pass`
can also fuse some ```Graph```'s ```Node```s. can also fuse some `Graph`'s `Node`s.
#### Optimize #### Optimize
```Optimize``` contains a series of ```Pass``` with defined order. `Optimize` contains a series of `Pass` with defined order.
```Optimize``` transforms a ```Graph``` that only contains raw `Optimize` transforms a `Graph` that only contains raw
modeling logic to a ```Graph``` that can be run efficiently while modeling logic to a `Graph` that can be run efficiently while
maintaining the original modeling logic. maintaining the original modeling logic.
......
...@@ -196,38 +196,46 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -196,38 +196,46 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph); std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
size_t last_backward = 0; size_t last_backward = 0;
std::vector<ir::Node *> optimize_ops;
std::vector<ir::Node *> sorted_ret;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
if (boost::get<int>( if (boost::get<int>(
ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kBackward)) { static_cast<int>(OpRole::kBackward)) {
sorted_ret.push_back(ret[i]); last_backward = i;
last_backward = sorted_ret.size(); }
} else if (boost::get<int>(ret[i]->Op()->GetAttr( }
std::vector<ir::Node *> optimize_ops;
std::vector<ir::Node *> sorted_ret;
for (size_t i = 0; i < ret.size(); ++i) {
if (i < last_backward) {
if (boost::get<int>(ret[i]->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) == OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kOptimize)) { static_cast<int>(OpRole::kOptimize)) {
optimize_ops.push_back(ret[i]); optimize_ops.push_back(ret[i]);
} else { } else {
sorted_ret.push_back(ret[i]); sorted_ret.push_back(ret[i]);
} }
} } else if (i == last_backward) {
sorted_ret.push_back(ret[i]);
// Verify that no operations before optimize ops depends on optimize ops. // Verify that no operations before optimize ops depends on optimize ops.
std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(), std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
optimize_ops.end()); optimize_ops.end());
for (size_t i = 0; i < last_backward; ++i) { for (ir::Node *n : sorted_ret) {
for (ir::Node *in : sorted_ret[i]->inputs) { for (ir::Node *in : n->inputs) {
for (ir::Node *pre_n : in->inputs) { for (ir::Node *pre_n : in->inputs) {
PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(), PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
"optimize operations cannot be depended by forward " "optimize operations cannot be depended by forward "
"or backward node %s -> %s", "or backward node %s -> %s",
pre_n->Name(), sorted_ret[i]->Name()); pre_n->Name(), n->Name());
} }
} }
} }
sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(), sorted_ret.insert(sorted_ret.end(), optimize_ops.begin(),
optimize_ops.end()); optimize_ops.end());
} else {
sorted_ret.push_back(ret[i]);
}
}
return sorted_ret; return sorted_ret;
} }
...@@ -239,7 +247,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -239,7 +247,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
ir::Graph &result = *graph; ir::Graph &result = *graph;
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->NodeType() == ir::Node::Type::kVariable) { if (node->NodeType() == ir::Node::Type::kVariable && node->Var()) {
all_vars_.emplace(node->Name(), node->Var()); all_vars_.emplace(node->Name(), node->Var());
} }
} }
...@@ -361,6 +369,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -361,6 +369,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
} }
} }
} }
/*
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
*/
PolishGraphToSupportDataHazards(&result);
/* /*
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
......
...@@ -17,6 +17,46 @@ ...@@ -17,6 +17,46 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
OpHandleBase *write_op = (*it_new)->GeneratedOp();
const auto &read_ops = (*it_old)->PendingOps();
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
bool has_dep = false;
for (auto *r_out : read_op->Outputs()) {
for (auto *w_in : write_op->Inputs()) {
if (r_out->Node() == w_in->Node()) {
has_dep = true;
break;
}
}
}
if (has_dep) continue;
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
}
}
}
}
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place, ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
......
...@@ -57,6 +57,12 @@ class SSAGraphBuilder : public ir::Pass { ...@@ -57,6 +57,12 @@ class SSAGraphBuilder : public ir::Pass {
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected: protected:
/*
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
*/
static void PolishGraphToSupportDataHazards(ir::Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
......
...@@ -98,6 +98,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -98,6 +98,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
} }
} }
if (has_dep) continue; if (has_dep) continue;
ir::Node *dep_var = CreateControlDepVar(); ir::Node *dep_var = CreateControlDepVar();
read_op->outputs.push_back(dep_var); read_op->outputs.push_back(dep_var);
dep_var->inputs.push_back(read_op); dep_var->inputs.push_back(read_op);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册