未验证 提交 044a82d8 编写于 作者: L levi131 提交者: GitHub

Convert all blocks in program into SSAgraphs. (#33320)

As the title, this PR converts all blocks in program into SSA sub graphs and it is guarded by flag
上级 0e4bcede
......@@ -17,6 +17,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/operator.h"
DEFINE_bool(convert_all_blocks, false,
"Convert all blocks in program into SSAgraphs");
namespace paddle {
namespace framework {
namespace ir {
......@@ -24,16 +27,9 @@ namespace ir {
Graph::Graph(const ProgramDesc &program)
: Graph(program, 0, program.Block(0).AllOps().size()) {}
Graph::Graph(const ProgramDesc &program, int64_t start_op_index,
int64_t end_op_index)
: program_(program) {
auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index);
ResolveHazard(var_nodes);
}
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
const ProgramDesc &program, int64_t start_op_index, int64_t end_op_index) {
VLOG(3) << "block in program:" << program_.Size();
Graph::Graph(const ProgramDesc &program, const int64_t start_op_index,
const int64_t end_op_index)
: program_(program), main_graph_(nullptr) {
PADDLE_ENFORCE_GE(start_op_index, 0,
platform::errors::InvalidArgument(
"Required start_op_index >= 0, but received "
......@@ -44,16 +40,65 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
"Required end_op_index >= start_op_index, but received "
"end_op_index: %d < start_op_index: %d",
end_op_index, start_op_index));
PADDLE_ENFORCE_GE(
program_.Size(), 1,
platform::errors::InvalidArgument("Can't construct a graph from this "
"program, it doesn't have a block"));
const int64_t block_op_size = program_.Block(0).AllOps().size();
PADDLE_ENFORCE_LE(end_op_index, block_op_size,
platform::errors::InvalidArgument(
"Required end_op_index <= block_op_size, but received "
"end_op_index: %d > block_op_size: %d",
end_op_index, block_op_size));
if (FLAGS_convert_all_blocks) {
// NOTE(levi): start_op_index and end_op_index only work on the first
// sub_graph.
std::unique_ptr<Graph> first_sub_graph = std::make_unique<Graph>(
program_.Block(0), this, start_op_index, end_op_index);
sub_graphs_.push_back(std::move(first_sub_graph));
for (size_t idx = 1; idx < program_.Size(); ++idx) {
std::unique_ptr<Graph> sub_graph =
std::make_unique<Graph>(program_.Block(idx), this);
sub_graphs_.push_back(std::move(sub_graph));
}
} else {
auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index);
ResolveHazard(var_nodes);
}
}
Graph::Graph(const BlockDesc &block, const Graph *main_graph)
: Graph(block, main_graph, 0, block.AllOps().size()) {}
Graph::Graph(const BlockDesc &block, const Graph *main_graph,
const int64_t start_op_index, const int64_t end_op_index)
: main_graph_(main_graph) {
auto var_nodes = InitFromBlock(block, start_op_index, end_op_index);
ResolveHazard(var_nodes);
}
// TODO(levi): delete this interface after when we can convert all
// blocks into sub_graphs.
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
const ProgramDesc &program, const int64_t start_op_index,
const int64_t end_op_index) {
VLOG(3) << "block in program:" << program_.Size();
return InitFromBlock(program.Block(0), start_op_index, end_op_index);
}
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
const BlockDesc &block, const int64_t start_op_index,
const int64_t end_op_index) {
std::unordered_map<std::string, VarDesc *> all_vars;
// var nodes for each var name, will have multiple versions in SSA
std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *var : program.Block(0).AllVars()) {
for (auto *var : block.AllVars()) {
all_vars.emplace(var->Name(), var);
}
auto not_visited_vars = all_vars;
auto all_ops = program.Block(0).AllOps();
auto all_ops = block.AllOps();
PADDLE_ENFORCE_LE(
end_op_index, all_ops.size(),
platform::errors::InvalidArgument(
......@@ -210,6 +255,18 @@ void Graph::ResolveHazard(
}
std::shared_ptr<Graph> Graph::Clone() {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
platform::errors::InvalidArgument(
"This graph is a sub_graph, and can't be cloned individually"));
if (FLAGS_convert_all_blocks) {
auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseSubGraphs();
for (size_t idx = 0; idx < this->program_.Size(); ++idx) {
cloned_graph->AddSubGraph(this->CloneSubGraph(idx));
}
return cloned_graph;
} else {
auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseNodes();
cloned_graph->num_node_created_ = 0;
......@@ -242,6 +299,49 @@ std::shared_ptr<Graph> Graph::Clone() {
}
}
return cloned_graph;
}
}
std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
platform::errors::InvalidArgument("This graph is not main_graph"));
PADDLE_ENFORCE_LT(
idx, this->sub_graphs_.size(),
platform::errors::InvalidArgument("Invalid sub_graph index"));
std::unique_ptr<Graph> cloned_sub_graph =
std::make_unique<Graph>(this->program_.Block(idx), this);
cloned_sub_graph->ReleaseNodes();
cloned_sub_graph->num_node_created_ = 0;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->sub_graphs_.at(idx)->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
"The node to be cloned is nullptr."));
ir::Node *cloned_node = nullptr;
if (n->IsCtrlVar()) {
cloned_node = cloned_sub_graph->CreateControlDepVar();
} else if (!n->var_desc_ && !n->op_desc_) { // empty node
cloned_node = cloned_sub_graph->CreateEmptyNode(n->Name(), n->NodeType());
} else if (n->IsVar()) {
cloned_node = cloned_sub_graph->CreateVarNode(n->Var());
} else if (n->IsOp()) {
cloned_node = cloned_sub_graph->CreateOpNode(n->Op());
}
PADDLE_ENFORCE_NOT_NULL(
cloned_node,
platform::errors::InvalidArgument(
"Failed to clone new node from original node in graph."));
origin_to_cloned[n] = cloned_node;
}
for (auto *n : this->sub_graphs_.at(idx)->Nodes()) {
for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) {
origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]);
}
for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) {
origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]);
}
}
return cloned_sub_graph;
}
bool IsControlDepVar(const ir::Node &var) {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <map>
#include <memory>
#include <string>
......@@ -25,6 +26,8 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
DECLARE_bool(convert_all_blocks);
namespace paddle {
namespace framework {
class OpDesc;
......@@ -78,10 +81,20 @@ namespace ir {
*/
class Graph {
public:
// Construct a main_graph with some sub_graphs
explicit Graph(const ProgramDesc &program);
// Construct a Graph with ops[start_op_index, end_op_index)
explicit Graph(const ProgramDesc &program, int64_t start_op_index,
int64_t end_op_index);
// Construct a main_graph with some sub_graphs, and the 1st sub_graph is
// constructed with ops[start_op_index, end_op_index)
Graph(const ProgramDesc &program, const int64_t start_op_index,
const int64_t end_op_index);
// Construct a sub_graph
Graph(const BlockDesc &block, const Graph *main_graph);
// Construct a sub_graph with ops[start_op_index, end_op_index)
Graph(const BlockDesc &block, const Graph *main_graph,
const int64_t start_op_index, const int64_t end_op_index);
virtual ~Graph() {
for (auto &attr : attrs_) {
......@@ -94,11 +107,21 @@ class Graph {
bool IsConstructedByPartialProgram() const { return is_partial_; }
bool Has(const std::string &attr_name) const {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->Has(attr_name);
}
}
return attrs_.count(attr_name) > 0;
}
template <typename AttrType>
AttrType &GetOrInit(const std::string &attr_name) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->GetOrInit<AttrType>(attr_name);
}
}
if (!Has(attr_name)) {
Set(attr_name, new AttrType);
}
......@@ -107,6 +130,11 @@ class Graph {
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->Get<AttrType>(attr_name);
}
}
PADDLE_ENFORCE_EQ(
Has(attr_name), true,
platform::errors::PreconditionNotMet(
......@@ -123,6 +151,11 @@ class Graph {
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->Set<AttrType>(attr_name, attr);
}
}
PADDLE_ENFORCE_EQ(
attrs_.count(attr_name), 0,
platform::errors::AlreadyExists(
......@@ -137,6 +170,11 @@ class Graph {
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->SetNotOwned<AttrType>(attr_name, attr);
}
}
PADDLE_ENFORCE_EQ(
attrs_.count(attr_name), 0,
platform::errors::AlreadyExists("The attribute %s to be set(not owned) "
......@@ -147,6 +185,11 @@ class Graph {
}
void Erase(const std::string &attr_name) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->Erase(attr_name);
}
}
PADDLE_ENFORCE_NE(
attrs_.count(attr_name), 0,
platform::errors::NotFound(
......@@ -157,10 +200,22 @@ class Graph {
attr_dels_.erase(attr_name);
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
const std::unordered_set<ir::Node *> &Nodes() const {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->Nodes();
}
}
return node_set_;
}
// Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->CreateVarNode(var_desc);
}
}
PADDLE_ENFORCE_NOT_NULL(
var_desc, platform::errors::InvalidArgument(
"The VarDesc used to create variable node is null."));
......@@ -171,6 +226,11 @@ class Graph {
// Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->CreateOpNode(op_desc);
}
}
PADDLE_ENFORCE_NOT_NULL(
op_desc, platform::errors::InvalidArgument(
"The OpDesc used to create operator node is null."));
......@@ -183,6 +243,11 @@ class Graph {
// var doesn't hold any data. Other than that, it's no different from
// other var, considering dependency analysis.
ir::Node *CreateControlDepVar() {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->CreateControlDepVar();
}
}
// TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf(
"%s@%llu", static_cast<const char *>(ir::Node::kControlDepVarName),
......@@ -195,6 +260,11 @@ class Graph {
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->CreateEmptyNode(name, type);
}
}
auto *x = AddNode(new ir::Node(name, type));
x->SetId(num_node_created_++);
return x;
......@@ -203,6 +273,11 @@ class Graph {
// Clear all node information of the graph and return the ownership of the
// nodes.
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->ReleaseNodes();
}
}
std::vector<std::unique_ptr<ir::Node>> ret;
for (auto &n : nodes_) {
ret.emplace_back(n.second.release());
......@@ -213,6 +288,11 @@ class Graph {
}
std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->RemoveNode(node);
}
}
PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true,
platform::errors::PreconditionNotMet(
"The node to be removed does not exist."));
......@@ -225,6 +305,11 @@ class Graph {
// NOTE low performance, but simple and secure.
Node *RetrieveNode(int id) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->RetrieveNode(id);
}
}
for (auto &node : nodes_) {
if (node.second->id() == id) {
return node.second.get();
......@@ -237,10 +322,22 @@ class Graph {
// WARN: After a series of passes, the current graph can be quite
// different from OriginProgram. Caller shouldn't assume much from
// the returned OriginProgram.
const ProgramDesc &OriginProgram() const { return program_; }
const ProgramDesc &OriginProgram() const {
if (FLAGS_convert_all_blocks) {
if (!IsMainGraph()) {
return main_graph_->OriginProgram();
}
}
return program_;
}
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->AddNode(node);
}
}
PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true,
platform::errors::PreconditionNotMet(
"The node to be added already exists."));
......@@ -256,12 +353,59 @@ class Graph {
// WARN: The method only clones the graph structure, not its attributes.
std::shared_ptr<Graph> Clone();
bool IsMainGraph() const { return main_graph_ == nullptr; }
Graph *GetSubGraph(const size_t idx) const {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
platform::errors::InvalidArgument("This graph is not main_graph"));
PADDLE_ENFORCE_LT(
idx, sub_graphs_.size(),
platform::errors::InvalidArgument("Invalid sub_graph index"));
return sub_graphs_.at(idx).get();
}
size_t SubGraphsSize() const {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
platform::errors::InvalidArgument("This graph is not main_graph"));
return sub_graphs_.size();
}
private:
// TODO(levi): delete this interface after when we can convert all
// blocks into sub_graphs.
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program, int64_t start_op_index, int64_t end_op_index);
const ProgramDesc &program, const int64_t start_op_index,
const int64_t end_op_index);
std::map<std::string, std::vector<ir::Node *>> InitFromBlock(
const BlockDesc &block, const int64_t start_op_index,
const int64_t end_op_index);
void ReleaseSubGraphs() {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
platform::errors::InvalidArgument("This graph is not main_graph"));
sub_graphs_.clear();
}
void AddSubGraph(std::unique_ptr<Graph> sub_graph) {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
platform::errors::InvalidArgument("This graph is not main_graph"));
sub_graphs_.push_back(std::move(sub_graph));
}
std::unique_ptr<Graph> CloneSubGraph(const size_t idx);
// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc program_;
// NOTE: main_graph_ doesn't hold any node. It's used as a container of
// sub_graphs, and the sub_graph holds the nodes.
const Graph *main_graph_; // not owned.
std::vector<std::unique_ptr<Graph>> sub_graphs_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
......
......@@ -264,5 +264,181 @@ TEST(GraphTest, TestAttrCopy) {
ASSERT_FALSE(dst_g.Has(kFloatValue));
}
TEST(GraphTest, TestInterfaceConvertAllBlocks) {
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool flag_temp = FLAGS_convert_all_blocks;
FLAGS_convert_all_blocks = true;
ProgramDesc prog;
prog.MutableBlock(0)->Var("init_var")->SetType(proto::VarType::SELECTED_ROWS);
ir::Graph g(prog);
ASSERT_TRUE(g.IsMainGraph());
const std::string kIntValue = "int_value";
const int INT_VALUE = 3;
g.Set<int>(kIntValue, new int(INT_VALUE));
ASSERT_TRUE(g.Has(kIntValue));
ASSERT_EQ(g.GetOrInit<int>(kIntValue), INT_VALUE);
ASSERT_EQ(g.Get<int>(kIntValue), INT_VALUE);
g.Erase(kIntValue);
ASSERT_TRUE(!g.Has(kIntValue));
g.SetNotOwned<int>(kIntValue, new int(INT_VALUE));
ASSERT_TRUE(g.Has(kIntValue));
g.Erase(kIntValue);
g.ReleaseNodes();
ASSERT_EQ(g.Nodes().size(), 0UL);
g.CreateVarNode(new VarDesc("temp_var_desc_name"));
g.CreateOpNode(prog.MutableBlock(0)->AppendOp());
g.CreateControlDepVar();
g.CreateEmptyNode("temp_empty_node_name", ir::Node::Type::kVariable);
ASSERT_EQ(g.Nodes().size(), 4UL);
g.RemoveNode(g.RetrieveNode(1));
ASSERT_EQ(g.Nodes().size(), 3UL);
// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks = flag_temp;
}
TEST(GraphTest, TestMultiBlock) {
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool flag_temp = FLAGS_convert_all_blocks;
FLAGS_convert_all_blocks = true;
// Step1: Build a program with 3 blocks.
ProgramDesc prog;
ASSERT_EQ(prog.Size(), 1UL);
prog.AppendBlock(prog.Block(0));
prog.AppendBlock(prog.Block(0));
ASSERT_EQ(prog.Size(), 3UL);
// Set contents in block_0.
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType());
// Set contents in block_1.
op = prog.MutableBlock(1)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(1)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetAttr("op_role", 1);
prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
// Set contents in block_2.
op = prog.MutableBlock(2)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(2)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
// Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs.
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ASSERT_EQ(g->IsMainGraph(), true);
ASSERT_EQ(g->SubGraphsSize(), 3UL);
// Check contents in sub_graph_0.
const ir::Graph *g0 = g->GetSubGraph(0);
std::vector<ir::Node *> nodes(g0->Nodes().begin(), g0->Nodes().end());
for (ir::Node *n : nodes) {
if (n->Name() == "sum") {
ASSERT_EQ(n->inputs.size(), 3UL);
ASSERT_EQ(n->outputs.size(), 1UL);
} else if (n->Name() == "test_a" || n->Name() == "test_b" ||
n->Name() == "test_c") {
ASSERT_EQ(n->inputs.size(), 0UL);
ASSERT_EQ(n->outputs.size(), 1UL);
} else if (n->Name() == "test_out") {
ASSERT_EQ(n->inputs.size(), 1UL);
ASSERT_EQ(n->outputs.size(), 0UL);
}
}
ASSERT_EQ(nodes.size(), 5UL);
// Check contents in sub_graph_1.
const ir::Graph *g1 = g->GetSubGraph(1);
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g1->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_EQ(control_dep1, control_dep2);
// Check contents in sub_graph_2.
const ir::Graph *g2 = g->GetSubGraph(2);
control_dep1 = nullptr;
control_dep2 = nullptr;
for (ir::Node *n : g2->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
// Step3: Clone graph.
std::shared_ptr<ir::Graph> clone_g = g->Clone();
ASSERT_EQ(clone_g->IsMainGraph(), true);
ASSERT_EQ(clone_g->SubGraphsSize(), 3UL);
// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks = flag_temp;
}
} // namespace framework
} // namespace paddle
......@@ -135,6 +135,93 @@ TEST(PassTest, TestPassAttrCheck) {
exception.npos);
}
TEST(PassTest, TestPassAttrCheckConvertAllBlocks) {
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool flag_temp = FLAGS_convert_all_blocks;
FLAGS_convert_all_blocks = true;
ProgramDesc prog;
auto pass = PassRegistry::Instance().Get("test_pass");
std::unique_ptr<Graph> graph(new Graph(prog));
std::string exception;
try {
graph.reset(pass->Apply(graph.release()));
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find("Required atrribute test_pass_attr for pass < "
"test_pass > is not set") != exception.npos);
int val = 1;
graph.reset(new Graph(prog));
pass->SetNotOwned<int>("test_pass_attr", &val);
for (std::string try_type : {"bool", "const int", "std::string"}) {
try {
if (try_type == "bool") {
pass->Get<bool>("test_pass_attr");
} else if (try_type == "const int") {
pass->Get<const int>("test_pass_attr");
} else if (try_type == "std::string") {
pass->Get<std::string>("test_pass_attr");
}
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
std::string msg = "Invalid type for attritube test_pass_attr, expected: " +
try_type + ", actual: int";
ASSERT_TRUE(exception.find(msg) != exception.npos);
}
try {
graph.reset(pass->Apply(graph.release()));
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find(
"Required atrribute test_graph_attr for graph is not set") !=
exception.npos);
graph.reset(new Graph(prog));
graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 1;
graph.reset(pass->Apply(graph.release()));
ASSERT_EQ(graph->Get<int>("copy_test_pass_attr"), 2);
ASSERT_EQ(graph->Get<int>("copy_test_graph_attr"), 2);
// Allow apply more than once.
graph.reset(new Graph(prog));
graph->Set<int>("test_graph_attr", new int);
graph.reset(pass->Apply(graph.release()));
pass = PassRegistry::Instance().Get("test_pass");
pass->SetNotOwned<int>("test_pass_attr", &val);
graph.reset(new Graph(prog));
BuildCircleGraph(graph.get());
graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 2;
try {
pass->Apply(graph.release());
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find("shouldn't contain cycle") != exception.npos);
pass = PassRegistry::Instance().Get("test_pass");
pass->Set<int>("test_pass_attr", new int);
try {
pass->Set<int>("test_pass_attr", new int);
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(
exception.find("Attribute test_pass_attr already set in the pass") !=
exception.npos);
// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks = flag_temp;
}
class TestPassWithDefault : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const {
......@@ -160,6 +247,28 @@ TEST(PassTest, TestPassDefaultAttrCheck) {
ASSERT_EQ(pass->Get<int>("default_attr"), 3);
}
TEST(PassTest, TestPassDefaultAttrCheckConvertAllBlocks) {
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool flag_temp = FLAGS_convert_all_blocks;
FLAGS_convert_all_blocks = true;
ProgramDesc prog;
// check if default value is set
auto pass = PassRegistry::Instance().Get("test_pass_default_attr");
std::unique_ptr<Graph> graph(new Graph(prog));
ASSERT_EQ(pass->Get<int>("default_attr"), 1);
graph.reset(pass->Apply(graph.release()));
ASSERT_EQ(graph->Get<int>("copy_default_attr"), 2);
// check if new value overrides default value
pass = PassRegistry::Instance().Get("test_pass_default_attr");
pass->Set<int>("default_attr", new int{3});
ASSERT_EQ(pass->Get<int>("default_attr"), 3);
// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks = flag_temp;
}
TEST(PassTest, TestPassRegistrarDeconstructor) {
auto pass_registrary =
new PassRegistrar<paddle::framework::ir::TestPassWithDefault>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册