diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 8f366e675e4ada146a9209b806db8454e781f809..2d5ef75d1c6c4570326cbe0637130c1fdb0045c8 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -138,6 +138,14 @@ DEFINE_OP_TYPE_ID = """ IR_DEFINE_EXPLICIT_TYPE_ID({op_name}) """ +scalar_type_maps = { + 'int': 'ir::Int32Attribute', + 'int64_t': 'ir::Int64Attribute', + 'float': 'ir::FloatAttribute', + 'dobule': 'ir::DoubleAttribute', + 'bool': 'ir::BoolAttribute', +} + def to_phi_and_fluid_op_name(op_item): # Templat: - op : phi_name (fluid_name) @@ -151,13 +159,14 @@ def to_phi_and_fluid_op_name(op_item): return phi_name, fluid_name -scalar_type_maps = { - 'int': 'ir::Int32Attribute', - 'int64_t': 'ir::Int64Attribute', - 'float': 'ir::FloatAttribute', - 'dobule': 'ir::DoubleAttribute', - 'bool': 'ir::BoolAttribute', -} +def to_phi_and_fluid_grad_op_name(op_item): + # Templat: sum_grad (reduce_sum_grad), sum_double_grad + rtn = [] + all_names = op_item.split(', ') + for name in all_names: + backward_phi_name, backward_fluid_name = to_phi_and_fluid_op_name(name) + rtn.append([backward_phi_name, backward_fluid_name]) + return rtn # ===================================== @@ -171,9 +180,16 @@ class OpCompatParser: def get_compat(self, op_name): for compat in self.ops_compat: - phi_name, fluid_name = to_phi_and_fluid_op_name(compat['op']) - if op_name == phi_name: + forward_phi_name, forward_fluid_name = to_phi_and_fluid_op_name( + compat['op'] + ) + if op_name == forward_phi_name: return compat + elif 'backward' in compat.keys(): + bkw_names = to_phi_and_fluid_grad_op_name(compat['backward']) + for name in bkw_names: + if op_name == name[0]: + return compat return None diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 5fbe508ce80c17d332831125193d7048f720fa1c..913d8f26f9cd1a9b7e71f41cd64b7bdc9ea39259 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -114,17 +114,17 @@ def OpNameNormalizerInitialization( insert_new_mutable_attributes( legacy_name, op_compat_item["int_array"] ) + for backward_op in legacy_backward_op_names: + insert_new_mutable_attributes( + backward_op, op_compat_item["int_array"] + ) if "scalar" in op_compat_item: insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) - - if "int_array" in op_compat_item: - insert_new_mutable_attributes( - legacy_name, op_compat_item["int_array"] - ) - - if "scalar" in op_compat_item: - insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) + for backward_op in legacy_backward_op_names: + insert_new_mutable_attributes( + backward_op, op_compat_item["scalar"] + ) # special op mappings op_name_mappings["fetch_v2"] = "fetch" diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index c50c579ca51daad6d258bbebfbdf231a7bd32f62..50b3623b4528ae8460f6a073082e2c5d90032a08 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -56,14 +56,13 @@ TEST(PaddleDialectTest, MainProgram) { ctx->GetOrRegisterDialect(); auto program = paddle::TranslateLegacyProgramToProgram(p); - size_t op_size = program->block()->size(); - // ops.size() = op size in BlockDesc + get_parameter_op + combine op + int - // array op + full op - EXPECT_EQ(op_size, - p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8); - std::stringstream ss; program->Print(ss); + + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + int + // array op + full op (Note: p already has a full) + EXPECT_EQ(program->block()->size(), + p.Block(0).OpSize() + program->parameters_num() + 20 + 5 + 8); EXPECT_GT(ss.str().size(), 0u); }