From a15fec8b8c7d51c1bf19d67aee6d9ad115719bf1 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 24 Aug 2023 14:50:40 +0800 Subject: [PATCH] [IR] Fix BuildOperatorBase bug (#56552) * fix inplace with set_parameter * fix bug * fix bug * fix bug * fix bug * fix bug * refine code * refine code * fix bug * refine op_gem * fix bug --- .../fluid/ir/dialect/op_generator/op_gen.py | 4 +- .../ir/phi_kernel_adaptor/phi_kernel_util.cc | 45 +++++++++++++++++++ .../ir/transforms/pd_op_to_kernel_pass.cc | 2 +- 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 2042c626f06..b3858b6e673 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -151,7 +151,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ std::vector inputs = {{ {inputs} }}; std::vector attributes = {{ {attributes} }}; std::vector outputs = {{ {outputs} }}; - paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{"{kernel_key_dtype}"}}, {{{inplace}}}, {{{view}}}); + paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{inplace}}}, {{{view}}}); return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); }} @@ -1020,6 +1020,8 @@ def OpGenerator( kernel_key_dtype = '", "'.join( op_kernel_map['data_type']['candidates'] ) + if kernel_key_dtype != "": + kernel_key_dtype = '"' + kernel_key_dtype + '"' inplace_str = "" view_str = "" diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 456aeae1468..7e7b0dbe76b 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -704,6 +704,51 @@ std::shared_ptr BuildOperatorBase( attr_map[name] = val.dyn_cast().data(); } else if (val.isa()) { attr_map[name] = val.dyn_cast().data(); + } else if (val.isa()) { + auto array_list = val.dyn_cast().AsVector(); + PADDLE_ENFORCE( + array_list.size() > 0, + paddle::platform::errors::Fatal("Attribute %s is empty", name)); + if (array_list[0].isa()) { + std::vector vec_int; + for (auto attribute : array_list) { + vec_int.push_back(attribute.dyn_cast().data()); + } + attr_map[name] = vec_int; + } else if (array_list[0].isa()) { + std::vector vec_int64; + for (auto attribute : array_list) { + vec_int64.push_back(attribute.dyn_cast().data()); + } + attr_map[name] = vec_int64; + } else if (array_list[0].isa()) { + std::vector vec_bool; + for (auto attribute : array_list) { + vec_bool.push_back(attribute.dyn_cast().data()); + } + attr_map[name] = vec_bool; + } else if (array_list[0].isa()) { + std::vector vec_float; + for (auto attribute : array_list) { + vec_float.push_back(attribute.dyn_cast().data()); + } + attr_map[name] = vec_float; + } else if (array_list[0].isa()) { + std::vector vec_double; + for (auto attribute : array_list) { + vec_double.push_back( + attribute.dyn_cast().data()); + } + attr_map[name] = vec_double; + } else { + std::stringstream ss; + val.Print(ss); + VLOG(1) << "type not support " << ss.str() << std::endl; + PADDLE_THROW("Type[%s] in attribute map not support yet", ss.str()); + } + } else if (val.isa()) { + attr_map[name] = paddle::framework::TransToProtoVarType( + val.dyn_cast().data()); } else { std::stringstream ss; val.Print(ss); diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index 47aaf0dc600..017188bcaca 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -248,7 +248,7 @@ phi::KernelKey GetKernelKey( auto attr_map = op->attributes(); auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype; - if (!data_type_info.empty() && !data_type_info[0].empty()) { + if (!data_type_info.empty()) { // only support single input and attribute auto slot_name = data_type_info[0]; auto& input_map = op_info_parser->InputName2Id(); -- GitLab