提交 25123a3b 编写于 作者: X Xin Pan

add tests

test=develop
上级 8c11d3fe
...@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") ...@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass) cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
......
...@@ -27,7 +27,24 @@ namespace paddle { ...@@ -27,7 +27,24 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// Node should normally created by Graph::CreateXXXNode(). // Node should only created by Graph::CreateXXXNode().
// 1. Every Node should be part of a graph. No dangling Node exists.
// 2. Node only contains members necessary for building graph structure.
// It doesn't contain other unrelated members, such as device, etc.
//
// Sometimes, for specific usages, Node needs to have additional members,
// such as device_placement, version in order to be executed. It is suggested
// to use composition pattern.
//
// class RunnableOp {
// RunnableOp(ir::Node* n) : n_(n) { n_.WrappedBy(this); }
//
// int any_thing_;
// }
//
// RunnableOp is owned by the ir::Node that composes it. In other words.
// ir::Node will be responsible for deleting RunnableOp, say, when ir::Node
// is deleted from the graph.
class Node { class Node {
public: public:
virtual ~Node() { virtual ~Node() {
...@@ -53,6 +70,7 @@ class Node { ...@@ -53,6 +70,7 @@ class Node {
return op_desc_.get(); return op_desc_.get();
} }
// Set the `wrapper` that wraps the Node. `wrapper` is owned by Node.
template <typename T> template <typename T>
void WrappedBy(T* wrapper) { void WrappedBy(T* wrapper) {
if (!wrapper_.empty()) { if (!wrapper_.empty()) {
...@@ -63,11 +81,13 @@ class Node { ...@@ -63,11 +81,13 @@ class Node {
wrapper_type_ = std::type_index(typeid(T)); wrapper_type_ = std::type_index(typeid(T));
} }
// Return a reference to the `wrapper`.
template <typename T> template <typename T>
T& Wrapper() { T& Wrapper() {
return *boost::any_cast<T*>(wrapper_); return *boost::any_cast<T*>(wrapper_);
} }
// Test if the Node is wrapped by type T.
template <typename T> template <typename T>
bool IsWrappedBy() { bool IsWrappedBy() {
return std::type_index(typeid(T)) == wrapper_type_; return std::type_index(typeid(T)) == wrapper_type_;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class RunnableOp {
public:
RunnableOp(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
class RunnableOp2 {
public:
RunnableOp2(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp2() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
TEST(NodeTest, Basic) {
bool alive1 = true;
bool alive2 = true;
std::unique_ptr<Node> n1(CreateNodeForTest("n1", Node::Type::kVariable));
std::unique_ptr<Node> n2(CreateNodeForTest("n2", Node::Type::kVariable));
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp2>());
new RunnableOp(n1.get(), &alive1);
new RunnableOp2(n2.get(), &alive2);
EXPECT_TRUE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_TRUE(n2->IsWrappedBy<RunnableOp2>());
EXPECT_TRUE(alive1);
EXPECT_TRUE(alive2);
n1.reset(nullptr);
n2.reset(nullptr);
EXPECT_FALSE(alive1);
EXPECT_FALSE(alive2);
}
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册