未验证 提交 cfa6bbb7 编写于 作者: Y Yan Chunwei 提交者: GitHub

move nodeid from graph to node (#13065)

上级 f88a8ba9
......@@ -87,6 +87,9 @@ bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
}
Graph::Graph(const ProgramDesc &program) : program_(program) {
// Make the nodes id start from 0.
Node::ResetId();
VLOG(3) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
......
......@@ -99,13 +99,13 @@ class Graph {
// Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc, node_count_++));
return AddNode(new ir::Node(var_desc));
}
// Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc, node_count_++));
return AddNode(new ir::Node(op_desc));
}
// Create a control dependency var that connects 2 operations. The
......@@ -115,14 +115,13 @@ class Graph {
// TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(
new ir::Node(name, ir::Node::Type::kVariable, node_count_++));
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
}
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type, node_count_++));
return AddNode(new ir::Node(name, type));
}
// Clear all node information of the graph and return the ownership of the
......@@ -143,9 +142,13 @@ class Graph {
nodes_.erase(node);
}
// NOTE low performance, but simple and secure.
Node *RetriveNode(int id) {
auto it = id2node_.find(id);
if (it != id2node_.end()) return it->second;
for (auto &node : nodes_) {
if (node.second->id() == id) {
return node.second.get();
}
}
return nullptr;
}
......@@ -155,8 +158,6 @@ class Graph {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node);
node_set_.insert(node);
PADDLE_ENFORCE(!id2node_.count(node->id()), "duplicate id %d", node->id());
id2node_[node->id()] = node;
return node;
}
......@@ -166,7 +167,6 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_;
std::map<int, Node *> id2node_;
int node_count_{0};
};
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace framework {
namespace ir {
constexpr char Node::kControlDepVarName[];
int Node::count_ = 0;
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -29,26 +29,26 @@ class Node {
enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type, int id = -1)
explicit Node(const std::string& name, Type type)
: name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(id) {}
id_(count_++) {}
explicit Node(VarDesc* var_desc, int id = -1)
explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr),
type_(Type::kVariable),
id_(id) {}
id_(count_++) {}
explicit Node(OpDesc* op_desc, int id = -1)
explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation),
id_(id) {}
id_(count_++) {}
Type NodeType() const { return type_; }
......@@ -80,6 +80,9 @@ class Node {
int id_;
private:
friend class Graph;
static int count_;
static void ResetId() { count_ = 0; }
DISABLE_COPY_AND_ASSIGN(Node);
};
......
......@@ -102,7 +102,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) {
// Ungly support fluid-to-ir-pass
// Ugly support fluid-to-ir-pass
argument->Set(kFluidToIrPassesAttr,
new std::vector<std::string>({
// Manual update the passes here.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册