diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 28231a53bad50fe9f19cfe3e73c3dc09aa3762cf..4cf973253cc4f1f22d2fc578a1ac3a8c95e479c9 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") cc_library(pass_builder SRCS pass_builder.cc DEPS pass) +cc_test(node_test SRCS node_test.cc DEPS node) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index b8764e256c15d54ada4d1c8c33b485f7b3279087..98650c23f753a49c017c03036006da98710928a1 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -27,7 +27,24 @@ namespace paddle { namespace framework { namespace ir { -// Node should normally created by Graph::CreateXXXNode(). +// Node should only created by Graph::CreateXXXNode(). +// 1. Every Node should be part of a graph. No dangling Node exists. +// 2. Node only contains members necessary for building graph structure. +// It doesn't contain other unrelated members, such as device, etc. +// +// Sometimes, for specific usages, Node needs to have additional members, +// such as device_placement, version in order to be executed. It is suggested +// to use composition pattern. +// +// class RunnableOp { +// RunnableOp(ir::Node* n) : n_(n) { n_.WrappedBy(this); } +// +// int any_thing_; +// } +// +// RunnableOp is owned by the ir::Node that composes it. In other words. +// ir::Node will be responsible for deleting RunnableOp, say, when ir::Node +// is deleted from the graph. class Node { public: virtual ~Node() { @@ -53,6 +70,7 @@ class Node { return op_desc_.get(); } + // Set the `wrapper` that wraps the Node. `wrapper` is owned by Node. template void WrappedBy(T* wrapper) { if (!wrapper_.empty()) { @@ -63,11 +81,13 @@ class Node { wrapper_type_ = std::type_index(typeid(T)); } + // Return a reference to the `wrapper`. template T& Wrapper() { return *boost::any_cast(wrapper_); } + // Test if the Node is wrapped by type T. template bool IsWrappedBy() { return std::type_index(typeid(T)) == wrapper_type_; diff --git a/paddle/fluid/framework/ir/node_test.cc b/paddle/fluid/framework/ir/node_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..694efadda078169c993457181c00f7b357a09e87 --- /dev/null +++ b/paddle/fluid/framework/ir/node_test.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class RunnableOp { + public: + RunnableOp(Node* node, bool* alive) : node_(node), alive_(alive) { + node_->WrappedBy(this); + } + + virtual ~RunnableOp() { *alive_ = false; } + + private: + Node* node_; + bool* alive_; +}; + +class RunnableOp2 { + public: + RunnableOp2(Node* node, bool* alive) : node_(node), alive_(alive) { + node_->WrappedBy(this); + } + + virtual ~RunnableOp2() { *alive_ = false; } + + private: + Node* node_; + bool* alive_; +}; + +TEST(NodeTest, Basic) { + bool alive1 = true; + bool alive2 = true; + std::unique_ptr n1(CreateNodeForTest("n1", Node::Type::kVariable)); + std::unique_ptr n2(CreateNodeForTest("n2", Node::Type::kVariable)); + + EXPECT_FALSE(n1->IsWrappedBy()); + EXPECT_FALSE(n1->IsWrappedBy()); + EXPECT_FALSE(n2->IsWrappedBy()); + EXPECT_FALSE(n2->IsWrappedBy()); + + new RunnableOp(n1.get(), &alive1); + new RunnableOp2(n2.get(), &alive2); + + EXPECT_TRUE(n1->IsWrappedBy()); + EXPECT_FALSE(n1->IsWrappedBy()); + EXPECT_FALSE(n2->IsWrappedBy()); + EXPECT_TRUE(n2->IsWrappedBy()); + + EXPECT_TRUE(alive1); + EXPECT_TRUE(alive2); + + n1.reset(nullptr); + n2.reset(nullptr); + EXPECT_FALSE(alive1); + EXPECT_FALSE(alive2); +} + +} // namespace ir +} // namespace framework +} // namespace paddle