提交 0b3465d2 编写于 作者: X Xin Pan

better

上级 dcaf183d
...@@ -5,7 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod ...@@ -5,7 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph graph_helper)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -186,9 +187,55 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -186,9 +187,55 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
return dev_id; return dev_id;
} }
// Topology sort the graph nodes from inputs to outputs.
// Since SSAGraphBuilder depends on forward/backward nodes to assign devices
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
std::vector<ir::Node *> ret = ir::TopologySort(graph);
size_t last_backward = 0;
std::vector<ir::Node *> optimize_ops;
std::vector<ir::Node *> sorted_ret;
for (size_t i = 0; i < ret.size(); ++i) {
if (boost::get<int>(
ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kBackward)) {
sorted_ret.push_back(ret[i]);
last_backward = sorted_ret.size();
} else if (boost::get<int>(ret[i]->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kOptimize)) {
optimize_ops.push_back(ret[i]);
} else {
sorted_ret.push_back(ret[i]);
}
}
sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
optimize_ops.end());
for (ir::Node *n : sorted_ret) {
n->inputs.erase(std::remove_if(n->inputs.begin(), n->inputs.end(),
[n](ir::Node *t) {
return t->Name() ==
ir::Node::kControlDepVarName;
}),
n->inputs.end());
n->outputs.erase(std::remove_if(n->outputs.begin(), n->outputs.end(),
[n](ir::Node *t) {
return t->Name() ==
ir::Node::kControlDepVarName;
}),
n->outputs.end());
}
return sorted_ret;
}
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<Graph> graph) const { std::unique_ptr<Graph> graph) const {
// Rebuild the graph structure. // Rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
auto nodes = std::move(graph->nodes); auto nodes = std::move(graph->nodes);
graph->nodes.clear(); graph->nodes.clear();
...@@ -217,12 +264,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -217,12 +264,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
size_t cur_device_id = 0; size_t cur_device_id = 0;
bool is_forwarding = true; bool is_forwarding = true;
// NOTE: Currently, passes before SSAGraphBuilder cannot reorder for (ir::Node *node : sorted_ops) {
// forward, backward nodes. E.g. you can't append an forward node
// at the end of the node list.
// TODO(panyx0718): FIXME: Needs to sort by forward->backward order.
for (ir::Node *node : TopologySortOperationFromInToOut(nodes)) {
VLOG(3) << "apply node: " << node->Name() << reinterpret_cast<void *>(node);
if (boost::get<int>( if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
...@@ -240,7 +282,6 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -240,7 +282,6 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
// It also assumes backward op will always follow the forward op in // It also assumes backward op will always follow the forward op in
// the block. // the block.
is_forwarding = false; is_forwarding = false;
LOG(ERROR) << "forward flipping!!!!!!!";
} else { } else {
int op_dev_id = GetOpDeviceID(node); int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
......
cc_library(node SRCS node.cc DEPS proto_desc) cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node) cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node) cc_library(pass SRCS pass.cc DEPS graph node)
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry)
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/*
namespace { namespace {
void SortHelper( void SortHelper(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list, const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
...@@ -39,7 +40,21 @@ void SortHelper( ...@@ -39,7 +40,21 @@ void SortHelper(
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size(); << reinterpret_cast<void *>(node) << " input " << node->inputs.size();
ret->push_back(node); ret->push_back(node);
} }
std::vector<ir::Node*> TopologySort(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
for (auto adj : adj_list) {
if (visited.find(adj.first) == visited.end()) {
SortHelper(adj_list, adj.first, &visited, &ret);
}
}
return ret;
}
} // namespace } // namespace
*/
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
VLOG(3) << "block in program:" << program_.Size(); VLOG(3) << "block in program:" << program_.Size();
...@@ -48,20 +63,9 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -48,20 +63,9 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
all_vars.emplace(var->Name(), var); all_vars.emplace(var->Name(), var);
} }
ir::Node *last_backward = nullptr;
std::vector<ir::Node *> optimize_ops;
std::map<std::string, std::vector<ir::Node *>> var_nodes; std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op); ir::Node *node = CreateOpNode(op);
if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kBackward)) {
last_backward = node;
} else if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kOptimize)) {
optimize_ops.push_back(node);
}
for (auto &each_var_name : op->InputArgumentNames()) { for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr; ir::Node *var = nullptr;
...@@ -106,7 +110,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -106,7 +110,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
// Read Write is the same op. // Read Write is the same op.
continue; continue;
} }
ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable); ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName,
ir::Node::Type::kVariable);
read_op->outputs.push_back(dep_var); read_op->outputs.push_back(dep_var);
dep_var->inputs.push_back(read_op); dep_var->inputs.push_back(read_op);
write_op->inputs.push_back(dep_var); write_op->inputs.push_back(dep_var);
...@@ -114,62 +119,121 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -114,62 +119,121 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
} }
} }
} }
}
if (last_backward) { /*
for (ir::Node *opt_node : optimize_ops) { bool HasCircleHelper(ir::Node* node,
ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable); const std::map<ir::Node *, std::unordered_set<ir::Node *>>
last_backward->outputs.push_back(dep_var); &adj_list,
dep_var->inputs.push_back(last_backward); std::unordered_set<ir::Node*>* visited,
opt_node->inputs.push_back(dep_var); std::unordered_set<ir::Node*>* in_trace) {
dep_var->outputs.push_back(opt_node); if (visited->find(node) == visited->end()) {
VLOG(3) << "appending connect: " << last_backward->Name() visited->insert(node);
<< reinterpret_cast<void *>(last_backward) << "->" in_trace->insert(node);
<< opt_node->Name() << reinterpret_cast<void *>(opt_node);
for (ir::Node *in : adj_list.at(node)) {
if (visited->find(in) == visited->end() &&
HasCircleHelper(in, adj_list, visited, in_trace)) {
return true;
} else if (in_trace->find(in) != in_trace->end()) {
return true;
}
} }
} }
in_trace->erase(node);
return false;
} }
std::vector<ir::Node *> TopologySortOperationFromInToOut( bool HasCircle(const std::map<ir::Node *, std::unordered_set<ir::Node *>>
const std::vector<std::unique_ptr<ir::Node>> &nodes) { &adj_list) {
std::unordered_set<ir::Node*> visited;
std::unordered_set<ir::Node*> in_trace;
for (auto& adj : adj_list) {
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) {
return true;
}
}
return false;
}
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList(
const std::vector<ir::Node*> &nodes) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list; std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
for (auto &n : nodes) { for (auto &n : nodes) {
if (n->NodeType() != ir::Node::Type::kOperation) continue; if (n->NodeType() != ir::Node::Type::kOperation) continue;
if (adj_list.find(n.get()) == adj_list.end()) { if (adj_list.find(n) == adj_list.end()) {
adj_list[n.get()] = std::unordered_set<ir::Node *>(); adj_list[n] = std::unordered_set<ir::Node *>();
} }
for (auto &var : n->inputs) { for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) { for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n.get()].insert(adj_n); adj_list[n].insert(adj_n);
LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n.get()) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
} }
} }
} }
return adj_list;
}
for (auto adj : adj_list) { std::vector<ir::Node *> TopologySortOperationFromInToOut(
if (visited.find(adj.first) == visited.end()) { const std::vector<std::unique_ptr<ir::Node>> &nodes) {
SortHelper(adj_list, adj.first, &visited, &ret); std::vector<ir::Node*> tmp;
for (auto& n : nodes) {
tmp.push_back(n.get());
}
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildAdjList(tmp);
PADDLE_ENFORCE(!HasCircle(adj_list));
std::vector<ir::Node*> ret = TopologySort(adj_list);
ir::Node *last_backward = nullptr;
std::vector<ir::Node *> optimize_ops;
for (ir::Node* n : ret) {
if (boost::get<int>(
n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kBackward)) {
last_backward = n;
} else if (boost::get<int>(
n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kOptimize)) {
optimize_ops.push_back(n);
} }
} }
if (last_backward) {
for (ir::Node *opt_node : optimize_ops) {
ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName,
ir::Node::Type::kVariable);
last_backward->outputs.push_back(dep_var);
dep_var->inputs.push_back(last_backward);
opt_node->inputs.push_back(dep_var);
dep_var->outputs.push_back(opt_node);
VLOG(3) << "appending connect: " << last_backward->Name()
<< reinterpret_cast<void *>(last_backward) << "->"
<< opt_node->Name() << reinterpret_cast<void *>(opt_node);
}
}
PADDLE_ENFORCE(!HasCircle(adj_list));
for (ir::Node *n : ret) { for (ir::Node *n : ret) {
std::unordered_set<ir::Node *> dummy; std::unordered_set<ir::Node *> dummy;
n->inputs.erase( n->inputs.erase(
std::remove_if(n->inputs.begin(), n->inputs.end(), std::remove_if(n->inputs.begin(), n->inputs.end(),
[n](ir::Node *t) { return t->Name() == "dummy"; }), [n](ir::Node *t) {
return t->Name() == ir::Node::kControlDepVarName; }),
n->inputs.end()); n->inputs.end());
n->outputs.erase( n->outputs.erase(
std::remove_if(n->outputs.begin(), n->outputs.end(), std::remove_if(n->outputs.begin(), n->outputs.end(),
[n](ir::Node *t) { return t->Name() == "dummy"; }), [n](ir::Node *t) {
return t->Name() == ir::Node::kControlDepVarName; }),
n->outputs.end()); n->outputs.end());
} }
return ret; return ret;
} }*/
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -78,8 +78,5 @@ class Graph { ...@@ -78,8 +78,5 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
}; };
std::vector<ir::Node*> TopologySortOperationFromInToOut(
const std::vector<std::unique_ptr<ir::Node>>& nodes);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
void SortHelper(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
ir::Node *node, std::unordered_set<ir::Node *> *visited,
std::vector<ir::Node *> *ret) {
visited->insert(node);
for (auto adj : adj_list.at(node)) {
if (visited->find(adj) == visited->end()) {
SortHelper(adj_list, adj, visited, ret);
}
}
LOG(ERROR) << "topology sort insert: " << node->Name()
<< reinterpret_cast<void *>(node) << " input "
<< node->inputs.size();
ret->push_back(node);
}
bool HasCircleHelper(
ir::Node *node,
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
std::unordered_set<ir::Node *> *visited,
std::unordered_set<ir::Node *> *in_trace) {
if (visited->find(node) == visited->end()) {
visited->insert(node);
in_trace->insert(node);
for (ir::Node *in : adj_list.at(node)) {
if (visited->find(in) == visited->end() &&
HasCircleHelper(in, adj_list, visited, in_trace)) {
return true;
} else if (in_trace->find(in) != in_trace->end()) {
return true;
}
}
}
in_trace->erase(node);
return false;
}
} // namespace
bool HasCircle(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildAdjList(graph);
std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> in_trace;
for (auto &adj : adj_list) {
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) {
return true;
}
}
return false;
}
std::vector<ir::Node *> TopologySort(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildAdjList(graph);
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
for (auto adj : adj_list) {
if (visited.find(adj.first) == visited.end()) {
SortHelper(adj_list, adj.first, &visited, &ret);
}
}
return ret;
}
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList(
const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
for (auto &n : graph.nodes) {
if (n->NodeType() != ir::Node::Type::kOperation) continue;
if (adj_list.find(n.get()) == adj_list.end()) {
adj_list[n.get()] = std::unordered_set<ir::Node *>();
}
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n.get()].insert(adj_n);
LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n.get())
<< " via " << var->Name() << reinterpret_cast<void *>(var);
}
}
}
return adj_list;
}
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace framework {
namespace ir {
bool HasCircle(const Graph &graph);
std::vector<ir::Node *> TopologySort(const Graph &graph);
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList(
const Graph &graph);
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -15,5 +15,9 @@ limitations under the License. */ ...@@ -15,5 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
namespace framework {} // namespace framework namespace framework {
namespace ir {
const char Node::kControlDepVarName[] = "__control_var";
} // namespace ir
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -27,6 +27,8 @@ namespace ir { ...@@ -27,6 +27,8 @@ namespace ir {
class Node { class Node {
public: public:
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static const char kControlDepVarName[];
explicit Node(const std::string& name, Type type) explicit Node(const std::string& name, Type type)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册