diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md index a141dcbca584c6064c8da863410692a8be911d12..a33b5a9c9312c93247a1e1f3431061a5aad6c884 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/draft.md @@ -1,16 +1,16 @@ ## Motivation -There is a ```gap``` between the ```Program``` defined by -user and the ```Executable``` that can be scheduled +There is a `gap` between the `Program` defined by +user and the `Executable` that can be scheduled efficiently on heterogeneous hardware, either locally or distributedly. -Usually, the ```gap``` is bridged by +Usually, the `gap` is bridged by * A serious transformations with defined order. * These transformations usually involve -```insert, delete, clustering, split, dependency analysis```. +`insert, delete, clustering, split, dependency analysis`. * Has a simple way to verify and debug each transformation. @@ -38,44 +38,44 @@ design below. #### Node -```Node``` represents an operation that performs some computation or +`Node` represents an operation that performs some computation or a variable that is input or output of operation. -```Node```s are connected to other ```Node```s via inputs and outputs. +`Node`s are connected to other `Node`s via inputs and outputs. Other properties (maybe device placement information) can be added -to ```Node``` in the future if it's a -common requirement of many other ```Pass```es. Otherwise, it should live -in a ```Node``` wrapper class that is private to some ```Pass``` or be -a local member of a ```Pass```. +to `Node` in the future if it's a +common requirement of many other `Pass`es. Otherwise, it should live +in a `Node` wrapper class that is private to some `Pass` or be +a local member of a `Pass`. #### Graph -```Graph``` contains a list of ```Node```s, which are connected to +`Graph` contains a list of `Node`s, which are connected to each other via inputs and outputs. TODO: Better definitions for the graph. -```Graph``` can also contain ```Attribute```s. ```Attribute```s -can be ``any`` thing. For example, it can be a list of "wraper" -nodes. The ```wrapper``` nodes compose ```Node```s and provide -helper method for execution or transformation. ```Attribute``` +`Graph` can also contain `Attribute`s. `Attribute`s +can be `any` thing. For example, it can be a list of "wraper" +nodes. The `wrapper` nodes compose `Node`s and provide +helper method for execution or transformation. `Attribute` can also contain other things that describe some properties of -the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed -across ```Pass```. However, it should be used with care. +the `Graph` or `Graph` nodes. `Attribute` can be passed +across `Pass`. However, it should be used with care. #### Pass -```Pass``` represents a transformation of ```Graph```. Its input -is a ```Graph``` and its output is also a ```Graph```. For example, -a ```Pass``` can simply print out the ```Graph```. A ```Pass``` -can also fuse some ```Graph```'s ```Node```s. +`Pass` represents a transformation of `Graph`. Its input +is a `Graph` and its output is also a `Graph`. For example, +a `Pass` can simply print out the `Graph`. A `Pass` +can also fuse some `Graph`'s `Node`s. #### Optimize -```Optimize``` contains a series of ```Pass``` with defined order. -```Optimize``` transforms a ```Graph``` that only contains raw -modeling logic to a ```Graph``` that can be run efficiently while +`Optimize` contains a series of `Pass` with defined order. +`Optimize` transforms a `Graph` that only contains raw +modeling logic to a `Graph` that can be run efficiently while maintaining the original modeling logic. diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index de0cac0f1a69617d11311402d3d7deacc344a0e0..22f0cb20d01cc5b40325ec37a8c7cd44105bc6c6 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -196,38 +196,46 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { std::vector ret = ir::TopologySortOperations(graph); size_t last_backward = 0; - std::vector optimize_ops; - std::vector sorted_ret; for (size_t i = 0; i < ret.size(); ++i) { if (boost::get( ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kBackward)) { - sorted_ret.push_back(ret[i]); - last_backward = sorted_ret.size(); - } else if (boost::get(ret[i]->Op()->GetAttr( - OpProtoAndCheckerMaker::OpRoleAttrName())) == - static_cast(OpRole::kOptimize)) { - optimize_ops.push_back(ret[i]); - } else { - sorted_ret.push_back(ret[i]); + last_backward = i; } } - // Verify that no operations before optimize ops depends on optimize ops. - std::unordered_set optimize_set(optimize_ops.begin(), - optimize_ops.end()); - for (size_t i = 0; i < last_backward; ++i) { - for (ir::Node *in : sorted_ret[i]->inputs) { - for (ir::Node *pre_n : in->inputs) { - PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(), - "optimize operations cannot be depended by forward " - "or backward node %s -> %s", - pre_n->Name(), sorted_ret[i]->Name()); + std::vector optimize_ops; + std::vector sorted_ret; + for (size_t i = 0; i < ret.size(); ++i) { + if (i < last_backward) { + if (boost::get(ret[i]->Op()->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kOptimize)) { + optimize_ops.push_back(ret[i]); + } else { + sorted_ret.push_back(ret[i]); + } + } else if (i == last_backward) { + sorted_ret.push_back(ret[i]); + // Verify that no operations before optimize ops depends on optimize ops. + std::unordered_set optimize_set(optimize_ops.begin(), + optimize_ops.end()); + for (ir::Node *n : sorted_ret) { + for (ir::Node *in : n->inputs) { + for (ir::Node *pre_n : in->inputs) { + PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(), + "optimize operations cannot be depended by forward " + "or backward node %s -> %s", + pre_n->Name(), n->Name()); + } + } } + sorted_ret.insert(sorted_ret.end(), optimize_ops.begin(), + optimize_ops.end()); + } else { + sorted_ret.push_back(ret[i]); } } - sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(), - optimize_ops.end()); return sorted_ret; } @@ -239,7 +247,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( ir::Graph &result = *graph; for (auto &node : nodes) { - if (node->NodeType() == ir::Node::Type::kVariable) { + if (node->NodeType() == ir::Node::Type::kVariable && node->Var()) { all_vars_.emplace(node->Name(), node->Var()); } } @@ -361,6 +369,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( } } } + /* + Dependency graph has been constructed. However, there are still data + hazards need to be handled. + */ + PolishGraphToSupportDataHazards(&result); /* * Only variables should be the leaves of graph. diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 203d5fbbc1d163b9ec3e8e110a8fb11f3d741773..506e7eb35cd977869424223cb863dd64dbaa9d30 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -17,6 +17,46 @@ namespace paddle { namespace framework { namespace details { +void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { + for (auto &var_map : graph->Get("vars")) { + for (auto &name_pair : var_map) { + if (name_pair.second.size() <= 1) { + continue; + } + auto it_new = name_pair.second.rbegin(); + auto it_old = name_pair.second.rbegin(); + ++it_old; + for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { + OpHandleBase *write_op = (*it_new)->GeneratedOp(); + const auto &read_ops = (*it_old)->PendingOps(); + + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + if (read_op == write_op) { + // Read Write is the same op. + continue; + } + bool has_dep = false; + for (auto *r_out : read_op->Outputs()) { + for (auto *w_in : write_op->Inputs()) { + if (r_out->Node() == w_in->Node()) { + has_dep = true; + break; + } + } + } + if (has_dep) continue; + + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + read_op->AddOutput(dep_var); + write_op->AddInput(dep_var); + graph->Get("dep_vars").emplace(dep_var); + } + } + } + } +} + VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index e99bab518e9601ffc54cbcb43bba57e3bf4ea3c6..2b4f31f2ff3444f909e3be5eb810ae6737e237b2 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -57,6 +57,12 @@ class SSAGraphBuilder : public ir::Pass { DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: + /* + Dependency graph has been constructed. However, there are still data + hazards need to be handled. + */ + static void PolishGraphToSupportDataHazards(ir::Graph *graph); + static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset); diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index c5b48514773616fb5860fc93aa98238a3283dcec..740acfafb7594d8d9f3ca5439323ce76c5ed271a 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -98,6 +98,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { } } if (has_dep) continue; + ir::Node *dep_var = CreateControlDepVar(); read_op->outputs.push_back(dep_var); dep_var->inputs.push_back(read_op);