未验证 提交 b4d7e1e0 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Add op compat info for grad op (#55277)

* fix bug

* fix bug

* fix bug
上级 036c0ae1
......@@ -138,6 +138,14 @@ DEFINE_OP_TYPE_ID = """
IR_DEFINE_EXPLICIT_TYPE_ID({op_name})
"""
scalar_type_maps = {
'int': 'ir::Int32Attribute',
'int64_t': 'ir::Int64Attribute',
'float': 'ir::FloatAttribute',
'dobule': 'ir::DoubleAttribute',
'bool': 'ir::BoolAttribute',
}
def to_phi_and_fluid_op_name(op_item):
# Templat: - op : phi_name (fluid_name)
......@@ -151,13 +159,14 @@ def to_phi_and_fluid_op_name(op_item):
return phi_name, fluid_name
scalar_type_maps = {
'int': 'ir::Int32Attribute',
'int64_t': 'ir::Int64Attribute',
'float': 'ir::FloatAttribute',
'dobule': 'ir::DoubleAttribute',
'bool': 'ir::BoolAttribute',
}
def to_phi_and_fluid_grad_op_name(op_item):
# Templat: sum_grad (reduce_sum_grad), sum_double_grad
rtn = []
all_names = op_item.split(', ')
for name in all_names:
backward_phi_name, backward_fluid_name = to_phi_and_fluid_op_name(name)
rtn.append([backward_phi_name, backward_fluid_name])
return rtn
# =====================================
......@@ -171,9 +180,16 @@ class OpCompatParser:
def get_compat(self, op_name):
for compat in self.ops_compat:
phi_name, fluid_name = to_phi_and_fluid_op_name(compat['op'])
if op_name == phi_name:
forward_phi_name, forward_fluid_name = to_phi_and_fluid_op_name(
compat['op']
)
if op_name == forward_phi_name:
return compat
elif 'backward' in compat.keys():
bkw_names = to_phi_and_fluid_grad_op_name(compat['backward'])
for name in bkw_names:
if op_name == name[0]:
return compat
return None
......
......@@ -114,17 +114,17 @@ def OpNameNormalizerInitialization(
insert_new_mutable_attributes(
legacy_name, op_compat_item["int_array"]
)
for backward_op in legacy_backward_op_names:
insert_new_mutable_attributes(
backward_op, op_compat_item["int_array"]
)
if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
if "int_array" in op_compat_item:
insert_new_mutable_attributes(
legacy_name, op_compat_item["int_array"]
)
if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
for backward_op in legacy_backward_op_names:
insert_new_mutable_attributes(
backward_op, op_compat_item["scalar"]
)
# special op mappings
op_name_mappings["fetch_v2"] = "fetch"
......
......@@ -56,14 +56,13 @@ TEST(PaddleDialectTest, MainProgram) {
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p);
size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op + combine op + int
// array op + full op
EXPECT_EQ(op_size,
p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8);
std::stringstream ss;
program->Print(ss);
// ops.size() = op size in BlockDesc + get_parameter_op + combine op + int
// array op + full op (Note: p already has a full)
EXPECT_EQ(program->block()->size(),
p.Block(0).OpSize() + program->parameters_num() + 20 + 5 + 8);
EXPECT_GT(ss.str().size(), 0u);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册