未验证 提交 46abe798 编写于 作者: 王明冬 提交者: GitHub

[infrt] add default kernel argument remap feature in phi_op_convert_pass. (#40633)

上级 3a256637
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h" #include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h" #include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/ops/compat/signatures.h" #include "paddle/phi/ops/compat/signatures.h"
namespace { namespace {
...@@ -94,42 +95,49 @@ void PhiOpConvertPass::convertStage() { ...@@ -94,42 +95,49 @@ void PhiOpConvertPass::convertStage() {
// Todo: print log // Todo: print log
continue; continue;
} }
auto loc = getFunction().getLoc();
::phi::KernelSignature kernel_sign = builder.setInsertionPoint(op);
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_name)) {
infrt::ProtoArgumentMappingContext(op)); std::string kernel_name = phi::TransToPhiKernelName(op_name);
// resort input&output according to kernel_sign auto kernel_op = builder.create<infrt::KernelOp>(loc,
::llvm::SmallVector<mlir::Value, 4> inputs, ori_output; op->getResultTypes(),
::llvm::SmallVector<mlir::Type, 4> output_types; op->getOperands(),
for (const std::string &str : std::get<0>(kernel_sign.args)) { kernel_name,
if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) { op->getAttrDictionary());
LOG(ERROR) << "No input info for Op " << op_name << " and argument " op->replaceAllUsesWith(kernel_op.getResults());
<< str; } else {
return; ::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
infrt::ProtoArgumentMappingContext(op));
// resort input&output according to kernel_sign
::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
::llvm::SmallVector<mlir::Type, 4> output_types;
for (const std::string &str : std::get<0>(kernel_sign.args)) {
if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No input info for Op " << op_name << " and argument "
<< str;
return;
}
uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
inputs.push_back(op->getOperands()[index]);
} }
uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
inputs.push_back(op->getOperands()[index]);
}
for (const std::string &str : std::get<2>(kernel_sign.args)) { for (const std::string &str : std::get<2>(kernel_sign.args)) {
if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) { if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No output info for Op " << op_name << " and argument " LOG(ERROR) << "No output info for Op " << op_name << " and argument "
<< str; << str;
return; return;
}
uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
output_types.push_back(op->getResultTypes()[index]);
ori_output.push_back(op->getResult(index));
}
auto kernel_op = builder.create<infrt::KernelOp>(
loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
for (size_t index = 0; index < ori_output.size(); ++index) {
ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
} }
uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
output_types.push_back(op->getResultTypes()[index]);
ori_output.push_back(op->getResult(index));
}
auto loc = getFunction().getLoc();
builder.setInsertionPoint(op);
auto kernel_op = builder.create<infrt::KernelOp>(
loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
for (size_t index = 0; index < ori_output.size(); ++index) {
ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
} }
CHECK(op->use_empty()); CHECK(op->use_empty());
op->erase(); op->erase();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册