未验证 提交 b4d7e1e0 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Add op compat info for grad op (#55277)

* fix bug

* fix bug

* fix bug
上级 036c0ae1
...@@ -138,6 +138,14 @@ DEFINE_OP_TYPE_ID = """ ...@@ -138,6 +138,14 @@ DEFINE_OP_TYPE_ID = """
IR_DEFINE_EXPLICIT_TYPE_ID({op_name}) IR_DEFINE_EXPLICIT_TYPE_ID({op_name})
""" """
scalar_type_maps = {
'int': 'ir::Int32Attribute',
'int64_t': 'ir::Int64Attribute',
'float': 'ir::FloatAttribute',
'dobule': 'ir::DoubleAttribute',
'bool': 'ir::BoolAttribute',
}
def to_phi_and_fluid_op_name(op_item): def to_phi_and_fluid_op_name(op_item):
# Templat: - op : phi_name (fluid_name) # Templat: - op : phi_name (fluid_name)
...@@ -151,13 +159,14 @@ def to_phi_and_fluid_op_name(op_item): ...@@ -151,13 +159,14 @@ def to_phi_and_fluid_op_name(op_item):
return phi_name, fluid_name return phi_name, fluid_name
scalar_type_maps = { def to_phi_and_fluid_grad_op_name(op_item):
'int': 'ir::Int32Attribute', # Templat: sum_grad (reduce_sum_grad), sum_double_grad
'int64_t': 'ir::Int64Attribute', rtn = []
'float': 'ir::FloatAttribute', all_names = op_item.split(', ')
'dobule': 'ir::DoubleAttribute', for name in all_names:
'bool': 'ir::BoolAttribute', backward_phi_name, backward_fluid_name = to_phi_and_fluid_op_name(name)
} rtn.append([backward_phi_name, backward_fluid_name])
return rtn
# ===================================== # =====================================
...@@ -171,9 +180,16 @@ class OpCompatParser: ...@@ -171,9 +180,16 @@ class OpCompatParser:
def get_compat(self, op_name): def get_compat(self, op_name):
for compat in self.ops_compat: for compat in self.ops_compat:
phi_name, fluid_name = to_phi_and_fluid_op_name(compat['op']) forward_phi_name, forward_fluid_name = to_phi_and_fluid_op_name(
if op_name == phi_name: compat['op']
)
if op_name == forward_phi_name:
return compat return compat
elif 'backward' in compat.keys():
bkw_names = to_phi_and_fluid_grad_op_name(compat['backward'])
for name in bkw_names:
if op_name == name[0]:
return compat
return None return None
......
...@@ -114,17 +114,17 @@ def OpNameNormalizerInitialization( ...@@ -114,17 +114,17 @@ def OpNameNormalizerInitialization(
insert_new_mutable_attributes( insert_new_mutable_attributes(
legacy_name, op_compat_item["int_array"] legacy_name, op_compat_item["int_array"]
) )
for backward_op in legacy_backward_op_names:
insert_new_mutable_attributes(
backward_op, op_compat_item["int_array"]
)
if "scalar" in op_compat_item: if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
for backward_op in legacy_backward_op_names:
if "int_array" in op_compat_item: insert_new_mutable_attributes(
insert_new_mutable_attributes( backward_op, op_compat_item["scalar"]
legacy_name, op_compat_item["int_array"] )
)
if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
# special op mappings # special op mappings
op_name_mappings["fetch_v2"] = "fetch" op_name_mappings["fetch_v2"] = "fetch"
......
...@@ -56,14 +56,13 @@ TEST(PaddleDialectTest, MainProgram) { ...@@ -56,14 +56,13 @@ TEST(PaddleDialectTest, MainProgram) {
ctx->GetOrRegisterDialect<ir::BuiltinDialect>(); ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p); auto program = paddle::TranslateLegacyProgramToProgram(p);
size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op + combine op + int
// array op + full op
EXPECT_EQ(op_size,
p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8);
std::stringstream ss; std::stringstream ss;
program->Print(ss); program->Print(ss);
// ops.size() = op size in BlockDesc + get_parameter_op + combine op + int
// array op + full op (Note: p already has a full)
EXPECT_EQ(program->block()->size(),
p.Block(0).OpSize() + program->parameters_num() + 20 + 5 + 8);
EXPECT_GT(ss.str().size(), 0u); EXPECT_GT(ss.str().size(), 0u);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册