diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index f3cf7982aa9c67690361a70ce20f4c795b6becc9..8071edee33607c95df994ef6cf2959e4bdf00a64 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -151,8 +151,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ std::vector inputs = {{ {inputs} }}; std::vector attributes = {{ {attributes} }}; std::vector outputs = {{ {outputs} }}; - paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{inplace}}}, {{{view}}}); - + paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}); return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); }} """ @@ -1013,6 +1012,7 @@ def OpGenerator( kernel_func_str = "" kernel_param_str = "" kernel_key_dtype = "" + kernel_key_backend = "" if op_kernel_map is not None: kernel_func_str = '", "'.join(op_kernel_map['func']) kernel_param_str = '", "'.join(op_kernel_map['param']) @@ -1022,6 +1022,12 @@ def OpGenerator( ) if kernel_key_dtype != "": kernel_key_dtype = '"' + kernel_key_dtype + '"' + if 'backend' in op_kernel_map and op_kernel_map['backend']: + kernel_key_backend = '", "'.join( + op_kernel_map['backend']['candidates'] + ) + if kernel_key_backend != "": + kernel_key_backend = '"' + kernel_key_backend + '"' inplace_str = "" view_str = "" @@ -1045,6 +1051,7 @@ def OpGenerator( kernel_func=kernel_func_str, kernel_param=kernel_param_str, kernel_key_dtype=kernel_key_dtype, + kernel_key_backend=kernel_key_backend, inplace=inplace_str, view=view_str, origin_op_name=op_info.op_yaml_item['name'], diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index 0c51b98d42fd11497e347bb6419884f6da6a36fc..4a131bbf1dc506561b235699e764a5a78ed1e5cf 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -39,7 +39,7 @@ OpInfoTuple AddNOp::GetOpInfo() { std::vector outputs = { OpOutputInfo("out", "paddle::dialect::DenseTensorType", false, false)}; paddle::dialect::OpRunTimeInfo run_time_info = - OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {}); + OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {}, {}); return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n"); } diff --git a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h b/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h index 5dd319c69fddc0402e93bfda586b88335b1a855b..eaa37a3a7de9fa66ffa12d560b91ba1fb72d76d2 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h @@ -77,6 +77,7 @@ struct OpRunTimeInfo { std::vector kernel_func; std::vector kernel_param; std::vector kernel_key_dtype; + std::vector kernel_key_backend; std::vector> inplace; std::vector> view; OpRunTimeInfo(const std::string& infer_meta_func, @@ -84,6 +85,7 @@ struct OpRunTimeInfo { const std::vector& kernel_func, const std::vector& kernel_param, const std::vector& dtype, + const std::vector& backend, const std::vector>& inplace, const std::vector>& view) : infer_meta_func(infer_meta_func), @@ -91,6 +93,7 @@ struct OpRunTimeInfo { kernel_func(kernel_func), kernel_param(kernel_param), kernel_key_dtype(dtype), + kernel_key_backend(backend), inplace(inplace), view(view) {} }; diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index 017188bcacaeb79fec0eb01f7294fbf87ca44d30..4198098f2bd4fbb65e00f452d043d56b80ae6f55 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -209,6 +209,150 @@ ir::OpResult AddPlaceTransferOp(ir::OpResult in, } } +phi::DataType GetKernelDataTypeByYamlInfo( + const ir::Operation* op, + const std::unordered_map& map_value_pair, + const dialect::OpYamlInfoParser* op_info_parser) { + auto& attr_map = op->attributes(); + auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype; + phi::DataType kernel_data_type = phi::DataType::UNDEFINED; + + for (size_t i = 0; i < data_type_info.size(); ++i) { + auto slot_name = data_type_info[i]; + auto& input_map = op_info_parser->InputName2Id(); + + auto find_it = Str2PhiDataType.find(slot_name); + if (find_it != Str2PhiDataType.end()) { + kernel_data_type = find_it->second; + } else if (input_map.count(slot_name)) { + // parse from input + int in_index = input_map.at(slot_name); + auto type = map_value_pair.at(op->operand_source(in_index)).type(); + + if (type.isa()) { + kernel_data_type = TransToPhiDataType( + type.dyn_cast().dtype()); + } else if (type.isa()) { + auto vec_data = type.dyn_cast().data(); + if (vec_data.empty()) { + kernel_data_type = phi::DataType::UNDEFINED; + } else { + if (vec_data[0].isa()) { + kernel_data_type = TransToPhiDataType( + vec_data[0] + .dyn_cast() + .dtype()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support DenseTensorType in vector")); + } + } + } else if (type.isa()) { + kernel_data_type = TransToPhiDataType( + type.dyn_cast() + .dtype()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support DenseTensorType, SelectedRows, VectorType")); + } + + } else { + PADDLE_ENFORCE_EQ(attr_map.count(slot_name), + true, + phi::errors::PreconditionNotMet( + "[%s] MUST in attribute map", slot_name)); + + auto attr_type = op_info_parser->AttrTypeName(slot_name); + PADDLE_ENFORCE_EQ(attr_type, + "paddle::dialect::DataTypeAttribute", + phi::errors::PreconditionNotMet( + "Type of [%s] should be DataType", slot_name)); + kernel_data_type = attr_map.at(slot_name) + .dyn_cast() + .data(); + } + + if (kernel_data_type != phi::DataType::UNDEFINED) { + // In yaml definition, data type have an order + // like: data_type : dtype > x + // Should break when found a defined data type + break; + } + } + + return kernel_data_type; +} + +phi::Backend GetKernelBackendByYamlInfo( + const ir::Operation* op, + const std::unordered_map& map_value_pair, + const dialect::OpYamlInfoParser* op_info_parser) { + auto& attr_map = op->attributes(); + auto& backend_info = op_info_parser->OpRuntimeInfo().kernel_key_backend; + phi::Backend kernel_backend = phi::Backend::UNDEFINED; + for (size_t i = 0; i < backend_info.size(); ++i) { + auto slot_name = backend_info[i]; + auto& input_map = op_info_parser->InputName2Id(); + + if (input_map.count(slot_name)) { + // parse from input + int in_index = input_map.at(slot_name); + auto type = map_value_pair.at(op->operand_source(in_index)).type(); + + if (type.isa()) { + kernel_backend = paddle::experimental::ParseBackend( + type.dyn_cast().place()); + } else if (type.isa()) { + auto vec_data = type.dyn_cast().data(); + if (vec_data.empty()) { + kernel_backend = phi::Backend::UNDEFINED; + } else { + if (vec_data[0].isa()) { + kernel_backend = paddle::experimental::ParseBackend( + vec_data[0] + .dyn_cast() + .place()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support DenseTensorType in vector")); + } + } + } else if (type.isa()) { + kernel_backend = paddle::experimental::ParseBackend( + type.dyn_cast() + .place()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support DenseTensorType, SelectedRows, VectorType")); + } + + } else { + PADDLE_ENFORCE_EQ(attr_map.count(slot_name), + true, + phi::errors::PreconditionNotMet( + "[%s] MUST in attribute map", slot_name)); + + auto attr_type = op_info_parser->AttrTypeName(slot_name); + PADDLE_ENFORCE_EQ(attr_type, + "paddle::dialect::PlaceAttribute", + phi::errors::PreconditionNotMet( + "Type of [%s] should be DataType", slot_name)); + kernel_backend = paddle::experimental::ParseBackend( + attr_map.at(slot_name) + .dyn_cast() + .data()); + } + if (kernel_backend != phi::Backend::UNDEFINED) { + // In yaml definition, backend have an order + // like: backend : place > x + // Should break when found a defined data type + break; + } + } + + return kernel_backend; +} + phi::KernelKey GetKernelKey( ir::Operation* op, const phi::Place& place, @@ -245,66 +389,11 @@ phi::KernelKey GetKernelKey( // only suppurt non vector input for now int tensor_input_number = op_info_parser->InputTensorNumber(); - auto attr_map = op->attributes(); - auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype; - - if (!data_type_info.empty()) { - // only support single input and attribute - auto slot_name = data_type_info[0]; - auto& input_map = op_info_parser->InputName2Id(); - - auto find_it = Str2PhiDataType.find(slot_name); - if (find_it != Str2PhiDataType.end()) { - kernel_data_type = find_it->second; - } else if (input_map.count(slot_name)) { - // parse from input - int in_index = input_map.at(slot_name); - auto type = map_value_pair.at(op->operand_source(in_index)).type(); - - if (type.isa()) { - kernel_data_type = TransToPhiDataType( - type.dyn_cast() - .dtype()); - } else if (type.isa()) { - auto vec_data = type.dyn_cast().data(); - if (vec_data.empty()) { - kernel_data_type = phi::DataType::UNDEFINED; - } else { - if (vec_data[0].isa()) { - kernel_data_type = TransToPhiDataType( - vec_data[0] - .dyn_cast() - .dtype()); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Only support DenseTensorType in vector")); - } - } - } else if (type.isa()) { - kernel_data_type = TransToPhiDataType( - type.dyn_cast() - .dtype()); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Only support DenseTensorType, SelectedRows, VectorType")); - } - - } else { - PADDLE_ENFORCE_EQ(attr_map.count(slot_name), - true, - phi::errors::PreconditionNotMet( - "[%s] MUST in attribute map", slot_name)); - - auto attr_type = op_info_parser->AttrTypeName(slot_name); - PADDLE_ENFORCE_EQ(attr_type, - "paddle::dialect::DataTypeAttribute", - phi::errors::PreconditionNotMet( - "Type of [%s] should be DataType", slot_name)); - kernel_data_type = attr_map.at(slot_name) - .dyn_cast() - .data(); - } - } + // get datatype info + kernel_data_type = + GetKernelDataTypeByYamlInfo(op, map_value_pair, op_info_parser); + kernel_backend = + GetKernelBackendByYamlInfo(op, map_value_pair, op_info_parser); // parse all the input tensor if (tensor_input_number == 0 || op->name() == "pd.full_") { diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index 9f5e97deda93e4dd86c72aee84495474694184d4..9a479f7191b73c816d2f50b943580a7f922f0353 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -488,6 +488,7 @@ OpInfoTuple Conv2dFusionOpTest::GetOpInfo() { "user_workspace_size"}, {"input"}, {}, + {}, {}); return std::make_tuple(