提交 dcaf183d 编写于 作者: X Xin Pan

builder SSA graph at the beginning.

上级 2b2406e5
...@@ -221,15 +221,15 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -221,15 +221,15 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
// forward, backward nodes. E.g. you can't append an forward node // forward, backward nodes. E.g. you can't append an forward node
// at the end of the node list. // at the end of the node list.
// TODO(panyx0718): FIXME: Needs to sort by forward->backward order. // TODO(panyx0718): FIXME: Needs to sort by forward->backward order.
for (auto &node : nodes) { for (ir::Node *node : TopologySortOperationFromInToOut(nodes)) {
if (node->NodeType() != ir::Node::Type::kOperation) continue; VLOG(3) << "apply node: " << node->Name() << reinterpret_cast<void *>(node);
if (boost::get<int>( if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
CreateRPCOp(&result, node.get()); CreateRPCOp(&result, node);
} else if (IsDistTrainOp(node.get(), send_vars, recv_vars)) { } else if (IsDistTrainOp(node, send_vars, recv_vars)) {
CreateDistTrainOp(&result, node.get()); CreateDistTrainOp(&result, node);
} else if (IsScaleLossOp(node.get())) { } else if (IsScaleLossOp(node)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
...@@ -240,10 +240,11 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -240,10 +240,11 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
// It also assumes backward op will always follow the forward op in // It also assumes backward op will always follow the forward op in
// the block. // the block.
is_forwarding = false; is_forwarding = false;
LOG(ERROR) << "forward flipping!!!!!!!";
} else { } else {
int op_dev_id = GetOpDeviceID(node.get()); int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node.get(), op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
var_name_on_devices_.emplace(n->Name(), op_dev_id); var_name_on_devices_.emplace(n->Name(), op_dev_id);
} }
...@@ -252,13 +253,11 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -252,13 +253,11 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
// gradients. // gradients.
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
node->Op()->SetAttr("throw_eof_exp", false); node->Op()->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, node.get(), places_.size()); CreateComputationalOps(&result, node, places_.size());
// TODO(paddle-dev): builder shouldn't depend on the out logic of
// a specific op.
const auto &data_var_names = node->Op()->Output("Out"); const auto &data_var_names = node->Op()->Output("Out");
InsertDataBalanceOp(&result, data_var_names); InsertDataBalanceOp(&result, data_var_names);
} else { } else {
CreateComputationalOps(&result, node.get(), places_.size()); CreateComputationalOps(&result, node, places_.size());
} }
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
...@@ -479,8 +478,8 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { ...@@ -479,8 +478,8 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]); int dev_id = GetVarDeviceID(param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0]); node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id; return dev_id;
} }
......
...@@ -37,6 +37,17 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -37,6 +37,17 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
continue; continue;
} }
bool has_dep = false;
for (auto read_out : read_op->Outputs()) {
for (auto write_in : write_op->Inputs()) {
if (read_out == write_in) {
has_dep = true;
break;
}
}
}
if (has_dep) continue;
auto *dep_var = new DummyVarHandle( auto *dep_var = new DummyVarHandle(
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
......
...@@ -12,14 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace {
void SortHelper(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
ir::Node *node, std::unordered_set<ir::Node *> *visited,
std::vector<ir::Node *> *ret) {
visited->insert(node);
for (auto adj : adj_list.at(node)) {
if (visited->find(adj) == visited->end()) {
SortHelper(adj_list, adj, visited, ret);
}
}
VLOG(3) << "topology sort insert: " << node->Name()
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
ret->push_back(node);
}
} // namespace
// NOTE(paddle-dev): This graph contains circle.
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
VLOG(3) << "block in program:" << program_.Size(); VLOG(3) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
...@@ -27,40 +48,128 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -27,40 +48,128 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
all_vars.emplace(var->Name(), var); all_vars.emplace(var->Name(), var);
} }
std::map<std::string, ir::Node *> var_nodes; ir::Node *last_backward = nullptr;
std::vector<ir::Node *> optimize_ops;
std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op); ir::Node *node = CreateOpNode(op);
if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kBackward)) {
last_backward = node;
} else if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kOptimize)) {
optimize_ops.push_back(node);
}
for (auto &each_var_name : op->InputArgumentNames()) { for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr; ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) { if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name); var = var_nodes.at(each_var_name).back();
} else if (all_vars.count(each_var_name) != 0) { } else if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name)); var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var; var_nodes[each_var_name].push_back(var);
} else { } else {
// TODO(paddle-dev): Seems some assumption doesn't hold? // TODO(paddle-dev): Seems some assumption doesn't hold?
VLOG(3) << op->Type() VLOG(3) << op->Type()
<< " input var not in all_var list: " << each_var_name; << " input var not in all_var list: " << each_var_name;
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
var_nodes[each_var_name] = var; var_nodes[each_var_name].push_back(var);
} }
node->inputs.push_back(var); node->inputs.push_back(var);
var->outputs.push_back(node); var->outputs.push_back(node);
} }
for (auto &each_var_name : op->OutputArgumentNames()) { for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = nullptr; ir::Node *var = CreateVarNode(all_vars.at(each_var_name));
if (var_nodes.find(each_var_name) != var_nodes.end()) { var_nodes[each_var_name].push_back(var);
var = var_nodes.at(each_var_name);
} else {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
}
node->outputs.push_back(var); node->outputs.push_back(var);
var->inputs.push_back(node); var->inputs.push_back(node);
} }
} }
for (auto &var : var_nodes) {
auto &versions = var.second;
if (versions.size() <= 1) continue;
auto it_new = versions.rbegin();
auto it_old = versions.rbegin();
++it_old;
for (; it_old != versions.rend(); it_new = it_old, ++it_old) {
ir::Node *write_op =
(*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0];
const auto &read_ops = (*it_old)->outputs;
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable);
read_op->outputs.push_back(dep_var);
dep_var->inputs.push_back(read_op);
write_op->inputs.push_back(dep_var);
dep_var->outputs.push_back(write_op);
}
}
}
if (last_backward) {
for (ir::Node *opt_node : optimize_ops) {
ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable);
last_backward->outputs.push_back(dep_var);
dep_var->inputs.push_back(last_backward);
opt_node->inputs.push_back(dep_var);
dep_var->outputs.push_back(opt_node);
VLOG(3) << "appending connect: " << last_backward->Name()
<< reinterpret_cast<void *>(last_backward) << "->"
<< opt_node->Name() << reinterpret_cast<void *>(opt_node);
}
}
}
std::vector<ir::Node *> TopologySortOperationFromInToOut(
const std::vector<std::unique_ptr<ir::Node>> &nodes) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
for (auto &n : nodes) {
if (n->NodeType() != ir::Node::Type::kOperation) continue;
if (adj_list.find(n.get()) == adj_list.end()) {
adj_list[n.get()] = std::unordered_set<ir::Node *>();
}
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n.get()].insert(adj_n);
LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n.get())
<< " via " << var->Name() << reinterpret_cast<void *>(var);
}
}
}
for (auto adj : adj_list) {
if (visited.find(adj.first) == visited.end()) {
SortHelper(adj_list, adj.first, &visited, &ret);
}
}
for (ir::Node *n : ret) {
std::unordered_set<ir::Node *> dummy;
n->inputs.erase(
std::remove_if(n->inputs.begin(), n->inputs.end(),
[n](ir::Node *t) { return t->Name() == "dummy"; }),
n->inputs.end());
n->outputs.erase(
std::remove_if(n->outputs.begin(), n->outputs.end(),
[n](ir::Node *t) { return t->Name() == "dummy"; }),
n->outputs.end());
}
return ret;
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -78,5 +78,8 @@ class Graph { ...@@ -78,5 +78,8 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
}; };
std::vector<ir::Node*> TopologySortOperationFromInToOut(
const std::vector<std::unique_ptr<ir::Node>>& nodes);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -76,6 +76,7 @@ TEST(GraphTest, Basic) { ...@@ -76,6 +76,7 @@ TEST(GraphTest, Basic) {
op->SetType("sum"); op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
......
...@@ -50,6 +50,7 @@ class Node { ...@@ -50,6 +50,7 @@ class Node {
PADDLE_ENFORCE(type_ == Type::kVariable); PADDLE_ENFORCE(type_ == Type::kVariable);
return var_desc_; return var_desc_;
} }
OpDesc* Op() { OpDesc* Op() {
PADDLE_ENFORCE(type_ == Type::kOperation); PADDLE_ENFORCE(type_ == Type::kOperation);
return op_desc_; return op_desc_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册