未验证 提交 167523e7 编写于 作者: J jiangcheng 提交者: GitHub

graph_to_program topology sort (#33949)

See https://github.com/PaddlePaddle/Paddle/pull/33949 for details
上级 f1654de6
...@@ -56,10 +56,12 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index, ...@@ -56,10 +56,12 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index,
// sub_graph. // sub_graph.
std::unique_ptr<Graph> first_sub_graph = std::make_unique<Graph>( std::unique_ptr<Graph> first_sub_graph = std::make_unique<Graph>(
program_.Block(0), this, start_op_index, end_op_index); program_.Block(0), this, start_op_index, end_op_index);
first_sub_graph->block_id_ = 0;
sub_graphs_.push_back(std::move(first_sub_graph)); sub_graphs_.push_back(std::move(first_sub_graph));
for (size_t idx = 1; idx < program_.Size(); ++idx) { for (size_t idx = 1; idx < program_.Size(); ++idx) {
std::unique_ptr<Graph> sub_graph = std::unique_ptr<Graph> sub_graph =
std::make_unique<Graph>(program_.Block(idx), this); std::make_unique<Graph>(program_.Block(idx), this);
sub_graph->block_id_ = idx;
sub_graphs_.push_back(std::move(sub_graph)); sub_graphs_.push_back(std::move(sub_graph));
} }
} else { } else {
...@@ -90,14 +92,32 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -90,14 +92,32 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock( std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
const BlockDesc &block, const int64_t start_op_index, const BlockDesc &block, const int64_t start_op_index,
const int64_t end_op_index) { const int64_t end_op_index) {
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, std::pair<VarDesc *, int>>
name_to_desc_block_id;
const BlockDesc *block_var_visible = &block;
while (block_var_visible != nullptr) {
for (auto *var : block_var_visible->AllVars()) {
name_to_desc_block_id.emplace(
var->Name(), std::make_pair(var, block_var_visible->ID()));
}
const BlockDesc *forward_block = block_var_visible->ForwardBlock();
if (forward_block != nullptr) {
for (auto *var : forward_block->AllVars()) {
name_to_desc_block_id.emplace(var->Name(),
std::make_pair(var, forward_block->ID()));
}
}
block_var_visible = block_var_visible->ParentBlock();
}
// var nodes for each var name, will have multiple versions in SSA // var nodes for each var name, will have multiple versions in SSA
std::map<std::string, std::vector<ir::Node *>> var_nodes; std::map<std::string, std::vector<ir::Node *>> var_nodes;
std::unordered_map<std::string, VarDesc *> not_visited_vars;
for (auto *var : block.AllVars()) { for (auto *var : block.AllVars()) {
all_vars.emplace(var->Name(), var); not_visited_vars.emplace(var->Name(), var);
} }
auto not_visited_vars = all_vars; int desc_order = 0;
auto all_ops = block.AllOps(); auto all_ops = block.AllOps();
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
end_op_index, all_ops.size(), end_op_index, all_ops.size(),
...@@ -109,6 +129,8 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock( ...@@ -109,6 +129,8 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
auto *op = all_ops[i]; auto *op = all_ops[i];
VLOG(3) << "create OpNode by " << op->Type(); VLOG(3) << "create OpNode by " << op->Type();
ir::Node *node = CreateOpNode(op); ir::Node *node = CreateOpNode(op);
node->SetDescOrder(desc_order);
++desc_order;
// For input args, reuse the same var name if it was created before. // For input args, reuse the same var name if it was created before.
// Otherwise, create a new one. // Otherwise, create a new one.
for (auto &each_var_name : op->InputArgumentNames()) { for (auto &each_var_name : op->InputArgumentNames()) {
...@@ -116,8 +138,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock( ...@@ -116,8 +138,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
ir::Node *var = nullptr; ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) { if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name).back(); var = var_nodes.at(each_var_name).back();
} else if (all_vars.count(each_var_name) != 0) { } else if (name_to_desc_block_id.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name)); auto desc_and_block_id = name_to_desc_block_id.at(each_var_name);
var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second);
var_nodes[each_var_name].push_back(var); var_nodes[each_var_name].push_back(var);
} else { } else {
// Operation input var can be optional (dispensable). Which means // Operation input var can be optional (dispensable). Which means
...@@ -143,8 +166,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock( ...@@ -143,8 +166,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
} }
ir::Node *var = nullptr; ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) { if (name_to_desc_block_id.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name)); auto desc_and_block_id = name_to_desc_block_id.at(each_var_name);
var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second);
} else { } else {
// Operation output vars can be @EMPTY@. For example, while_grad // Operation output vars can be @EMPTY@. For example, while_grad
// can have multi @EMPTY@ outputs with no VarDesc. // can have multi @EMPTY@ outputs with no VarDesc.
...@@ -270,6 +294,7 @@ std::shared_ptr<Graph> Graph::Clone() { ...@@ -270,6 +294,7 @@ std::shared_ptr<Graph> Graph::Clone() {
auto cloned_graph = std::make_shared<Graph>(this->program_); auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseNodes(); cloned_graph->ReleaseNodes();
cloned_graph->num_node_created_ = 0; cloned_graph->num_node_created_ = 0;
cloned_graph->block_id_ = this->block_id_;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned; std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->node_set_) { for (auto *n : this->node_set_) {
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
...@@ -313,6 +338,7 @@ std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) { ...@@ -313,6 +338,7 @@ std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) {
std::make_unique<Graph>(this->program_.Block(idx), this); std::make_unique<Graph>(this->program_.Block(idx), this);
cloned_sub_graph->ReleaseNodes(); cloned_sub_graph->ReleaseNodes();
cloned_sub_graph->num_node_created_ = 0; cloned_sub_graph->num_node_created_ = 0;
cloned_sub_graph->block_id_ = idx;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned; std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->sub_graphs_.at(idx)->Nodes()) { for (auto *n : this->sub_graphs_.at(idx)->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
......
...@@ -104,7 +104,14 @@ class Graph { ...@@ -104,7 +104,14 @@ class Graph {
attr_dels_.clear(); attr_dels_.clear();
} }
bool IsConstructedByPartialProgram() const { return is_partial_; } bool IsConstructedByPartialProgram() const {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->IsConstructedByPartialProgram();
}
}
return is_partial_;
}
bool Has(const std::string &attr_name) const { bool Has(const std::string &attr_name) const {
if (FLAGS_convert_all_blocks) { if (FLAGS_convert_all_blocks) {
...@@ -210,7 +217,7 @@ class Graph { ...@@ -210,7 +217,7 @@ class Graph {
} }
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc, int block_id = -1) {
if (FLAGS_convert_all_blocks) { if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) { if (IsMainGraph()) {
return GetSubGraph(0)->CreateVarNode(var_desc); return GetSubGraph(0)->CreateVarNode(var_desc);
...@@ -219,7 +226,8 @@ class Graph { ...@@ -219,7 +226,8 @@ class Graph {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var_desc, platform::errors::InvalidArgument( var_desc, platform::errors::InvalidArgument(
"The VarDesc used to create variable node is null.")); "The VarDesc used to create variable node is null."));
auto *x = AddNode(new ir::Node(var_desc)); auto *x =
AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
return x; return x;
} }
...@@ -252,7 +260,7 @@ class Graph { ...@@ -252,7 +260,7 @@ class Graph {
const std::string name = string::Sprintf( const std::string name = string::Sprintf(
"%s@%llu", static_cast<const char *>(ir::Node::kControlDepVarName), "%s@%llu", static_cast<const char *>(ir::Node::kControlDepVarName),
num_node_created_); num_node_created_);
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable)); auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
return x; return x;
} }
...@@ -265,7 +273,7 @@ class Graph { ...@@ -265,7 +273,7 @@ class Graph {
return GetSubGraph(0)->CreateEmptyNode(name, type); return GetSubGraph(0)->CreateEmptyNode(name, type);
} }
} }
auto *x = AddNode(new ir::Node(name, type)); auto *x = AddNode(new ir::Node(name, type, block_id_));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
return x; return x;
} }
...@@ -365,6 +373,15 @@ class Graph { ...@@ -365,6 +373,15 @@ class Graph {
return sub_graphs_.at(idx).get(); return sub_graphs_.at(idx).get();
} }
int GetBlockId() const {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->block_id_;
}
}
return block_id_;
}
size_t SubGraphsSize() const { size_t SubGraphsSize() const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true, this->IsMainGraph(), true,
...@@ -394,6 +411,9 @@ class Graph { ...@@ -394,6 +411,9 @@ class Graph {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true, this->IsMainGraph(), true,
platform::errors::InvalidArgument("This graph is not main_graph")); platform::errors::InvalidArgument("This graph is not main_graph"));
PADDLE_ENFORCE_EQ(sub_graphs_.size(), sub_graph->block_id_,
platform::errors::InvalidArgument(
"sub_graph idx is not equal to block_id_"));
sub_graphs_.push_back(std::move(sub_graph)); sub_graphs_.push_back(std::move(sub_graph));
} }
...@@ -416,6 +436,8 @@ class Graph { ...@@ -416,6 +436,8 @@ class Graph {
// parts: forward graph and backward graph, which can be executed // parts: forward graph and backward graph, which can be executed
// independently. // independently.
bool is_partial_{false}; bool is_partial_{false};
// The block this SubGraph belongs to.
int block_id_{0};
}; };
bool IsControlDepVar(const ir::Node &var); bool IsControlDepVar(const ir::Node &var);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include <queue>
#include <stack> #include <stack>
DEFINE_string(print_sub_graph_dir, "", DEFINE_string(print_sub_graph_dir, "",
...@@ -395,6 +396,85 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph, ...@@ -395,6 +396,85 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph,
} }
} }
class DescOrderComparator {
public:
bool operator()(const Node *n1, const Node *n2) {
return (n1->DescOrder() > n2->DescOrder()) ||
((n1->DescOrder() == n2->DescOrder()) &&
(n1->ToString() > n2->ToString()));
}
};
std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph) {
std::vector<ir::Node *> sorted_ops;
std::priority_queue<Node *, std::vector<Node *>, DescOrderComparator> q;
std::unordered_map<Node *, std::unordered_set<Node *>> in_ops;
std::unordered_map<Node *, std::unordered_set<Node *>> out_ops;
// ensure all op node in 'in_ops' and 'out_ops'
for (const auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
in_ops.emplace(n, std::unordered_set<Node *>());
out_ops.emplace(n, std::unordered_set<Node *>());
}
// record all op's input op and output op
for (const auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
// traverse all input op
for (const auto &var : n->inputs) {
for (const auto &in : var->inputs) {
// use at instead of [] to prevent no unrecorded op node
in_ops.at(n).insert(in);
out_ops.at(in).insert(n);
}
}
}
// find topology entrance
for (const auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
if (in_ops.at(n).empty()) {
q.push(n);
}
}
// topological sorting
while (!q.empty()) {
// Do not get by reference!!! The element will pop later.
const auto cur_op = q.top();
q.pop();
sorted_ops.push_back(cur_op);
for (const auto &out : out_ops.at(cur_op)) {
PADDLE_ENFORCE_GT(in_ops.at(out).count(cur_op), 0,
platform::errors::InvalidArgument(
"We find %s in %s's output list, "
"but cannot find %s in %s's input list. "
"Please ensure graph completely.",
out->Name().c_str(), cur_op->Name().c_str(),
cur_op->Name().c_str(), out->Name().c_str()));
in_ops.at(out).erase(cur_op);
// push if in-degree is 0
if (in_ops.at(out).empty()) {
q.push(out);
}
}
}
PADDLE_ENFORCE_EQ(
sorted_ops.size(), in_ops.size(),
platform::errors::InvalidArgument("Topological sorting incompletely, "
"only sorted %zd op but total %zd.",
sorted_ops.size(), in_ops.size()));
return sorted_ops;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -87,6 +87,8 @@ std::vector<T *> FilterByNodeWrapper(const Graph &graph) { ...@@ -87,6 +87,8 @@ std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
return ret; return ret;
} }
std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,7 +14,13 @@ limitations under the License. */ ...@@ -14,7 +14,13 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <gflags/gflags.h>
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
DECLARE_bool(convert_all_blocks);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -27,13 +33,10 @@ namespace framework { ...@@ -27,13 +33,10 @@ namespace framework {
namespace ir { namespace ir {
void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const { void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
// Remove the unneeded variables after memory optimization. PADDLE_ENFORCE_EQ(graph->IsMainGraph(), true,
std::unordered_set<std::string> vars2remove; platform::errors::InvalidArgument(
if (graph->Has(kGraphToProgramVarsToRemove)) { "This graph is a sub_graph, "
vars2remove = graph->Get<std::unordered_set<std::string>>( "and can't convert to program individually"));
kGraphToProgramVarsToRemove);
VLOG(2) << "graph to program remove " << vars2remove.size() << " nodes";
}
ProgramDesc& program = Get<ProgramDesc>("program"); ProgramDesc& program = Get<ProgramDesc>("program");
...@@ -42,12 +45,79 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const { ...@@ -42,12 +45,79 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
auto block = program_pb->mutable_blocks(kRootBlockIndex); auto block = program_pb->mutable_blocks(kRootBlockIndex);
block->set_idx(kRootBlockIndex); block->set_idx(kRootBlockIndex);
if (FLAGS_convert_all_blocks) {
GraphToBlock(graph->GetSubGraph(kRootBlockIndex), block);
VLOG(3) << "Graph to program need convert " << graph->SubGraphsSize()
<< " sub graph";
for (size_t idx = 0; idx < graph->SubGraphsSize(); ++idx) {
// avoid kRootBlockIndex not 0
if (idx == kRootBlockIndex) continue;
block = program_pb->add_blocks();
block->set_idx(idx);
GraphToBlock(graph->GetSubGraph(idx), block);
}
} else {
GraphToBlock(graph, block);
}
program.CopyFrom(*program_pb);
}
OpDesc* ReplaceScaleLossGradOp(ir::Node* node, OpDesc* desc) {
desc->SetType("fill_constant");
desc->SetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss)));
desc->SetAttr("value", 1.0f);
std::vector<std::string> output_names;
for (auto out : node->outputs) {
output_names.emplace_back(out->Name());
}
desc->SetOutput("Out", output_names);
return desc;
}
std::vector<OpDesc>* GetGraphOpDesc(const std::vector<ir::Node*>& nodes,
std::vector<OpDesc>* ops) {
for (ir::Node* n : nodes) {
// if node is not Op, skip
if (!n->IsOp()) continue;
// create fill_constant op
if (n->Name() == "scale_loss_grad") {
ops->emplace_back();
auto& desc = ops->back();
ReplaceScaleLossGradOp(n, &desc);
} else if (n->Op()) {
ops->emplace_back(*n->Op());
} else {
// delete no OpDesc op
}
}
return ops;
}
void GraphToProgramPass::GraphToBlock(const Graph* graph,
proto::BlockDesc* block) const {
// Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove;
if (graph->Has(kGraphToProgramVarsToRemove)) {
vars2remove = graph->Get<std::unordered_set<std::string>>(
kGraphToProgramVarsToRemove);
VLOG(2) << "graph (id: " << block->idx() << ") to program remove "
<< vars2remove.size() << " nodes";
}
block->clear_vars(); block->clear_vars();
std::unordered_set<std::string> visited_vars; std::unordered_set<std::string> visited_vars;
for (ir::Node* n : graph->Nodes()) { for (ir::Node* n : graph->Nodes()) {
if (n->IsVar()) { if (n->IsVar()) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 && if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 &&
!vars2remove.count(n->Var()->Name())) { !vars2remove.count(n->Var()->Name()) &&
n->GetVarNodeBlockId() == graph->GetBlockId()) {
visited_vars.insert(n->Var()->Name()); visited_vars.insert(n->Var()->Name());
block->add_vars()->MergeFrom(*n->Var()->Proto()); block->add_vars()->MergeFrom(*n->Var()->Proto());
} }
...@@ -61,17 +131,19 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const { ...@@ -61,17 +131,19 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
int sort_kind = Get<int>(kGraphToProgramSortKind); int sort_kind = Get<int>(kGraphToProgramSortKind);
nodes = TopologyVarientSort( nodes = TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(sort_kind)); *graph, static_cast<framework::ir::SortKind>(sort_kind));
} else {
if (FLAGS_convert_all_blocks) {
nodes = TopologySortGraphByDescOrder(*graph);
} else { } else {
nodes = TopologySortOperations(*graph); nodes = TopologySortOperations(*graph);
} }
for (ir::Node* n : nodes) {
if (!n->Op()) continue;
block->add_ops()->MergeFrom(*n->Op()->Proto());
} }
program.CopyFrom(*program_pb); std::vector<OpDesc> ops;
GetGraphOpDesc(nodes, &ops);
for (auto& op : ops) {
block->add_ops()->MergeFrom(*op.Proto());
}
} }
} // namespace ir } // namespace ir
......
...@@ -29,6 +29,9 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__"; ...@@ -29,6 +29,9 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class GraphToProgramPass : public Pass { class GraphToProgramPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private:
void GraphToBlock(const Graph* graph, proto::BlockDesc* block) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -14,8 +14,14 @@ limitations under the License. */ ...@@ -14,8 +14,14 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <algorithm>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -103,6 +109,382 @@ TEST(GraphToProgramPass, Basic) { ...@@ -103,6 +109,382 @@ TEST(GraphToProgramPass, Basic) {
EXPECT_TRUE(vars.find("var2") != vars.end()); EXPECT_TRUE(vars.find("var2") != vars.end());
EXPECT_TRUE(vars.find("var3") != vars.end()); EXPECT_TRUE(vars.find("var3") != vars.end());
} }
void BuildProgramWithMultiBlock(ProgramDesc* program) {
auto* global_block = program->MutableBlock(0);
auto* mul_1_x = global_block->Var("Mul_1_X");
mul_1_x->SetType(proto::VarType::LOD_TENSOR);
mul_1_x->SetLoDLevel(0);
mul_1_x->SetDataType(proto::VarType::FP32);
mul_1_x->SetShape({1000, 784});
auto* mul_1_y = global_block->Var("Mul_1_Y");
mul_1_y->SetType(proto::VarType::LOD_TENSOR);
mul_1_y->SetLoDLevel(0);
mul_1_y->SetDataType(proto::VarType::FP32);
mul_1_y->SetShape({784, 100});
auto* mul_1_out = global_block->Var("Mul_1_Out");
mul_1_out->SetType(proto::VarType::LOD_TENSOR);
auto* mul_op_1 = global_block->AppendOp();
mul_op_1->SetType("mul");
mul_op_1->SetInput("X", {mul_1_x->Name()});
mul_op_1->SetInput("Y", {mul_1_y->Name()});
mul_op_1->SetOutput("Y", {mul_1_out->Name()});
// building cond op such as less_than
auto* less_than_op_1 = global_block->AppendOp();
less_than_op_1->SetType("less_than");
auto* less_than_1_x = global_block->Var("Less_than_1_X");
less_than_1_x->SetType(proto::VarType::LOD_TENSOR);
less_than_1_x->SetLoDLevel(0);
less_than_1_x->SetDataType(proto::VarType::FP32);
less_than_1_x->SetShape({1});
auto* less_than_1_y = global_block->Var("Less_than_1_Y");
less_than_1_y->SetType(proto::VarType::LOD_TENSOR);
less_than_1_y->SetLoDLevel(0);
less_than_1_y->SetDataType(proto::VarType::FP32);
less_than_1_y->SetShape({1});
auto* less_than_1_out = global_block->Var("Less_than_1_Out");
less_than_1_out->SetType(proto::VarType::BOOL);
less_than_op_1->SetInput("X", {less_than_1_x->Name()});
less_than_op_1->SetInput("Y", {less_than_1_y->Name()});
less_than_op_1->SetOutput("Out", {less_than_1_out->Name()});
BlockDesc* sub_block = program->AppendBlock(*global_block);
std::vector<BlockDesc*> sub_blocks;
sub_blocks.push_back(sub_block);
BlockDesc* sub_block2 =
program->AppendBlock(*sub_block); // for testing nested case.
sub_blocks.push_back(sub_block2);
// building while op in sub_block
auto* while_op = global_block->AppendOp();
while_op->SetType("while");
while_op->SetAttr("sub_block", sub_blocks[0]);
auto* while_x = global_block->Var("While_X");
while_x->SetType(proto::VarType::LOD_TENSOR);
while_x->SetLoDLevel(0);
while_x->SetDataType(proto::VarType::FP32);
while_x->SetShape({1});
while_op->SetInput("kX", {while_x->Name()});
while_op->SetInput("kCondition", {less_than_1_out->Name()});
auto* while_out = global_block->Var("While_Out");
while_out->SetType(proto::VarType::LOD_TENSOR);
while_out->SetLoDLevel(0);
while_out->SetDataType(proto::VarType::FP32);
while_out->SetShape({1});
auto* steps = global_block->Var("StepScopes");
while_op->SetOutput("kOutputs", {while_out->Name()});
while_op->SetOutput("kStepScopes", {steps->Name()});
auto* mul_2_x = global_block->Var("Mul_2_X");
mul_2_x->SetType(proto::VarType::LOD_TENSOR);
mul_2_x->SetLoDLevel(0);
mul_2_x->SetDataType(proto::VarType::FP32);
mul_2_x->SetShape({1000, 784});
auto* mul_2_y = global_block->Var("Mul_2_Y");
mul_2_y->SetType(proto::VarType::LOD_TENSOR);
mul_2_y->SetLoDLevel(0);
mul_2_y->SetDataType(proto::VarType::FP32);
mul_2_y->SetShape({784, 100});
auto* mul_op_2 = sub_blocks[0]->AppendOp();
mul_op_2->SetType("mul");
mul_op_2->SetInput("X", {mul_2_x->Name()});
mul_op_2->SetInput("Y", {mul_2_y->Name()});
auto* mul_2_out = global_block->Var("Mul_2_Out");
mul_2_out->SetType(proto::VarType::LOD_TENSOR);
mul_op_2->SetOutput("Y", {mul_2_out->Name()});
auto* less_than_op_2 = sub_blocks[0]->AppendOp();
less_than_op_2->SetType("less_than");
auto* less_than_2_x = global_block->Var("Less_than_2_X");
less_than_2_x->SetType(proto::VarType::LOD_TENSOR);
less_than_2_x->SetLoDLevel(0);
less_than_2_x->SetDataType(proto::VarType::FP32);
less_than_2_x->SetShape({1});
auto* less_than_2_y = global_block->Var("Less_than_2_Y");
less_than_2_y->SetType(proto::VarType::LOD_TENSOR);
less_than_2_y->SetLoDLevel(0);
less_than_2_y->SetDataType(proto::VarType::FP32);
less_than_2_y->SetShape({1});
less_than_op_2->SetInput("X", {less_than_2_x->Name()});
less_than_op_2->SetInput("Y", {less_than_2_y->Name()});
auto* less_than_2_out = global_block->Var("Less_than_2_Out");
less_than_2_out->SetType(proto::VarType::BOOL);
less_than_op_2->SetOutput("Out", {less_than_2_out->Name()});
auto* cond_op = sub_blocks[0]->AppendOp();
cond_op->SetType("conditional_block");
cond_op->SetAttr("sub_block", sub_blocks[1]);
auto* cond_x = sub_blocks[0]->Var("Cond_X");
cond_x->SetType(proto::VarType::LOD_TENSOR);
cond_x->SetLoDLevel(0);
cond_x->SetDataType(proto::VarType::FP32);
cond_x->SetShape({1});
cond_op->SetInput("kInputs", {cond_x->Name()});
cond_op->SetInput("kCondition", {less_than_2_out->Name()});
auto* cond_out = sub_blocks[0]->Var("Cond_Out");
cond_out->SetType(proto::VarType::LOD_TENSOR);
cond_out->SetLoDLevel(0);
cond_out->SetDataType(proto::VarType::FP32);
cond_out->SetShape({1});
auto* scope = sub_blocks[0]->Var("Scope");
scope->SetType(proto::VarType::STEP_SCOPES);
cond_op->SetOutput("kOutputs", {cond_out->Name()});
cond_op->SetOutput("kScope", {scope->Name()});
auto* mul_3_x = global_block->Var("Mul_3_X");
mul_3_x->SetType(proto::VarType::LOD_TENSOR);
mul_3_x->SetLoDLevel(0);
mul_3_x->SetDataType(proto::VarType::FP32);
mul_3_x->SetShape({1000, 784});
auto* mul_3_y = global_block->Var("Mul_3_Y");
mul_3_y->SetType(proto::VarType::LOD_TENSOR);
mul_3_y->SetLoDLevel(0);
mul_3_y->SetDataType(proto::VarType::FP32);
mul_3_y->SetShape({784, 100});
auto* mul_3_out = global_block->Var("Mul_3_Out");
mul_3_out->SetType(proto::VarType::LOD_TENSOR);
auto* mul_op_3 = sub_blocks[1]->AppendOp();
mul_op_3->SetType("mul");
mul_op_3->SetInput("X", {mul_3_x->Name()});
mul_op_3->SetInput("Y", {mul_3_y->Name()});
mul_op_3->SetOutput("Y", {mul_3_out->Name()});
}
bool VarComparator(const VarDesc* a, const VarDesc* b) {
return a->Name() < b->Name();
}
void CheckBlockVarsEqual(const BlockDesc& before_block,
const BlockDesc& after_block) {
auto before_vars = before_block.AllVars();
auto after_vars = after_block.AllVars();
EXPECT_EQ(before_vars.size(), after_vars.size());
// var's order is unimportant
std::sort(before_vars.begin(), before_vars.end(), VarComparator);
std::sort(after_vars.begin(), after_vars.end(), VarComparator);
for (size_t var_idx = 0; var_idx < before_vars.size(); ++var_idx) {
const auto& before_var = before_vars.at(var_idx);
const auto& after_var = after_vars.at(var_idx);
EXPECT_EQ(before_var->Name(), after_var->Name());
EXPECT_EQ(before_var->GetType(), after_var->GetType());
}
}
void CheckOpInputsEqual(const OpDesc* before_op, const OpDesc* after_op) {
const auto& before_inputs = before_op->InputNames();
const auto& after_inputs = after_op->InputNames();
EXPECT_EQ(before_inputs.size(), after_inputs.size());
for (size_t in_idx = 0; in_idx < before_inputs.size(); ++in_idx) {
const auto& before_in_arg = before_inputs[in_idx];
const auto& after_in_arg = after_inputs[in_idx];
EXPECT_EQ(before_in_arg, after_in_arg);
const auto& before_in_vars = before_op->Input(before_in_arg);
const auto& after_in_vars = after_op->Input(after_in_arg);
EXPECT_EQ(before_in_vars, after_in_vars);
}
}
void CheckOpOutputsEqual(const OpDesc* before_op, const OpDesc* after_op) {
const auto& before_outputs = before_op->OutputNames();
const auto& after_outputs = after_op->OutputNames();
EXPECT_EQ(before_outputs.size(), after_outputs.size());
for (size_t out_idx = 0; out_idx < before_outputs.size(); ++out_idx) {
const auto& before_out_arg = before_outputs[out_idx];
const auto& after_out_arg = after_outputs[out_idx];
EXPECT_EQ(before_out_arg, after_out_arg);
const auto& before_out_vars = before_op->Output(before_out_arg);
const auto& after_out_vars = after_op->Output(after_out_arg);
EXPECT_EQ(before_out_vars, after_out_vars);
}
}
void CheckOpAttrsEqual(const OpDesc* before_op, const OpDesc* after_op) {
const auto& before_attrs = before_op->AttrNames();
const auto& after_attrs = after_op->AttrNames();
EXPECT_EQ(before_attrs.size(), after_attrs.size());
for (size_t attr_idx = 0; attr_idx < before_attrs.size(); ++attr_idx) {
const auto& before_attr = before_attrs[attr_idx];
const auto& after_attr = after_attrs[attr_idx];
EXPECT_EQ(before_attr, after_attr);
EXPECT_EQ(before_op->GetAttrType(before_attr),
after_op->GetAttrType(after_attr));
}
}
void CheckBlockOpsEqual(const BlockDesc& before_block,
const BlockDesc& after_block) {
EXPECT_EQ(before_block.OpSize(), after_block.OpSize());
// op's order must be the same
for (size_t op_idx = 0; op_idx < before_block.OpSize(); ++op_idx) {
const auto& before_op = before_block.Op(op_idx);
const auto& after_op = after_block.Op(op_idx);
EXPECT_EQ(before_op->Type(), after_op->Type());
// Step4.2.1 : check each op's input
CheckOpInputsEqual(before_op, after_op);
// Step4.2.2 : check each op's output
CheckOpOutputsEqual(before_op, after_op);
// Step4.2.3 : check each op's attribute
CheckOpAttrsEqual(before_op, after_op);
}
}
TEST(GraphToProgramPass, MultiBlock) {
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool flag_temp = FLAGS_convert_all_blocks;
FLAGS_convert_all_blocks = true;
// Step1: Build a program with multi block
ProgramDesc before_prog;
BuildProgramWithMultiBlock(&before_prog);
// Step2: Convert program into graph
std::unique_ptr<Graph> g(new ir::Graph(before_prog));
// Step3 : Convert graph back to program
auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
"graph_to_program_pass");
ProgramDesc after_prog;
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &after_prog);
pass->Apply(g.get());
// Step4 : Check tow program equal
EXPECT_EQ(before_prog.Size(), after_prog.Size());
for (size_t block_idx = 0; block_idx < before_prog.Size(); ++block_idx) {
const auto& before_block = before_prog.Block(block_idx);
const auto& after_block = after_prog.Block(block_idx);
EXPECT_EQ(before_block.ID(), after_block.ID());
// Step4.1 : check each block's var
CheckBlockVarsEqual(before_block, after_block);
// Step4.2 : check each block's op
CheckBlockOpsEqual(before_block, after_block);
}
// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks = flag_temp;
}
void BuildProgramWithScaleLossGrad(Graph* g) {
OpDesc op1;
op1.SetType("op1");
OpDesc op2;
op2.SetType("op2");
OpDesc op3;
op3.SetType("op3");
OpDesc op4;
op4.SetType("op4");
VarDesc var1("var1");
VarDesc var2("var2");
ir::Node* o1 = g->CreateOpNode(&op1);
ir::Node* o2 = g->CreateOpNode(&op2);
ir::Node* o3 =
g->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation);
ir::Node* o4 =
g->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation);
ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2);
// o1->v1->o2
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
// o3->v1
o3->outputs.push_back(v1);
v1->inputs.push_back(o1);
v1->inputs.push_back(o3);
// o4->v2
o4->outputs.push_back(v2);
v2->inputs.push_back(o4);
}
TEST(GraphToProgramPass, ReplaceScaleLossGrad) {
// Step1: Build a program with multi block
ProgramDesc before_prog;
Graph before_graph(before_prog);
BuildProgramWithScaleLossGrad(&before_graph);
// Step2 : Convert graph back to program
auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
"graph_to_program_pass");
ProgramDesc after_prog;
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &after_prog);
pass->Apply(&before_graph);
// Step3 : statistics scale_loss_grad and fill_constant number
int scale_node_num = 0, fill_node_num = 0;
const auto& before_nodes_set = before_graph.Nodes();
for (const auto& n : before_nodes_set) {
if (n->Name() == "scale_loss_grad") {
++scale_node_num;
} else if (n->Name() == "fill_constant") {
++fill_node_num;
}
}
int scale_op_num = 0, fill_op_num = 0;
const auto& block = after_prog.Block(0);
for (const auto& op : block.AllOps()) {
if (op->Type() == "fill_constant") {
++fill_op_num;
} else if (op->Type() == "scale_loss_grad") {
++scale_op_num;
}
}
// Check pass OK
EXPECT_EQ(scale_op_num, 0);
EXPECT_EQ(scale_node_num + fill_node_num, fill_op_num);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -26,6 +26,13 @@ using OpVariant = operators::OpVariant; ...@@ -26,6 +26,13 @@ using OpVariant = operators::OpVariant;
class WhileOpEagerDeletionPass : public ir::Pass { class WhileOpEagerDeletionPass : public ir::Pass {
protected: protected:
void ApplyImpl(ir::Graph *graph) const override { void ApplyImpl(ir::Graph *graph) const override {
if (!graph->IsMainGraph()) {
// TODO(zhhsplendid): the WhileOpEagerDeletionPass is based on old Graph,
// which only applies to the main block graph. The new Eager Deletion
// Technical can be added after we write new while_op based on SubGraph
// instead of SubBlock
return;
}
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph); auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
// Find all while_op and while_grad_op. In case of @to_static, graph // Find all while_op and while_grad_op. In case of @to_static, graph
...@@ -47,6 +54,7 @@ class WhileOpEagerDeletionPass : public ir::Pass { ...@@ -47,6 +54,7 @@ class WhileOpEagerDeletionPass : public ir::Pass {
} }
} }
if (graph->IsConstructedByPartialProgram()) { if (graph->IsConstructedByPartialProgram()) {
VLOG(4) << "Is Paritial Program";
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
target_ops.size(), 1, target_ops.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -69,8 +77,11 @@ class WhileOpEagerDeletionPass : public ir::Pass { ...@@ -69,8 +77,11 @@ class WhileOpEagerDeletionPass : public ir::Pass {
} }
for (auto &ops_pair : target_ops) { for (auto &ops_pair : target_ops) {
VLOG(4) << "Scope Idx = " << ops_pair.first;
auto &while_ops = ops_pair.second.first; auto &while_ops = ops_pair.second.first;
VLOG(4) << "while_ops.size() = " << while_ops.size();
auto &while_grad_ops = ops_pair.second.second; auto &while_grad_ops = ops_pair.second.second;
VLOG(4) << "while_grad_ops.size() = " << while_grad_ops.size();
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
graph->OriginProgram(), while_ops, while_grad_ops); graph->OriginProgram(), while_ops, while_grad_ops);
} }
......
...@@ -30,7 +30,7 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name, ...@@ -30,7 +30,7 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
} }
std::unique_ptr<Node> CreateNodeForTest(VarDesc *var_desc) { std::unique_ptr<Node> CreateNodeForTest(VarDesc *var_desc) {
return std::unique_ptr<Node>(new Node(var_desc)); return std::unique_ptr<Node>(new Node(var_desc, 0));
} }
std::unique_ptr<Node> CreateNodeForTest(OpDesc *op_desc) { std::unique_ptr<Node> CreateNodeForTest(OpDesc *op_desc) {
......
...@@ -136,9 +136,98 @@ class Node { ...@@ -136,9 +136,98 @@ class Node {
var_desc_->SetName(new_name); var_desc_->SetName(new_name);
} }
int DescOrder() const { return desc_order_; }
int GetVarNodeBlockId() const {
PADDLE_ENFORCE_EQ(
type_ == Type::kVariable && var_desc_, true,
platform::errors::InvalidArgument("Node must be type of variable."));
return block_id_;
}
const std::string ToString() const {
if (IsOp()) {
std::string op_str(Name());
const auto& op = Op();
if (op == nullptr) {
// Node is an Op but hasn't OpDesc (often create by CreateEmptyNode),
// like ScaleLossGradOp, it's type is OpHandle, which created by Pass
// and then inserted into graph.
// For OpHandle, we have to use Node's input and output for sorting.
std::vector<Node*> sorted_inputs(inputs);
std::vector<Node*> sorted_outputs(outputs);
auto comparator = [](Node* a, Node* b) {
return a->Name() > b->Name();
};
std::stable_sort(sorted_inputs.begin(), sorted_inputs.end(),
comparator);
std::stable_sort(sorted_outputs.begin(), sorted_outputs.end(),
comparator);
std::string out_str = "{";
std::string pre_str = "";
for (const auto& output : sorted_outputs) {
out_str.append(pre_str + output->Name());
pre_str = ", ";
}
out_str.append("} = ");
std::string in_str = "(";
pre_str = "";
for (const auto& input : sorted_inputs) {
in_str.append(pre_str + input->Name());
pre_str = ", ";
}
in_str.append(")");
op_str = out_str + op_str + in_str;
} else {
// A normal Op, has OpDesc, create from ProgramDesc
std::string out_str = "{";
std::string outer_pre_str = "";
for (const auto& output : op->OutputNames()) {
out_str.append(outer_pre_str + output + "=[");
std::string inner_pre_str = "";
for (const auto& arg : op->Output(output)) {
out_str.append(inner_pre_str + arg);
inner_pre_str = " ,";
}
outer_pre_str = ", ";
out_str.append("]");
}
out_str.append("} = ");
std::string in_str = "(";
outer_pre_str = "";
for (const auto& input : op->InputNames()) {
in_str.append(outer_pre_str + input + "=[");
std::string inner_pre_str = "";
for (const auto& arg : op->Input(input)) {
in_str.append(inner_pre_str + arg);
inner_pre_str = " ,";
}
outer_pre_str = " ,";
in_str.append("]");
}
in_str.append(")");
op_str = out_str + op_str + in_str;
}
return op_str;
}
return Name();
}
std::vector<Node*> inputs; std::vector<Node*> inputs;
std::vector<Node*> outputs; std::vector<Node*> outputs;
// Because NO_DESC_ORDER is a constexpr number,
// no one can change it, meanwhile, we need
// check whether the DescOrder invalid sometime,
// so expose it is a good idea
static constexpr int NO_DESC_ORDER = INT_MAX;
protected: protected:
std::string name_; std::string name_;
std::unique_ptr<VarDesc> var_desc_; std::unique_ptr<VarDesc> var_desc_;
...@@ -146,30 +235,45 @@ class Node { ...@@ -146,30 +235,45 @@ class Node {
Type type_; Type type_;
int id_; int id_;
int desc_order_;
int block_id_{-1};
private: private:
// ID can only set by a Graph. // ID can only set by a Graph.
void SetId(int id) { id_ = id; } void SetId(int id) { id_ = id; }
// desc_order can only set by a Graph when constructing a Graph from a
// BlockDesc.
void SetDescOrder(int desc_order) { desc_order_ = desc_order; }
friend class Graph; friend class Graph;
friend std::unique_ptr<Node> CreateNodeForTest(const std::string& name, friend std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type); Node::Type type);
friend std::unique_ptr<Node> CreateNodeForTest(VarDesc* var_desc); friend std::unique_ptr<Node> CreateNodeForTest(VarDesc* var_desc);
friend std::unique_ptr<Node> CreateNodeForTest(OpDesc* op_desc); friend std::unique_ptr<Node> CreateNodeForTest(OpDesc* op_desc);
explicit Node(const std::string& name, Type type) explicit Node(const std::string& name, Type type, int block_id = 0)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} : name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
desc_order_(NO_DESC_ORDER),
block_id_(block_id) {}
explicit Node(VarDesc* var_desc) explicit Node(VarDesc* var_desc, int block_id)
: name_(var_desc->Name()), : name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)), var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr), op_desc_(nullptr),
type_(Type::kVariable) {} type_(Type::kVariable),
desc_order_(NO_DESC_ORDER),
block_id_(block_id) {}
explicit Node(OpDesc* op_desc) explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())), op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation) {} type_(Type::kOperation),
desc_order_(NO_DESC_ORDER) {}
Node() = delete; Node() = delete;
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -75,6 +76,32 @@ TEST(NodeTest, Basic) { ...@@ -75,6 +76,32 @@ TEST(NodeTest, Basic) {
EXPECT_FALSE(alive2); EXPECT_FALSE(alive2);
} }
TEST(NodeTest, ToString) {
VarDesc var_desc("n2");
OpDesc op_desc;
op_desc.SetType("test_op");
op_desc.SetInput("X", {"x1", "x2", "x3"});
op_desc.SetOutput("Y", {"y1", "y2"});
std::unique_ptr<Node> n1(CreateNodeForTest("n1", Node::Type::kVariable));
std::unique_ptr<Node> n2(CreateNodeForTest(&var_desc));
std::unique_ptr<Node> n3(CreateNodeForTest("n3", Node::Type::kOperation));
std::unique_ptr<Node> n4(CreateNodeForTest(&op_desc));
EXPECT_EQ(n1->ToString(), "n1");
EXPECT_EQ(n2->ToString(), "n2");
EXPECT_EQ(n3->Op(), nullptr);
EXPECT_EQ(n3->ToString(), "{} = n3()");
EXPECT_NE(n4->Op(), nullptr);
EXPECT_EQ(n4->ToString(), "{Y=[y1 ,y2]} = test_op(X=[x1 ,x2 ,x3])");
n3->inputs.push_back(n1.get());
n3->outputs.push_back(n2.get());
EXPECT_EQ(n3->Op(), nullptr);
EXPECT_EQ(n3->ToString(), "{n2} = n3(n1)");
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册