diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index 9eba6fc89a2e9654c85e124510db884c72127d36..085298314ea3ffeba7e5924d0017e87ecf97d91a 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/generate_pass.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -224,6 +225,115 @@ bool GeneratePass::VerifyGraph(const Graph& graph) { return true; } +namespace generate_pass { + +VarHelper::VarHelper(const char* name) : name_(name), type_(Type::kInput) {} +VarHelper::VarHelper(const std::string& name, Type type) + : name_(name), type_(type) {} + +OpHelper::OpHelper(const char* type, SubgraphHelper* subgraph_helper) + : type_(type), subgraph_helper_(subgraph_helper) { + op_desc_ = subgraph_helper_->ProgramDesc()->mutable_blocks(0)->add_ops(); + op_desc_->set_type(type_); +} + +OpHelper::Arguments::Arguments(const char* parameter, + const VarHelper& var_helper) + : parameter_(parameter) { + var_helpers_.push_back(var_helper); +} + +OpHelper::Arguments::Arguments(const char* parameter, + std::initializer_list var_helpers) + : parameter_(parameter), var_helpers_(var_helpers) {} + +OpHelper& OpHelper::operator()(const Arguments& input) { + proto::OpDesc::Var* var = op_desc_->add_inputs(); + var->set_parameter(input.parameter_); + for (const VarHelper& var_helper : input.var_helpers_) { + var->add_arguments()->assign(var_helper.name_); + if (VarHelper::Type::kInput == var_helper.type_) { + subgraph_helper_->AddInputVar(var_helper.name_); + } + } + return *this; +} + +OpHelper& OpHelper::operator()(std::initializer_list inputs) { + for (const auto& input : inputs) { + operator()(input); + } + return *this; +} + +VarHelper OpHelper::Out(const char* name) { + std::string argument = patterns::UniqueKey(type_); + proto::OpDesc::Var* var = op_desc_->add_outputs(); + var->set_parameter(name); + var->add_arguments()->assign(argument); + return VarHelper(argument, VarHelper::Type::kOutput); +} + +proto::ProgramDesc* SubgraphHelper::ProgramDesc() { return &program_desc_; } + +const proto::ProgramDesc& SubgraphHelper::ProgramDesc() const { + return program_desc_; +} + +const std::vector& SubgraphHelper::InputVars() const { + return input_vars_; +} + +const std::vector& SubgraphHelper::OutputVars() const { + return output_vars_; +} + +void SubgraphHelper::AddInputVar(const std::string& name) { + auto iter = std::find(input_vars_.begin(), input_vars_.end(), name); + if (input_vars_.end() == iter) { + input_vars_.push_back(name); + } +} + +void SubgraphHelper::AddOutputVars(const VarHelper& var_helper) { + output_vars_.push_back(var_helper.name_); +} + +} // namespace generate_pass + +PassPairs::PassPairs(const SubgraphType& pattern, const SubgraphType& replace) { + AddPassDesc(pattern, replace); +} + +void PassPairs::AddPassDesc(const SubgraphType& pattern, + const SubgraphType& replace) { + proto::PassDesc* pass_desc = multi_pass_desc_.add_pass_descs(); + pass_desc->mutable_pattern()->CopyFrom(pattern.ProgramDesc()); + pass_desc->mutable_replace()->CopyFrom(replace.ProgramDesc()); + PADDLE_ENFORCE_EQ(pattern.InputVars().size(), replace.InputVars().size(), + platform::errors::InvalidArgument( + "Size of lambda expression arguments is not equal " + "between pattern/replace subgraph.")); + for (size_t i = 0; i < pattern.InputVars().size(); i++) { + proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); + var_map->set_pattern_var(pattern.InputVars()[i]); + var_map->set_replace_var(replace.InputVars()[i]); + } + PADDLE_ENFORCE_EQ(pattern.OutputVars().size(), replace.OutputVars().size(), + platform::errors::InvalidArgument( + "Size of lambda expression returns is not equal " + "between pattern/replace subgraph.")); + for (size_t i = 0; i < pattern.OutputVars().size(); i++) { + proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); + var_map->set_pattern_var(pattern.OutputVars()[i]); + var_map->set_replace_var(replace.OutputVars()[i]); + } +} + +const proto::MultiPassDesc& PassPairs::MultiPassDesc() const { + return multi_pass_desc_; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/generate_pass.h b/paddle/fluid/framework/ir/generate_pass.h index f73173233aed321fefbc3f9018a051b5bbc86519..26e5231fbc16e7c9038a3bf060f82628501fa12a 100644 --- a/paddle/fluid/framework/ir/generate_pass.h +++ b/paddle/fluid/framework/ir/generate_pass.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/pass_desc.pb.h" @@ -43,6 +42,158 @@ class GeneratePass : public Pass { proto::MultiPassDesc multi_pass_desc_; }; +namespace generate_pass { + +class VarHelper; +class OpHelper; +class SubgraphHelper; + +// VarHelper is used to represent a variable node. +struct VarHelper { + enum class Type { kInput, kOutput }; + + explicit VarHelper(const char* name); + VarHelper(const std::string& name, Type type); + + std::string name_; + Type type_; +}; + +// OpHelper is used to represent a operator node. +class OpHelper { + public: + // Convert multiple inputs. + struct Arguments { + Arguments(const char* parameter, const VarHelper& var_helper); + Arguments(const char* parameter, + std::initializer_list var_helpers); + + std::string parameter_; + std::vector var_helpers_; + }; + + OpHelper(const char* type, SubgraphHelper* subgraph_helper); + + OpHelper& operator()(const Arguments& input); + OpHelper& operator()(std::initializer_list inputs); + + VarHelper Out(const char* name); + + private: + OpHelper() = delete; + DISABLE_COPY_AND_ASSIGN(OpHelper); + + const char* type_; + proto::OpDesc* op_desc_; + SubgraphHelper* subgraph_helper_; +}; + +/* + * SubgraphHelper is used to define pattern/replace subgraphs. + * + * Use lambda expression to define subgraph like Python. SubgraphHelper + * converts lambda expression to ProgramDesc. + * + * In order to define a subgraph, user need to use VarHelper and OpHelper. + * Use the macros instead of class names, so user can develop better and + * don't need to know too much about underlying implementation. + * + * An example of defining a subgraph as follows: + * + * SUBGRAPH_(subgraph)([subgraph=&subgraph](VAR_(x), VAR_(y), VAR_(z)) { + * auto ewadd1 = OP_(elementwise_add)({{"X", x}, {"Y", y}}).Out("Out"); + * auto ewadd2 = OP_(elementwise_add)({{"X", ewadd1}, {"Y", z}}).Out("Out"); + * return ewadd2; + * }); + * + */ +class SubgraphHelper { + public: + SubgraphHelper() = default; + // The lambda expression is a prvalue expression. + template + SubgraphHelper& operator=(const T&& f) { + proto::BlockDesc* block = program_desc_.add_blocks(); + block->set_idx(0); + block->set_parent_idx(0); + AddOutputVars(f()); + return *this; + } + + proto::ProgramDesc* ProgramDesc(); + const proto::ProgramDesc& ProgramDesc() const; + const std::vector& InputVars() const; + const std::vector& OutputVars() const; + + void AddInputVar(const std::string& name); + + void AddOutputVars(const VarHelper& var_helper); + + template * = nullptr> + void AddOutputVars(const std::tuple& outputs) { + AddOutputVars(std::get(outputs)); + AddOutputVars(outputs); + } + + template * = nullptr> + void AddOutputVars(const std::tuple& outputs) { + AddOutputVars(std::get(outputs)); + } + + template + void AddOutputVars(const std::tuple& outputs) { + AddOutputVars<0>(outputs); + } + + private: + DISABLE_COPY_AND_ASSIGN(SubgraphHelper); + std::vector input_vars_; + std::vector output_vars_; + proto::ProgramDesc program_desc_; +}; + +} // namespace generate_pass + +class PassPairs { + public: + using SubgraphType = generate_pass::SubgraphHelper; + + PassPairs() = default; + PassPairs(const SubgraphType& pattern, const SubgraphType& replace); + + void AddPassDesc(const SubgraphType& pattern, const SubgraphType& replace); + + const proto::MultiPassDesc& MultiPassDesc() const; + + private: + proto::MultiPassDesc multi_pass_desc_; +}; + +// Use function to register in CC. +template +class MacroPassHelper : public GeneratePass { + public: + MacroPassHelper() : GeneratePass(Functor().MultiPassDesc()) {} +}; + +#define VAR_(name) \ + ::paddle::framework::ir::generate_pass::VarHelper name = \ + ::paddle::framework::ir::generate_pass::VarHelper(#name) +#define OP_(type) \ + ::paddle::framework::ir::generate_pass::OpHelper(#type, subgraph) +#define SUBGRAPH_(name) \ + ::paddle::framework::ir::generate_pass::SubgraphHelper name; \ + name + +#define REGISTER_GENERATE_PASS(pass_type) \ + paddle::framework::ir::PassPairs register_##pass_type(); \ + REGISTER_PASS( \ + pass_type, \ + ::paddle::framework::ir::MacroPassHelper<®ister_##pass_type>); \ + paddle::framework::ir::PassPairs register_##pass_type() + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/generate_pass_tester.cc b/paddle/fluid/framework/ir/generate_pass_tester.cc index c3852d29c308ff2594c63f3b5be06ad12c7211a4..6876dde50c157c3e14d2aa5b1212e9ffd48f90bf 100644 --- a/paddle/fluid/framework/ir/generate_pass_tester.cc +++ b/paddle/fluid/framework/ir/generate_pass_tester.cc @@ -16,234 +16,71 @@ #include "gtest/gtest.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h" -namespace paddle { -namespace framework { -namespace ir { - -template -class CXXGeneratePass : public GeneratePass { - public: - CXXGeneratePass() : GeneratePass(Functor()) {} -}; - -#define REGISTER_GENERATE_PASS(pass_type, function) \ - REGISTER_PASS(pass_type, ::paddle::framework::ir::CXXGeneratePass<&function>) - -proto::MultiPassDesc generate_fc_fuse() { - proto::MultiPassDesc multi_pass_desc; +REGISTER_GENERATE_PASS(generate_fc_fuse) { + paddle::framework::ir::PassPairs pass_pairs; for (bool with_relu : {true, false}) { - proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs(); - proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks(); - pattern->set_idx(0); - pattern->set_parent_idx(0); - proto::OpDesc* mul = pattern->add_ops(); - mul->set_type("mul"); - proto::OpDesc::Var* mul_x = mul->add_inputs(); - mul_x->set_parameter("X"); - mul_x->add_arguments()->assign("x"); - proto::OpDesc::Var* mul_y = mul->add_inputs(); - mul_y->set_parameter("Y"); - mul_y->add_arguments()->assign("w"); - proto::OpDesc::Var* mul_out = mul->add_outputs(); - mul_out->set_parameter("Out"); - mul_out->add_arguments()->assign("mul_out"); - proto::OpDesc* ewadd = pattern->add_ops(); - ewadd->set_type("elementwise_add"); - proto::OpDesc::Var* ewadd_x = ewadd->add_inputs(); - ewadd_x->set_parameter("X"); - ewadd_x->add_arguments()->assign("mul_out"); - proto::OpDesc::Var* ewadd_y = ewadd->add_inputs(); - ewadd_y->set_parameter("Y"); - ewadd_y->add_arguments()->assign("b"); - proto::OpDesc::Var* ewadd_out = ewadd->add_outputs(); - ewadd_out->set_parameter("Out"); - ewadd_out->add_arguments()->assign("ewadd_out"); - proto::OpDesc* relu = nullptr; - proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks(); - replace->set_idx(0); - replace->set_parent_idx(0); - proto::OpDesc* fc = replace->add_ops(); - fc->set_type("fc"); - proto::OpDesc::Var* fc_x = fc->add_inputs(); - fc_x->set_parameter("Input"); - fc_x->add_arguments()->assign("x"); - proto::OpDesc::Var* fc_w = fc->add_inputs(); - fc_w->set_parameter("W"); - fc_w->add_arguments()->assign("w"); - proto::OpDesc::Var* fc_b = fc->add_inputs(); - fc_b->set_parameter("Bias"); - fc_b->add_arguments()->assign("b"); - proto::OpDesc::Var* fc_out = fc->add_outputs(); - fc_out->set_parameter("Out"); - fc_out->add_arguments()->assign("fc_out"); - for (const char* var : {"x", "w", "b", "fc_out"}) { - proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); - var_map->set_pattern_var(var); - var_map->set_replace_var(var); - } - proto::PassDesc::AttrMap* attr_map = pass_desc->add_attr_maps(); - attr_map->set_pattern_op_idx(0); - attr_map->set_pattern_name("x_num_col_dims"); - attr_map->set_replace_op_idx(0); - attr_map->set_replace_name("in_num_col_dims"); - if (with_relu) { - relu = pattern->add_ops(); - relu->set_type("relu"); - proto::OpDesc::Var* relu_x = relu->add_inputs(); - relu_x->set_parameter("X"); - relu_x->add_arguments()->assign("ewadd_out"); - proto::OpDesc::Var* relu_out = relu->add_outputs(); - relu_out->set_parameter("Out"); - relu_out->add_arguments()->assign("relu_out"); - pass_desc->mutable_var_maps(3)->set_pattern_var("relu_out"); - proto::OpDesc::Attr* attr = fc->add_attrs(); - attr->set_name("activation_type"); - attr->set_type(proto::AttrType::STRING); - attr->set_s("relu"); - } else { - pass_desc->mutable_var_maps(3)->set_pattern_var("ewadd_out"); - } + // pattern + SUBGRAPH_(pattern) = + [ subgraph = &pattern, with_relu ](VAR_(x), VAR_(y), VAR_(z)) { + VLOG(3) << "exec lambda func."; + auto mul = OP_(mul)({{"X", x}, {"Y", y}}).Out("Out"); + auto ewadd = OP_(elementwise_add)({{"X", mul}, {"Y", z}}).Out("Out"); + if (with_relu) { + return OP_(relu)({"X", ewadd}).Out("Out"); + } else { + return ewadd; + } + }; + // replace + SUBGRAPH_(replace) = + [ subgraph = &replace, with_relu ](VAR_(x), VAR_(y), VAR_(z)) { + auto& fc = OP_(fc)({{"Input", x}, {"W", y}, {"Bias", z}}); + return fc.Out("Out"); + }; + pass_pairs.AddPassDesc(pattern, replace); } - return multi_pass_desc; + return pass_pairs; } -proto::MultiPassDesc generate_multi_add_to_addn() { - proto::MultiPassDesc multi_pass_desc; - proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs(); - proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks(); - proto::OpDesc* ewadd_0 = pattern->add_ops(); - ewadd_0->set_type("elementwise_add"); - proto::OpDesc::Var* ewadd_0_x = ewadd_0->add_inputs(); - ewadd_0_x->set_parameter("X"); - ewadd_0_x->add_arguments()->assign("a"); - proto::OpDesc::Var* ewadd_0_y = ewadd_0->add_inputs(); - ewadd_0_y->set_parameter("Y"); - ewadd_0_y->add_arguments()->assign("b"); - proto::OpDesc::Var* ewadd_0_out = ewadd_0->add_outputs(); - ewadd_0_out->set_parameter("Out"); - ewadd_0_out->add_arguments()->assign("ewadd_out_0"); - proto::OpDesc* ewadd_1 = pattern->add_ops(); - ewadd_1->set_type("elementwise_add"); - proto::OpDesc::Var* ewadd_1_x = ewadd_1->add_inputs(); - ewadd_1_x->set_parameter("X"); - ewadd_1_x->add_arguments()->assign("ewadd_out_0"); - proto::OpDesc::Var* ewadd_1_y = ewadd_1->add_inputs(); - ewadd_1_y->set_parameter("Y"); - ewadd_1_y->add_arguments()->assign("c"); - proto::OpDesc::Var* ewadd_1_out = ewadd_1->add_outputs(); - ewadd_1_out->set_parameter("Out"); - ewadd_1_out->add_arguments()->assign("ewadd_out_1"); - proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks(); - proto::OpDesc* addn = replace->add_ops(); - addn->set_type("add_n"); - proto::OpDesc::Var* addn_x = addn->add_inputs(); - addn_x->set_parameter("X"); - addn_x->add_arguments()->assign("a"); - addn_x->add_arguments()->assign("b"); - addn_x->add_arguments()->assign("c"); - proto::OpDesc::Var* addn_out = addn->add_outputs(); - addn_out->set_parameter("Out"); - addn_out->add_arguments()->assign("addn_out"); - for (const char* var : {"a", "b", "c", "ewadd_out_1"}) { - proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); - var_map->set_pattern_var(var); - var_map->set_replace_var(var); - } - pass_desc->mutable_var_maps(3)->set_replace_var("addn_out"); - return multi_pass_desc; +REGISTER_GENERATE_PASS(generate_multi_add_to_addn) { + // pattern + SUBGRAPH_(pattern) = [subgraph = &pattern](VAR_(x), VAR_(y), VAR_(z)) { + auto ewadd1 = OP_(elementwise_add)({{"X", x}, {"Y", y}}).Out("Out"); + auto ewadd2 = OP_(elementwise_add)({{"X", ewadd1}, {"Y", z}}).Out("Out"); + return ewadd2; + }; + // replace + SUBGRAPH_(replace) = [subgraph = &replace](VAR_(x), VAR_(y), VAR_(z)) { + return OP_(sum)({"X", {x, y, z}}).Out("Out"); + }; + return {pattern, replace}; } -proto::MultiPassDesc generate_combine_matmul() { - proto::MultiPassDesc multi_pass_desc; - proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs(); - proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks(); - proto::OpDesc* matmul_0 = pattern->add_ops(); - matmul_0->set_type("matmul"); - proto::OpDesc::Var* matmul_0_x = matmul_0->add_inputs(); - matmul_0_x->set_parameter("X"); - matmul_0_x->add_arguments()->assign("a"); - proto::OpDesc::Var* matmul_0_y = matmul_0->add_inputs(); - matmul_0_y->set_parameter("Y"); - matmul_0_y->add_arguments()->assign("b"); - proto::OpDesc::Var* matmul_0_out = matmul_0->add_outputs(); - matmul_0_out->set_parameter("Out"); - matmul_0_out->add_arguments()->assign("matmul_out_0"); - proto::OpDesc* matmul_1 = pattern->add_ops(); - matmul_1->set_type("matmul"); - proto::OpDesc::Var* matmul_1_x = matmul_1->add_inputs(); - matmul_1_x->set_parameter("X"); - matmul_1_x->add_arguments()->assign("a"); - proto::OpDesc::Var* matmul_1_y = matmul_1->add_inputs(); - matmul_1_y->set_parameter("Y"); - matmul_1_y->add_arguments()->assign("c"); - proto::OpDesc::Var* matmul_1_out = matmul_1->add_outputs(); - matmul_1_out->set_parameter("Out"); - matmul_1_out->add_arguments()->assign("matmul_out_1"); - proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks(); - proto::OpDesc* concat = replace->add_ops(); - concat->set_type("concat"); - proto::OpDesc::Var* concat_x = concat->add_inputs(); - concat_x->set_parameter("X"); - concat_x->add_arguments()->assign("b"); - concat_x->add_arguments()->assign("c"); - proto::OpDesc::Var* concat_out = concat->add_outputs(); - concat_out->set_parameter("Out"); - concat_out->add_arguments()->assign("concat_out"); - proto::OpDesc* matmul = replace->add_ops(); - matmul->set_type("matmul"); - proto::OpDesc::Var* matmul_x = matmul->add_inputs(); - matmul_x->set_parameter("X"); - matmul_x->add_arguments()->assign("a"); - proto::OpDesc::Var* matmul_y = matmul->add_inputs(); - matmul_y->set_parameter("Y"); - matmul_y->add_arguments()->assign("concat_out"); - proto::OpDesc::Var* matmul_out = matmul->add_outputs(); - matmul_out->set_parameter("Out"); - matmul_out->add_arguments()->assign("matmul_out"); - proto::OpDesc* slice_0 = replace->add_ops(); - slice_0->set_type("slice"); - proto::OpDesc::Var* slice_0_x = slice_0->add_inputs(); - slice_0_x->set_parameter("X"); - slice_0_x->add_arguments()->assign("matmul_out"); - proto::OpDesc::Var* slice_0_out = slice_0->add_outputs(); - slice_0_out->set_parameter("Out"); - slice_0_out->add_arguments()->assign("slice_out_0"); - proto::OpDesc* slice_1 = replace->add_ops(); - slice_1->set_type("slice"); - proto::OpDesc::Var* slice_1_x = slice_1->add_inputs(); - slice_1_x->set_parameter("X"); - slice_1_x->add_arguments()->assign("matmul_out"); - proto::OpDesc::Var* slice_1_out = slice_1->add_outputs(); - slice_1_out->set_parameter("Out"); - slice_1_out->add_arguments()->assign("slice_out_1"); - for (const char* var : {"a", "b", "c", "matmul_out_0", "matmul_out_1"}) { - proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); - var_map->set_pattern_var(var); - var_map->set_replace_var(var); - } - pass_desc->mutable_var_maps(3)->set_replace_var("slice_out_0"); - pass_desc->mutable_var_maps(4)->set_replace_var("slice_out_1"); - return multi_pass_desc; +REGISTER_GENERATE_PASS(generate_combine_matmul) { + // pattern + SUBGRAPH_(pattern) = [subgraph = &pattern](VAR_(x), VAR_(y), VAR_(z)) { + auto matmul1 = OP_(matmul)({{"X", x}, {"Y", y}}).Out("Out"); + auto matmul2 = OP_(matmul)({{"X", x}, {"Y", z}}).Out("Out"); + return std::make_tuple(matmul1, matmul2); + }; + // replace + SUBGRAPH_(replace) = [subgraph = &replace](VAR_(x), VAR_(y), VAR_(z)) { + auto concat = OP_(concat)({"X", {y, z}}).Out("Out"); + auto matmul = OP_(matmul)({{"X", x}, {"Y", concat}}).Out("Out"); + auto slice1 = OP_(slice)({"X", matmul}).Out("Out"); + auto slice2 = OP_(slice)({"X", matmul}).Out("Out"); + return std::make_tuple(slice1, slice2); + }; + return {pattern, replace}; } -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_GENERATE_PASS(generate_fc_fuse, - paddle::framework::ir::generate_fc_fuse); -REGISTER_GENERATE_PASS(generate_multi_add_to_addn, - paddle::framework::ir::generate_multi_add_to_addn); -REGISTER_GENERATE_PASS(generate_combine_matmul, - paddle::framework::ir::generate_combine_matmul); - namespace paddle { namespace framework { namespace ir { TEST(GeneratePass, construct_with_string) { std::string binary_str; - generate_fc_fuse().SerializeToString(&binary_str); + register_generate_fc_fuse().MultiPassDesc().SerializeToString(&binary_str); GeneratePass generate_pass(binary_str); } @@ -318,7 +155,7 @@ TEST(GeneratePass, generate_multi_add_to_addn) { graph.reset(pass->Apply(graph.release())); int num_nodes_after = graph->Nodes().size(); - int num_addn_nodes_after = GetNumOpNodes(graph, "add_n"); + int num_addn_nodes_after = GetNumOpNodes(graph, "sum"); VLOG(3) << DebugString(graph); PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 2,