diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 2042c626f067be903387254fe035cc58bdcb9d8a..b3858b6e6730e6555ca0d7697c66c8c759813344 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -151,7 +151,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ std::vector inputs = {{ {inputs} }}; std::vector attributes = {{ {attributes} }}; std::vector outputs = {{ {outputs} }}; - paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{"{kernel_key_dtype}"}}, {{{inplace}}}, {{{view}}}); + paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{inplace}}}, {{{view}}}); return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); }} @@ -1020,6 +1020,8 @@ def OpGenerator( kernel_key_dtype = '", "'.join( op_kernel_map['data_type']['candidates'] ) + if kernel_key_dtype != "": + kernel_key_dtype = '"' + kernel_key_dtype + '"' inplace_str = "" view_str = "" diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 456aeae14688500de9843400020b49b7dcaa8a4d..7e7b0dbe76bbbc2f9f913ce09e9631390a169927 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -704,6 +704,51 @@ std::shared_ptr BuildOperatorBase( attr_map[name] = val.dyn_cast().data(); } else if (val.isa()) { attr_map[name] = val.dyn_cast().data(); + } else if (val.isa()) { + auto array_list = val.dyn_cast().AsVector(); + PADDLE_ENFORCE( + array_list.size() > 0, + paddle::platform::errors::Fatal("Attribute %s is empty", name)); + if (array_list[0].isa()) { + std::vector vec_int; + for (auto attribute : array_list) { + vec_int.push_back(attribute.dyn_cast().data()); + } + attr_map[name] = vec_int; + } else if (array_list[0].isa()) { + std::vector vec_int64; + for (auto attribute : array_list) { + vec_int64.push_back(attribute.dyn_cast().data()); + } + attr_map[name] = vec_int64; + } else if (array_list[0].isa()) { + std::vector vec_bool; + for (auto attribute : array_list) { + vec_bool.push_back(attribute.dyn_cast().data()); + } + attr_map[name] = vec_bool; + } else if (array_list[0].isa()) { + std::vector vec_float; + for (auto attribute : array_list) { + vec_float.push_back(attribute.dyn_cast().data()); + } + attr_map[name] = vec_float; + } else if (array_list[0].isa()) { + std::vector vec_double; + for (auto attribute : array_list) { + vec_double.push_back( + attribute.dyn_cast().data()); + } + attr_map[name] = vec_double; + } else { + std::stringstream ss; + val.Print(ss); + VLOG(1) << "type not support " << ss.str() << std::endl; + PADDLE_THROW("Type[%s] in attribute map not support yet", ss.str()); + } + } else if (val.isa()) { + attr_map[name] = paddle::framework::TransToProtoVarType( + val.dyn_cast().data()); } else { std::stringstream ss; val.Print(ss); diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index 47aaf0dc600e4aa89fce0bd4b28c5baf7ad28a1c..017188bcacaeb79fec0eb01f7294fbf87ca44d30 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -248,7 +248,7 @@ phi::KernelKey GetKernelKey( auto attr_map = op->attributes(); auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype; - if (!data_type_info.empty() && !data_type_info[0].empty()) { + if (!data_type_info.empty()) { // only support single input and attribute auto slot_name = data_type_info[0]; auto& input_map = op_info_parser->InputName2Id();