未验证 提交 46abe798 编写于 作者: 王明冬 提交者: GitHub

[infrt] add default kernel argument remap feature in phi_op_convert_pass. (#40633)

上级 3a256637
......@@ -32,6 +32,7 @@
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/ops/compat/signatures.h"
namespace {
......@@ -94,42 +95,49 @@ void PhiOpConvertPass::convertStage() {
// Todo: print log
continue;
}
::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
infrt::ProtoArgumentMappingContext(op));
// resort input&output according to kernel_sign
::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
::llvm::SmallVector<mlir::Type, 4> output_types;
for (const std::string &str : std::get<0>(kernel_sign.args)) {
if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No input info for Op " << op_name << " and argument "
<< str;
return;
auto loc = getFunction().getLoc();
builder.setInsertionPoint(op);
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_name)) {
std::string kernel_name = phi::TransToPhiKernelName(op_name);
auto kernel_op = builder.create<infrt::KernelOp>(loc,
op->getResultTypes(),
op->getOperands(),
kernel_name,
op->getAttrDictionary());
op->replaceAllUsesWith(kernel_op.getResults());
} else {
::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
infrt::ProtoArgumentMappingContext(op));
// resort input&output according to kernel_sign
::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
::llvm::SmallVector<mlir::Type, 4> output_types;
for (const std::string &str : std::get<0>(kernel_sign.args)) {
if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No input info for Op " << op_name << " and argument "
<< str;
return;
}
uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
inputs.push_back(op->getOperands()[index]);
}
uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
inputs.push_back(op->getOperands()[index]);
}
for (const std::string &str : std::get<2>(kernel_sign.args)) {
if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No output info for Op " << op_name << " and argument "
<< str;
return;
for (const std::string &str : std::get<2>(kernel_sign.args)) {
if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No output info for Op " << op_name << " and argument "
<< str;
return;
}
uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
output_types.push_back(op->getResultTypes()[index]);
ori_output.push_back(op->getResult(index));
}
auto kernel_op = builder.create<infrt::KernelOp>(
loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
for (size_t index = 0; index < ori_output.size(); ++index) {
ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
}
uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
output_types.push_back(op->getResultTypes()[index]);
ori_output.push_back(op->getResult(index));
}
auto loc = getFunction().getLoc();
builder.setInsertionPoint(op);
auto kernel_op = builder.create<infrt::KernelOp>(
loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
for (size_t index = 0; index < ori_output.size(); ++index) {
ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
}
CHECK(op->use_empty());
op->erase();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册