From 25123a3b7e6062c2829eff9de43718c6481ee95c Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 6 Nov 2018 16:32:57 +0800 Subject: [PATCH] add tests test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + paddle/fluid/framework/ir/node.h | 22 ++++++- paddle/fluid/framework/ir/node_test.cc | 80 ++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/framework/ir/node_test.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 28231a53bad..4cf973253cc 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") cc_library(pass_builder SRCS pass_builder.cc DEPS pass) +cc_test(node_test SRCS node_test.cc DEPS node) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index b8764e256c1..98650c23f75 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -27,7 +27,24 @@ namespace paddle { namespace framework { namespace ir { -// Node should normally created by Graph::CreateXXXNode(). +// Node should only created by Graph::CreateXXXNode(). +// 1. Every Node should be part of a graph. No dangling Node exists. +// 2. Node only contains members necessary for building graph structure. +// It doesn't contain other unrelated members, such as device, etc. +// +// Sometimes, for specific usages, Node needs to have additional members, +// such as device_placement, version in order to be executed. It is suggested +// to use composition pattern. +// +// class RunnableOp { +// RunnableOp(ir::Node* n) : n_(n) { n_.WrappedBy(this); } +// +// int any_thing_; +// } +// +// RunnableOp is owned by the ir::Node that composes it. In other words. +// ir::Node will be responsible for deleting RunnableOp, say, when ir::Node +// is deleted from the graph. class Node { public: virtual ~Node() { @@ -53,6 +70,7 @@ class Node { return op_desc_.get(); } + // Set the `wrapper` that wraps the Node. `wrapper` is owned by Node. template void WrappedBy(T* wrapper) { if (!wrapper_.empty()) { @@ -63,11 +81,13 @@ class Node { wrapper_type_ = std::type_index(typeid(T)); } + // Return a reference to the `wrapper`. template T& Wrapper() { return *boost::any_cast(wrapper_); } + // Test if the Node is wrapped by type T. template bool IsWrappedBy() { return std::type_index(typeid(T)) == wrapper_type_; diff --git a/paddle/fluid/framework/ir/node_test.cc b/paddle/fluid/framework/ir/node_test.cc new file mode 100644 index 00000000000..694efadda07 --- /dev/null +++ b/paddle/fluid/framework/ir/node_test.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class RunnableOp { + public: + RunnableOp(Node* node, bool* alive) : node_(node), alive_(alive) { + node_->WrappedBy(this); + } + + virtual ~RunnableOp() { *alive_ = false; } + + private: + Node* node_; + bool* alive_; +}; + +class RunnableOp2 { + public: + RunnableOp2(Node* node, bool* alive) : node_(node), alive_(alive) { + node_->WrappedBy(this); + } + + virtual ~RunnableOp2() { *alive_ = false; } + + private: + Node* node_; + bool* alive_; +}; + +TEST(NodeTest, Basic) { + bool alive1 = true; + bool alive2 = true; + std::unique_ptr n1(CreateNodeForTest("n1", Node::Type::kVariable)); + std::unique_ptr n2(CreateNodeForTest("n2", Node::Type::kVariable)); + + EXPECT_FALSE(n1->IsWrappedBy()); + EXPECT_FALSE(n1->IsWrappedBy()); + EXPECT_FALSE(n2->IsWrappedBy()); + EXPECT_FALSE(n2->IsWrappedBy()); + + new RunnableOp(n1.get(), &alive1); + new RunnableOp2(n2.get(), &alive2); + + EXPECT_TRUE(n1->IsWrappedBy()); + EXPECT_FALSE(n1->IsWrappedBy()); + EXPECT_FALSE(n2->IsWrappedBy()); + EXPECT_TRUE(n2->IsWrappedBy()); + + EXPECT_TRUE(alive1); + EXPECT_TRUE(alive2); + + n1.reset(nullptr); + n2.reset(nullptr); + EXPECT_FALSE(alive1); + EXPECT_FALSE(alive2); +} + +} // namespace ir +} // namespace framework +} // namespace paddle -- GitLab