From 7231ef6b68334ef095b643b565a8b2e52806c150 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 11 Jul 2018 21:43:42 +0800 Subject: [PATCH] tmp --- paddle/fluid/framework/ir/graph.h | 6 ++-- paddle/fluid/framework/ir/node.h | 46 +++++++++++++++++++++++++++---- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 21b9fa943..72602840f 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -58,9 +58,9 @@ class Graph { return attr; } - std::vector inputs; - std::vector outputs; - std::vector> nodes; + std::vector inputs; + std::vector outputs; + std::vector> nodes; private: std::map attrs_; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 9a280afb3..0fd804839 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include @@ -23,13 +24,23 @@ limitations under the License. */ namespace paddle { namespace framework { +namespace ir { class Node { public: enum class Type { kNone = -1, kOperation, kVariable }; - Node() {} - virtual ~Node() {} + Node(const std::string& name, Type type) : name_(name), type_(type) {} + + virtual ~Node() { + for (auto& attr : attrs_) { + if (attr_dels_.find(attr.first) != attr_dels_.end()) { + attr_dels_[attr.first](); + } + } + attr_dels_.clear(); + attrs_.clear(); + } int64_t ID() const { return id_; } @@ -43,17 +54,42 @@ class Node { Type NodeType() const { return type_; } - std::vector inputs; - std::vector outputs; + template + void Set(const std::string& name, AttrType attr) { + attrs_[name] = attr; + } + + template + void Set(const std::string& name, AttrType* attr, + std::function attr_del) { + attrs_[name] = attr; + attr_dels_[name] = attr_del; + } + + std::vector inputs; + std::vector outputs; protected: - std::map> attrs_; + std::map attrs_; + std::map> attr_dels_; int64_t id_ = 0; std::string name_; Type type_; + private: DISABLE_COPY_AND_ASSIGN(Node); }; +class Variable : public Node { + public: + explicit Variable(const std::string& name) : Node(name, Type::kVariable) {} +}; + +class Operation : public Node { + public: + explicit Operation(const std::string& name) : Node(name, Type::kOperation) {} +}; + +} // namespace ir } // namespace framework } // namespace paddle -- GitLab