未验证 提交 46abe798 编写于 作者: 王明冬 提交者: GitHub

[infrt] add default kernel argument remap feature in phi_op_convert_pass. (#40633)

上级 3a256637
......@@ -32,6 +32,7 @@
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/ops/compat/signatures.h"
namespace {
......@@ -94,7 +95,17 @@ void PhiOpConvertPass::convertStage() {
// Todo: print log
continue;
}
auto loc = getFunction().getLoc();
builder.setInsertionPoint(op);
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_name)) {
std::string kernel_name = phi::TransToPhiKernelName(op_name);
auto kernel_op = builder.create<infrt::KernelOp>(loc,
op->getResultTypes(),
op->getOperands(),
kernel_name,
op->getAttrDictionary());
op->replaceAllUsesWith(kernel_op.getResults());
} else {
::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
infrt::ProtoArgumentMappingContext(op));
......@@ -121,15 +132,12 @@ void PhiOpConvertPass::convertStage() {
output_types.push_back(op->getResultTypes()[index]);
ori_output.push_back(op->getResult(index));
}
auto loc = getFunction().getLoc();
builder.setInsertionPoint(op);
auto kernel_op = builder.create<infrt::KernelOp>(
loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
for (size_t index = 0; index < ori_output.size(); ++index) {
ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
}
}
CHECK(op->use_empty());
op->erase();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册