diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index cae9af7217660fb7e4b8535ee8e022fb3a127668..c62f9a9d0873d351a9297e8f3a5fc035cc2caf85 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -18,6 +18,7 @@ #include #include +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/platform/place.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 6f4bb172c6ed530ff09fc8953ef5c57825efbf79..d1805d7434270096bdf6eb48f090c62ca30e16ba 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -14,6 +14,25 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/platform/variant.h" + namespace paddle { -namespace framework {} // namespace framework +namespace framework { + +class Graph { + public: + std::map> attrs; + + std::vector inputs; + std::vector outputs; + std::vector> nodes; +}; + +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 0c521270693e20b9fa8027d0759ab4ef5448a61c..9a280afb3b7501f53c7811f148eac562dd0ba0a4 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -15,9 +15,11 @@ limitations under the License. */ #pragma once #include +#include #include #include #include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/variant.h" namespace paddle { namespace framework { @@ -29,11 +31,6 @@ class Node { Node() {} virtual ~Node() {} - template - Subclass &As() { - return *dynamic_cast(this); - } - int64_t ID() const { return id_; } std::string Name() const { return name_; } @@ -42,12 +39,15 @@ class Node { return Name() + "(" + std::to_string(ID()) + ")"; } + virtual std::string DebugString() const = 0; + Type NodeType() const { return type_; } std::vector inputs; std::vector outputs; protected: + std::map> attrs_; int64_t id_ = 0; std::string name_; Type type_; diff --git a/paddle/fluid/platform/variant.h b/paddle/fluid/platform/variant.h index 45f60fc9d76560b133fa06198a24c7eaccc24088..dc9fad29f281a1c6ac300b48f9e600ff802a5752 100644 --- a/paddle/fluid/platform/variant.h +++ b/paddle/fluid/platform/variant.h @@ -38,6 +38,7 @@ limitations under the License. */ #endif #endif +#include #include #include #include