From 46abe798d8ca7edc72f76f117878b5b7edc7b6d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Thu, 17 Mar 2022 06:35:16 +0800 Subject: [PATCH] [infrt] add default kernel argument remap feature in phi_op_convert_pass. (#40633) --- .../dialect/phi/pass/phi_op_convert_pass.cc | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc index f9e124aba6c..13cba6eeabb 100644 --- a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc +++ b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc @@ -32,6 +32,7 @@ #include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h" #include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h" #include "paddle/phi/core/compat/op_utils.h" +#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/ops/compat/signatures.h" namespace { @@ -94,42 +95,49 @@ void PhiOpConvertPass::convertStage() { // Todo: print log continue; } - - ::phi::KernelSignature kernel_sign = - ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( - infrt::ProtoArgumentMappingContext(op)); - // resort input&output according to kernel_sign - ::llvm::SmallVector inputs, ori_output; - ::llvm::SmallVector output_types; - for (const std::string &str : std::get<0>(kernel_sign.args)) { - if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) { - LOG(ERROR) << "No input info for Op " << op_name << " and argument " - << str; - return; + auto loc = getFunction().getLoc(); + builder.setInsertionPoint(op); + if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_name)) { + std::string kernel_name = phi::TransToPhiKernelName(op_name); + auto kernel_op = builder.create(loc, + op->getResultTypes(), + op->getOperands(), + kernel_name, + op->getAttrDictionary()); + op->replaceAllUsesWith(kernel_op.getResults()); + } else { + ::phi::KernelSignature kernel_sign = + ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( + infrt::ProtoArgumentMappingContext(op)); + // resort input&output according to kernel_sign + ::llvm::SmallVector inputs, ori_output; + ::llvm::SmallVector output_types; + for (const std::string &str : std::get<0>(kernel_sign.args)) { + if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) { + LOG(ERROR) << "No input info for Op " << op_name << " and argument " + << str; + return; + } + uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str); + inputs.push_back(op->getOperands()[index]); } - uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str); - inputs.push_back(op->getOperands()[index]); - } - for (const std::string &str : std::get<2>(kernel_sign.args)) { - if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) { - LOG(ERROR) << "No output info for Op " << op_name << " and argument " - << str; - return; + for (const std::string &str : std::get<2>(kernel_sign.args)) { + if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) { + LOG(ERROR) << "No output info for Op " << op_name << " and argument " + << str; + return; + } + uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str); + output_types.push_back(op->getResultTypes()[index]); + ori_output.push_back(op->getResult(index)); + } + auto kernel_op = builder.create( + loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary()); + for (size_t index = 0; index < ori_output.size(); ++index) { + ori_output[index].replaceAllUsesWith(kernel_op.getResult(index)); } - uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str); - output_types.push_back(op->getResultTypes()[index]); - ori_output.push_back(op->getResult(index)); - } - - auto loc = getFunction().getLoc(); - builder.setInsertionPoint(op); - auto kernel_op = builder.create( - loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary()); - for (size_t index = 0; index < ori_output.size(); ++index) { - ori_output[index].replaceAllUsesWith(kernel_op.getResult(index)); } - CHECK(op->use_empty()); op->erase(); } -- GitLab