提交 9c9e28b5 编写于 作者: X Xin Pan

fix program to graph

上级 64eaa4c8
...@@ -210,7 +210,10 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -210,7 +210,10 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
size_t cur_device_id = 0; size_t cur_device_id = 0;
bool is_forwarding = true; bool is_forwarding = true;
// TODO(panyx0718): FIXME: nodes should be sorted by "program" order. // NOTE: Currently, passes before SSAGraphBuilder cannot reorder
// forward, backward nodes. E.g. you can't append an forward node
// at the end of the node list.
// TODO(panyx0718): FIXME: Needs to sort by forward->backward order.
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue; if (node->NodeType() != ir::Node::Type::kOperation) continue;
if (boost::get<int>( if (boost::get<int>(
......
...@@ -19,31 +19,43 @@ limitations under the License. */ ...@@ -19,31 +19,43 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// NOTE(paddle-dev): This graph contains circle.
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var); all_vars.emplace(var->Name(), var);
} }
std::map<std::string, ir::Node *> var_nodes;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op); ir::Node *node = CreateOpNode(op);
for (auto &each_var_name : op->InputArgumentNames()) { for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr; ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) { if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name);
} else if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name)); var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
} else { } else {
// TODO(paddle-dev): Seems some assumption doesn't hold? // TODO(paddle-dev): Seems some assumption doesn't hold?
LOG(ERROR) << op->Type() LOG(ERROR) << op->Type()
<< " input var not in all_var list: " << each_var_name; << " input var not in all_var list: " << each_var_name;
var = CreateEmptyNode(each_var_name); var = CreateEmptyNode(each_var_name);
var_nodes[each_var_name] = var;
} }
node->inputs.push_back(var); node->inputs.push_back(var);
var->outputs.push_back(node); var->outputs.push_back(node);
} }
for (auto &each_var_name : op->OutputArgumentNames()) { for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = CreateVarNode(all_vars.at(each_var_name)); ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name);
} else {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
}
node->outputs.push_back(var); node->outputs.push_back(var);
var->inputs.push_back(node); var->inputs.push_back(node);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册