diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 21b9fa943e1a49bad210a87ee71d635c63e67a0b..72602840fc33e3b630e26748e9a7070fc8ab0840 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -58,9 +58,9 @@ class Graph { return attr; } - std::vector inputs; - std::vector outputs; - std::vector> nodes; + std::vector inputs; + std::vector outputs; + std::vector> nodes; private: std::map attrs_; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 9a280afb3b7501f53c7811f148eac562dd0ba0a4..0fd80483904b5d0bb51e00a36f0878acda09daba 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include @@ -23,13 +24,23 @@ limitations under the License. */ namespace paddle { namespace framework { +namespace ir { class Node { public: enum class Type { kNone = -1, kOperation, kVariable }; - Node() {} - virtual ~Node() {} + Node(const std::string& name, Type type) : name_(name), type_(type) {} + + virtual ~Node() { + for (auto& attr : attrs_) { + if (attr_dels_.find(attr.first) != attr_dels_.end()) { + attr_dels_[attr.first](); + } + } + attr_dels_.clear(); + attrs_.clear(); + } int64_t ID() const { return id_; } @@ -43,17 +54,42 @@ class Node { Type NodeType() const { return type_; } - std::vector inputs; - std::vector outputs; + template + void Set(const std::string& name, AttrType attr) { + attrs_[name] = attr; + } + + template + void Set(const std::string& name, AttrType* attr, + std::function attr_del) { + attrs_[name] = attr; + attr_dels_[name] = attr_del; + } + + std::vector inputs; + std::vector outputs; protected: - std::map> attrs_; + std::map attrs_; + std::map> attr_dels_; int64_t id_ = 0; std::string name_; Type type_; + private: DISABLE_COPY_AND_ASSIGN(Node); }; +class Variable : public Node { + public: + explicit Variable(const std::string& name) : Node(name, Type::kVariable) {} +}; + +class Operation : public Node { + public: + explicit Operation(const std::string& name) : Node(name, Type::kOperation) {} +}; + +} // namespace ir } // namespace framework } // namespace paddle