未验证 提交 cfa6bbb7 编写于 作者: Y Yan Chunwei 提交者: GitHub

move nodeid from graph to node (#13065)

上级 f88a8ba9
...@@ -87,6 +87,9 @@ bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars, ...@@ -87,6 +87,9 @@ bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
} }
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
// Make the nodes id start from 0.
Node::ResetId();
VLOG(3) << "block in program:" << program_.Size(); VLOG(3) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
......
...@@ -99,13 +99,13 @@ class Graph { ...@@ -99,13 +99,13 @@ class Graph {
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc); PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc, node_count_++)); return AddNode(new ir::Node(var_desc));
} }
// Create a normal runnable operator with OpDesc. // Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc); PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc, node_count_++)); return AddNode(new ir::Node(op_desc));
} }
// Create a control dependency var that connects 2 operations. The // Create a control dependency var that connects 2 operations. The
...@@ -115,14 +115,13 @@ class Graph { ...@@ -115,14 +115,13 @@ class Graph {
// TODO(panyx0718): control var name should be really unique. // TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf( const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); "%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode( return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
new ir::Node(name, ir::Node::Type::kVariable, node_count_++));
} }
// A more free style way of creating a graph node. Mostly use for test // A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible. // or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type, node_count_++)); return AddNode(new ir::Node(name, type));
} }
// Clear all node information of the graph and return the ownership of the // Clear all node information of the graph and return the ownership of the
...@@ -143,9 +142,13 @@ class Graph { ...@@ -143,9 +142,13 @@ class Graph {
nodes_.erase(node); nodes_.erase(node);
} }
// NOTE low performance, but simple and secure.
Node *RetriveNode(int id) { Node *RetriveNode(int id) {
auto it = id2node_.find(id); for (auto &node : nodes_) {
if (it != id2node_.end()) return it->second; if (node.second->id() == id) {
return node.second.get();
}
}
return nullptr; return nullptr;
} }
...@@ -155,8 +158,6 @@ class Graph { ...@@ -155,8 +158,6 @@ class Graph {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node); nodes_[node].reset(node);
node_set_.insert(node); node_set_.insert(node);
PADDLE_ENFORCE(!id2node_.count(node->id()), "duplicate id %d", node->id());
id2node_[node->id()] = node;
return node; return node;
} }
...@@ -166,7 +167,6 @@ class Graph { ...@@ -166,7 +167,6 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_; std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_; std::unordered_set<ir::Node *> node_set_;
std::map<int, Node *> id2node_;
int node_count_{0}; int node_count_{0};
}; };
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
constexpr char Node::kControlDepVarName[]; constexpr char Node::kControlDepVarName[];
int Node::count_ = 0;
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,26 +29,26 @@ class Node { ...@@ -29,26 +29,26 @@ class Node {
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var"; static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type, int id = -1) explicit Node(const std::string& name, Type type)
: name_(name), : name_(name),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(nullptr), op_desc_(nullptr),
type_(type), type_(type),
id_(id) {} id_(count_++) {}
explicit Node(VarDesc* var_desc, int id = -1) explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()), : name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)), var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr), op_desc_(nullptr),
type_(Type::kVariable), type_(Type::kVariable),
id_(id) {} id_(count_++) {}
explicit Node(OpDesc* op_desc, int id = -1) explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())), op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation), type_(Type::kOperation),
id_(id) {} id_(count_++) {}
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
...@@ -80,6 +80,9 @@ class Node { ...@@ -80,6 +80,9 @@ class Node {
int id_; int id_;
private: private:
friend class Graph;
static int count_;
static void ResetId() { count_ = 0; }
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
}; };
......
...@@ -102,7 +102,7 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -102,7 +102,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) { void Analyzer::Run(Argument* argument) {
// Ungly support fluid-to-ir-pass // Ugly support fluid-to-ir-pass
argument->Set(kFluidToIrPassesAttr, argument->Set(kFluidToIrPassesAttr,
new std::vector<std::string>({ new std::vector<std::string>({
// Manual update the passes here. // Manual update the passes here.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册