未验证 提交 219c46b1 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Fix inplace op with set_parameter op bug (#56519)

* fix inplace with set_parameter

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* refine code

* refine code
上级 e914f7fc
...@@ -123,7 +123,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( ...@@ -123,7 +123,7 @@ LegacyKernelInstruction::LegacyKernelInstruction(
phi_kernel_ = new phi::Kernel(kernel_result.kernel); phi_kernel_ = new phi::Kernel(kernel_result.kernel);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name); phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name);
VLOG(6) << "finish process select kernel"; VLOG(6) << "finish process select kernel: " << kernel_name;
Scope* inner_scope = local_scope == nullptr ? scope : local_scope; Scope* inner_scope = local_scope == nullptr ? scope : local_scope;
......
...@@ -134,7 +134,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ ...@@ -134,7 +134,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }}; std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{"{kernel_key_dtype}"}}, {{{inplace}}}, {{{view}}}); paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{"{kernel_key_dtype}"}}, {{{inplace}}}, {{{view}}});
return std::make_tuple(inputs, attributes, outputs, run_time_info); return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}} }}
""" """
CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})""" CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})"""
...@@ -1024,6 +1024,7 @@ def OpGenerator( ...@@ -1024,6 +1024,7 @@ def OpGenerator(
kernel_key_dtype=kernel_key_dtype, kernel_key_dtype=kernel_key_dtype,
inplace=inplace_str, inplace=inplace_str,
view=view_str, view=view_str,
origin_op_name=op_info.op_yaml_item['name'],
) )
# generate op verify function str # generate op verify function str
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>, using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>,
std::vector<paddle::dialect::OpAttributeInfo>, std::vector<paddle::dialect::OpAttributeInfo>,
std::vector<paddle::dialect::OpOutputInfo>, std::vector<paddle::dialect::OpOutputInfo>,
paddle::dialect::OpRunTimeInfo>; paddle::dialect::OpRunTimeInfo,
std::string>;
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
......
...@@ -41,7 +41,7 @@ OpInfoTuple AddNOp::GetOpInfo() { ...@@ -41,7 +41,7 @@ OpInfoTuple AddNOp::GetOpInfo() {
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {}); OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {});
return std::make_tuple(inputs, attributes, outputs, run_time_info); return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n");
} }
void AddNOp::Verify() { void AddNOp::Verify() {
......
...@@ -187,5 +187,9 @@ void OpYamlInfoParser::parse() { ...@@ -187,5 +187,9 @@ void OpYamlInfoParser::parse() {
} }
} }
const std::string& OpYamlInfoParser::GetOriginOpName() const {
return std::get<4>(op_info_tuple_);
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -57,6 +57,8 @@ class OpYamlInfoParser { ...@@ -57,6 +57,8 @@ class OpYamlInfoParser {
const std::string& ViewName(const std::string& out_name) const; const std::string& ViewName(const std::string& out_name) const;
const std::string& GetOriginOpName() const;
private: private:
void parse(); void parse();
inline const std::vector<OpInputInfo>& InputInfo() const { inline const std::vector<OpInputInfo>& InputInfo() const {
......
...@@ -18,11 +18,7 @@ namespace paddle { ...@@ -18,11 +18,7 @@ namespace paddle {
namespace dialect { namespace dialect {
const std::unordered_set<std::string> LegacyOpList = { const std::unordered_set<std::string> LegacyOpList = {
"pd.fused_softmax_mask_upper_triangle", "pd.load_combine", "pd.c_concat", "pd.c_broadcast_"};
"pd.fused_softmax_mask_upper_triangle_grad",
"pd.load_combine",
"pd.c_concat",
"pd.load_combine"};
enum class AttrType { enum class AttrType {
UNDEFINED = 0, UNDEFINED = 0,
......
...@@ -80,6 +80,12 @@ void RenameData(ir::Value value, ...@@ -80,6 +80,12 @@ void RenameData(ir::Value value,
std::map<std::string, int>* var_name_2_id) { std::map<std::string, int>* var_name_2_id) {
(*value_2_var_name)[value] = new_name; (*value_2_var_name)[value] = new_name;
for (auto kv : (*value_2_var_name)) {
if (kv.second == orig_name) {
(*value_2_var_name)[kv.first] = new_name;
}
}
for (auto kv : (*variable_2_var_name)) { for (auto kv : (*variable_2_var_name)) {
if (kv.second == orig_name) { if (kv.second == orig_name) {
(*variable_2_var_name)[kv.first] = new_name; (*variable_2_var_name)[kv.first] = new_name;
...@@ -588,9 +594,7 @@ void BuildRuntimeContext( ...@@ -588,9 +594,7 @@ void BuildRuntimeContext(
auto& name2id = op_yaml_info.InputName2Id(); auto& name2id = op_yaml_info.InputName2Id();
auto pd_op_name = std::string fluid_op_name = op_yaml_info.GetOriginOpName();
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
auto fluid_op_name = pd_op_name.substr(3); // pd_op_name start with "pd.xxx"
auto& op_normalizer = paddle::translator::OpNameNormalizer::instance(); auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
...@@ -621,7 +625,7 @@ void BuildRuntimeContext( ...@@ -621,7 +625,7 @@ void BuildRuntimeContext(
ir::Value ptr = op->result(i); ir::Value ptr = op->result(i);
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name; VLOG(6) << "ctx->EmplaceBackOutput: " << name << "\t" << in_var_name;
PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name),
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
...@@ -664,9 +668,7 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase( ...@@ -664,9 +668,7 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
auto& name2id = op_yaml_info.InputName2Id(); auto& name2id = op_yaml_info.InputName2Id();
auto pd_op_name = std::string fluid_op_name = op_yaml_info.GetOriginOpName();
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
auto fluid_op_name = pd_op_name.substr(3); // pd_op_name start with "pd.xxx"
auto& op_normalizer = paddle::translator::OpNameNormalizer::instance(); auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
......
...@@ -653,7 +653,7 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx, ...@@ -653,7 +653,7 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx,
OpInputInfoList input_infos; OpInputInfoList input_infos;
OpAttributeInfoList attr_infos; OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos; OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) = std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
op_info_concept->get_op_info_(); op_info_concept->get_op_info_();
this->InsertSliceOperationForInput( this->InsertSliceOperationForInput(
...@@ -769,7 +769,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { ...@@ -769,7 +769,7 @@ struct AssignValueOpTranscriber : public OpTranscriber {
OpInputInfoList input_infos; OpInputInfoList input_infos;
OpAttributeInfoList attr_infos; OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos; OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) = std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
op_info_concept->get_op_info_(); op_info_concept->get_op_info_();
std::unordered_map<std::string, OpAttributeInfo> attr_info_maps; std::unordered_map<std::string, OpAttributeInfo> attr_info_maps;
for (auto info : attr_infos) { for (auto info : attr_infos) {
...@@ -1092,7 +1092,7 @@ struct FetchOpTranscriber : public OpTranscriber { ...@@ -1092,7 +1092,7 @@ struct FetchOpTranscriber : public OpTranscriber {
OpInputInfoList input_infos; OpInputInfoList input_infos;
OpAttributeInfoList attr_infos; OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos; OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) = std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
op_info_concept->get_op_info_(); op_info_concept->get_op_info_();
this->InsertSliceOperationForInput( this->InsertSliceOperationForInput(
......
...@@ -490,7 +490,8 @@ OpInfoTuple Conv2dFusionOpTest::GetOpInfo() { ...@@ -490,7 +490,8 @@ OpInfoTuple Conv2dFusionOpTest::GetOpInfo() {
{}, {},
{}); {});
return std::make_tuple(inputs, attributes, outputs, run_time_info); return std::make_tuple(
inputs, attributes, outputs, run_time_info, "conv2d_fusion_test");
} }
void Conv2dFusionOpTest::Build(ir::Builder &builder, void Conv2dFusionOpTest::Build(ir::Builder &builder,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册