From b4d7e1e06a2ce03ff38299bb8b8ebf959ee4bc3d Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Tue, 11 Jul 2023 14:03:05 +0800 Subject: [PATCH] [IR] Add op compat info for grad op (#55277) * fix bug * fix bug * fix bug --- .../fluid/ir/dialect/op_generator/op_gen.py | 34 ++++++++++++++----- .../ir_adaptor/translator/op_compat_gen.py | 16 ++++----- test/cpp/ir/core/program_translator_test.cc | 11 +++--- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 8f366e675e4..2d5ef75d1c6 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -138,6 +138,14 @@ DEFINE_OP_TYPE_ID = """ IR_DEFINE_EXPLICIT_TYPE_ID({op_name}) """ +scalar_type_maps = { + 'int': 'ir::Int32Attribute', + 'int64_t': 'ir::Int64Attribute', + 'float': 'ir::FloatAttribute', + 'dobule': 'ir::DoubleAttribute', + 'bool': 'ir::BoolAttribute', +} + def to_phi_and_fluid_op_name(op_item): # Templat: - op : phi_name (fluid_name) @@ -151,13 +159,14 @@ def to_phi_and_fluid_op_name(op_item): return phi_name, fluid_name -scalar_type_maps = { - 'int': 'ir::Int32Attribute', - 'int64_t': 'ir::Int64Attribute', - 'float': 'ir::FloatAttribute', - 'dobule': 'ir::DoubleAttribute', - 'bool': 'ir::BoolAttribute', -} +def to_phi_and_fluid_grad_op_name(op_item): + # Templat: sum_grad (reduce_sum_grad), sum_double_grad + rtn = [] + all_names = op_item.split(', ') + for name in all_names: + backward_phi_name, backward_fluid_name = to_phi_and_fluid_op_name(name) + rtn.append([backward_phi_name, backward_fluid_name]) + return rtn # ===================================== @@ -171,9 +180,16 @@ class OpCompatParser: def get_compat(self, op_name): for compat in self.ops_compat: - phi_name, fluid_name = to_phi_and_fluid_op_name(compat['op']) - if op_name == phi_name: + forward_phi_name, forward_fluid_name = to_phi_and_fluid_op_name( + compat['op'] + ) + if op_name == forward_phi_name: return compat + elif 'backward' in compat.keys(): + bkw_names = to_phi_and_fluid_grad_op_name(compat['backward']) + for name in bkw_names: + if op_name == name[0]: + return compat return None diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 5fbe508ce80..913d8f26f9c 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -114,17 +114,17 @@ def OpNameNormalizerInitialization( insert_new_mutable_attributes( legacy_name, op_compat_item["int_array"] ) + for backward_op in legacy_backward_op_names: + insert_new_mutable_attributes( + backward_op, op_compat_item["int_array"] + ) if "scalar" in op_compat_item: insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) - - if "int_array" in op_compat_item: - insert_new_mutable_attributes( - legacy_name, op_compat_item["int_array"] - ) - - if "scalar" in op_compat_item: - insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) + for backward_op in legacy_backward_op_names: + insert_new_mutable_attributes( + backward_op, op_compat_item["scalar"] + ) # special op mappings op_name_mappings["fetch_v2"] = "fetch" diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index c50c579ca51..50b3623b452 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -56,14 +56,13 @@ TEST(PaddleDialectTest, MainProgram) { ctx->GetOrRegisterDialect(); auto program = paddle::TranslateLegacyProgramToProgram(p); - size_t op_size = program->block()->size(); - // ops.size() = op size in BlockDesc + get_parameter_op + combine op + int - // array op + full op - EXPECT_EQ(op_size, - p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8); - std::stringstream ss; program->Print(ss); + + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + int + // array op + full op (Note: p already has a full) + EXPECT_EQ(program->block()->size(), + p.Block(0).OpSize() + program->parameters_num() + 20 + 5 + 8); EXPECT_GT(ss.str().size(), 0u); } -- GitLab