未验证 提交 219c46b1 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Fix inplace op with set_parameter op bug (#56519)

* fix inplace with set_parameter

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* refine code

* refine code
上级 e914f7fc
......@@ -123,7 +123,7 @@ LegacyKernelInstruction::LegacyKernelInstruction(
phi_kernel_ = new phi::Kernel(kernel_result.kernel);
PADDLE_ENFORCE_EQ(
phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name);
VLOG(6) << "finish process select kernel";
VLOG(6) << "finish process select kernel: " << kernel_name;
Scope* inner_scope = local_scope == nullptr ? scope : local_scope;
......
......@@ -134,7 +134,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{"{kernel_key_dtype}"}}, {{{inplace}}}, {{{view}}});
return std::make_tuple(inputs, attributes, outputs, run_time_info);
return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}}
"""
CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})"""
......@@ -1024,6 +1024,7 @@ def OpGenerator(
kernel_key_dtype=kernel_key_dtype,
inplace=inplace_str,
view=view_str,
origin_op_name=op_info.op_yaml_item['name'],
)
# generate op verify function str
......
......@@ -20,7 +20,8 @@
using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>,
std::vector<paddle::dialect::OpAttributeInfo>,
std::vector<paddle::dialect::OpOutputInfo>,
paddle::dialect::OpRunTimeInfo>;
paddle::dialect::OpRunTimeInfo,
std::string>;
namespace paddle {
namespace dialect {
......
......@@ -41,7 +41,7 @@ OpInfoTuple AddNOp::GetOpInfo() {
paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {});
return std::make_tuple(inputs, attributes, outputs, run_time_info);
return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n");
}
void AddNOp::Verify() {
......
......@@ -187,5 +187,9 @@ void OpYamlInfoParser::parse() {
}
}
const std::string& OpYamlInfoParser::GetOriginOpName() const {
return std::get<4>(op_info_tuple_);
}
} // namespace dialect
} // namespace paddle
......@@ -57,6 +57,8 @@ class OpYamlInfoParser {
const std::string& ViewName(const std::string& out_name) const;
const std::string& GetOriginOpName() const;
private:
void parse();
inline const std::vector<OpInputInfo>& InputInfo() const {
......
......@@ -18,11 +18,7 @@ namespace paddle {
namespace dialect {
const std::unordered_set<std::string> LegacyOpList = {
"pd.fused_softmax_mask_upper_triangle",
"pd.fused_softmax_mask_upper_triangle_grad",
"pd.load_combine",
"pd.c_concat",
"pd.load_combine"};
"pd.load_combine", "pd.c_concat", "pd.c_broadcast_"};
enum class AttrType {
UNDEFINED = 0,
......
......@@ -80,6 +80,12 @@ void RenameData(ir::Value value,
std::map<std::string, int>* var_name_2_id) {
(*value_2_var_name)[value] = new_name;
for (auto kv : (*value_2_var_name)) {
if (kv.second == orig_name) {
(*value_2_var_name)[kv.first] = new_name;
}
}
for (auto kv : (*variable_2_var_name)) {
if (kv.second == orig_name) {
(*variable_2_var_name)[kv.first] = new_name;
......@@ -588,9 +594,7 @@ void BuildRuntimeContext(
auto& name2id = op_yaml_info.InputName2Id();
auto pd_op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
auto fluid_op_name = pd_op_name.substr(3); // pd_op_name start with "pd.xxx"
std::string fluid_op_name = op_yaml_info.GetOriginOpName();
auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
......@@ -621,7 +625,7 @@ void BuildRuntimeContext(
ir::Value ptr = op->result(i);
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name;
VLOG(6) << "ctx->EmplaceBackOutput: " << name << "\t" << in_var_name;
PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name),
phi::errors::PreconditionNotMet(
......@@ -664,9 +668,7 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
auto& name2id = op_yaml_info.InputName2Id();
auto pd_op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
auto fluid_op_name = pd_op_name.substr(3); // pd_op_name start with "pd.xxx"
std::string fluid_op_name = op_yaml_info.GetOriginOpName();
auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
......
......@@ -653,7 +653,7 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx,
OpInputInfoList input_infos;
OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
op_info_concept->get_op_info_();
this->InsertSliceOperationForInput(
......@@ -769,7 +769,7 @@ struct AssignValueOpTranscriber : public OpTranscriber {
OpInputInfoList input_infos;
OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
op_info_concept->get_op_info_();
std::unordered_map<std::string, OpAttributeInfo> attr_info_maps;
for (auto info : attr_infos) {
......@@ -1092,7 +1092,7 @@ struct FetchOpTranscriber : public OpTranscriber {
OpInputInfoList input_infos;
OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
op_info_concept->get_op_info_();
this->InsertSliceOperationForInput(
......
......@@ -490,7 +490,8 @@ OpInfoTuple Conv2dFusionOpTest::GetOpInfo() {
{},
{});
return std::make_tuple(inputs, attributes, outputs, run_time_info);
return std::make_tuple(
inputs, attributes, outputs, run_time_info, "conv2d_fusion_test");
}
void Conv2dFusionOpTest::Build(ir::Builder &builder,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册