未验证 提交 927767ca 编写于 作者: 王明冬 提交者: GitHub

[infrt]Refine phi dialect (#40505)

* change some symbol names

* add test

* add phi to opt.cc

* clean code

* up

* update

* up

* up

* Update pten_pass.mlir

* Update convolution_grad_kernel.cc

* update

* restore init_infrt_dialects

* restore

* up

* up

* up
Co-authored-by: NSuperjomn <yanchunwei@outlook.com>
上级 a04a6bd5
...@@ -53,9 +53,9 @@ def Infrt_CallOp : Infrt_Op<"call"> { ...@@ -53,9 +53,9 @@ def Infrt_CallOp : Infrt_Op<"call"> {
}]; }];
} }
def Infrt_CvtTensorOp : Infrt_Op<"cvt_tensor", [NoSideEffect]> { def Infrt_TensorCastOp : Infrt_Op<"tensor_cast", [NoSideEffect]> {
let summary = "convert tensor type op"; let summary = "cast tensor type op";
let description = [{convert tensor type op!}]; let description = [{cast tensor type op!}];
let arguments = (ins AnyType:$input); let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output); let results = (outs AnyType:$output);
} }
...@@ -5,17 +5,17 @@ include "mlir/Interfaces/SideEffectInterfaces.td" ...@@ -5,17 +5,17 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "paddle/infrt/dialect/infrt/ir/infrt_ops.td" include "paddle/infrt/dialect/infrt/ir/infrt_ops.td"
include "paddle/infrt/dialect/pd_ops.td" include "paddle/infrt/dialect/pd_ops.td"
def FuseCvtTensorPattern : Pat< def FuseTensorCastPattern : Pat<
(Infrt_CvtTensorOp (Infrt_CvtTensorOp $arg)), (Infrt_TensorCastOp (Infrt_TensorCastOp $arg)),
(Infrt_CvtTensorOp $arg)>; (Infrt_TensorCastOp $arg)>;
def FuseFeedCvtTensorPattern : Pat< def FuseFeedTensorCastPattern : Pat<
(Infrt_CvtTensorOp (PD_FeedOp $name)), (Infrt_TensorCastOp (PD_FeedOp $name)),
(PD_FeedOp $name)>; (PD_FeedOp $name)>;
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>; def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
def RedundantCvtTensorOptPattern : Pat< def RedundantTensorCastOptPattern : Pat<
(Infrt_CvtTensorOp:$res $arg), (replaceWithValue $arg), (Infrt_TensorCastOp:$res $arg), (replaceWithValue $arg),
[(TypesAreIdentical $res, $arg)]>; [(TypesAreIdentical $res, $arg)]>;
......
...@@ -27,8 +27,12 @@ struct InfrtOpFusePass ...@@ -27,8 +27,12 @@ struct InfrtOpFusePass
: public mlir::PassWrapper<InfrtOpFusePass, mlir::FunctionPass> { : public mlir::PassWrapper<InfrtOpFusePass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "infrtOpFusePass"; } ::llvm::StringRef getName() const override { return "infrtOpFusePass"; }
llvm::StringRef getArgument() const override { return "infrt-op-fuse"; }
void runOnFunction() override; void runOnFunction() override;
}; };
// Implementation of the InfrtOpFusePass. // Implementation of the InfrtOpFusePass.
void InfrtOpFusePass::runOnFunction() { void InfrtOpFusePass::runOnFunction() {
::mlir::RewritePatternSet patterns(&getContext()); ::mlir::RewritePatternSet patterns(&getContext());
...@@ -39,14 +43,18 @@ void InfrtOpFusePass::runOnFunction() { ...@@ -39,14 +43,18 @@ void InfrtOpFusePass::runOnFunction() {
if (nullptr == terminator_op) return; if (nullptr == terminator_op) return;
for (auto operand : terminator_op->getOperands()) { for (auto operand : terminator_op->getOperands()) {
auto *op1 = operand.getDefiningOp(); auto *op1 = operand.getDefiningOp();
auto cvt_op = ::llvm::dyn_cast<::infrt::CvtTensorOp>(op1); auto cvt_op = ::llvm::dyn_cast<::infrt::TensorCastOp>(op1);
if (!cvt_op) continue; if (!cvt_op) continue;
mlir::Value value = cvt_op.input(); mlir::Value value = cvt_op.input();
operand.replaceAllUsesWith(value); operand.replaceAllUsesWith(value);
cvt_op.erase(); cvt_op.erase();
} }
} }
} // namespace } // namespace
std::unique_ptr<mlir::Pass> infrt::createInfrtOpFusePass() { std::unique_ptr<mlir::Pass> infrt::createInfrtOpFusePass() {
return std::make_unique<InfrtOpFusePass>(); return std::make_unique<InfrtOpFusePass>();
} }
mlir::PassRegistration<InfrtOpFusePass> infrt_op_fuse_pass;
...@@ -5,9 +5,6 @@ endif() ...@@ -5,9 +5,6 @@ endif()
add_subdirectory(ir) add_subdirectory(ir)
add_subdirectory(pass) add_subdirectory(pass)
add_executable(phi-ir-exec phi_ir_exec.cc)
target_link_libraries(phi-ir-exec infrt)
add_executable(phi-exec phi_exec.cc) add_executable(phi-exec phi_exec.cc)
target_link_libraries(phi-exec infrt) target_link_libraries(phi-exec infrt)
......
...@@ -29,6 +29,7 @@ namespace infrt { ...@@ -29,6 +29,7 @@ namespace infrt {
namespace phi { namespace phi {
void PHIDialect::initialize() { void PHIDialect::initialize() {
LOG(INFO) << "PHI Dialect initalized";
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "paddle/infrt/dialect/phi/ir/infrt_phi_base.cpp.inc" // NOLINT #include "paddle/infrt/dialect/phi/ir/infrt_phi_base.cpp.inc" // NOLINT
......
...@@ -2,6 +2,8 @@ core_gather_headers() ...@@ -2,6 +2,8 @@ core_gather_headers()
gather_srcs(infrt_src SRCS gather_srcs(infrt_src SRCS
proto_arg_map_context.cc proto_arg_map_context.cc
phi_op_cvt_pass.cc phi_op_convert_pass.cc
kernel_op_desc.cc kernel_op_desc.cc
) )
cc_test(test_kernel_op_desc SRCS kernel_op_desc_test.cc DEPS infrt)
...@@ -73,7 +73,7 @@ std::string getPhiLayoutSuffix(LayoutType layout) { ...@@ -73,7 +73,7 @@ std::string getPhiLayoutSuffix(LayoutType layout) {
} }
} }
std::vector<PhiKernelDesc> getCandidateKernels( std::vector<PhiKernelDesc> GetCandidateKernels(
std::string name, const std::vector<Place>& valid_palces) { std::string name, const std::vector<Place>& valid_palces) {
std::vector<PhiKernelDesc> candidate_kernels; std::vector<PhiKernelDesc> candidate_kernels;
PhiKernelDesc phi_kernel_desc; PhiKernelDesc phi_kernel_desc;
...@@ -88,19 +88,20 @@ std::vector<PhiKernelDesc> getCandidateKernels( ...@@ -88,19 +88,20 @@ std::vector<PhiKernelDesc> getCandidateKernels(
if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) continue; if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) continue;
place.layout = LayoutType::ANY; place.layout = LayoutType::ANY;
} }
phi_kernel_desc.kernelType = place; phi_kernel_desc.kernel_type = place;
phi_kernel_desc.inputsType.clear(); phi_kernel_desc.input_types.clear();
phi_kernel_desc.outputsType.clear(); phi_kernel_desc.output_types.clear();
phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def(); phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def();
const paddle::SmallVector<phi::TensorArgDef>& input_arg = const paddle::SmallVector<phi::TensorArgDef>& input_arg =
args_def.input_defs(); args_def.input_defs();
const paddle::SmallVector<phi::TensorArgDef>& output_arg = const paddle::SmallVector<phi::TensorArgDef>& output_arg =
args_def.output_defs(); args_def.output_defs();
for (auto tensor_arg : input_arg) { for (auto tensor_arg : input_arg) {
phi_kernel_desc.inputsType.emplace_back(ConvertPlaceFromPhi(tensor_arg)); phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg));
} }
for (auto tensor_arg : output_arg) { for (auto tensor_arg : output_arg) {
phi_kernel_desc.outputsType.emplace_back(ConvertPlaceFromPhi(tensor_arg)); phi_kernel_desc.output_types.emplace_back(
ConvertPlaceFromPhi(tensor_arg));
} }
candidate_kernels.emplace_back(phi_kernel_desc); candidate_kernels.emplace_back(phi_kernel_desc);
} }
......
...@@ -21,16 +21,16 @@ ...@@ -21,16 +21,16 @@
namespace infrt { namespace infrt {
struct PhiKernelDesc { struct PhiKernelDesc {
std::vector<Place> inputsType; // kernel input place std::vector<Place> input_types; // kernel input place
std::vector<Place> outputsType; // kernel output place std::vector<Place> output_types; // kernel output place
Place kernelType; // kernel place Place kernel_type; // kernel place
}; };
std::string getPhiTargetPrefix(TargetType target); std::string getPhiTargetPrefix(TargetType target);
std::string getPhiPrecisionSuffix(PrecisionType precision); std::string getPhiPrecisionSuffix(PrecisionType precision);
std::string getPhiLayoutSuffix(LayoutType layout); std::string getPhiLayoutSuffix(LayoutType layout);
std::vector<PhiKernelDesc> getCandidateKernels( std::vector<PhiKernelDesc> GetCandidateKernels(
std::string name, const std::vector<Place>& valid_palces); std::string name, const std::vector<Place>& valid_palces);
} // namespace infrt } // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/phi/kernels/declarations.h"
namespace infrt {
TEST(phi, get_op_desc) {
std::vector<Place> places;
places.emplace_back(
TargetType::CPU, PrecisionType::FLOAT32, LayoutType::NCHW);
auto kernels = GetCandidateKernels("addmm", places);
ASSERT_GE(kernels.size(), 1UL);
}
} // namespace infrt
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h" #include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <llvm/ADT/SetVector.h> #include <llvm/ADT/SetVector.h>
...@@ -24,35 +24,52 @@ ...@@ -24,35 +24,52 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/infrt/common/string.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/phi/ir/phi_kernels.h"
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h" #include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h" #include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/ops/compat/signatures.h" #include "paddle/phi/ops/compat/signatures.h"
namespace { namespace {
class phiOpCvtPass class PhiOpConvertPass
: public mlir::PassWrapper<phiOpCvtPass, mlir::FunctionPass> { : public mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "phiOpCvtPass"; } ::llvm::StringRef getName() const override { return "PhiOpConvertPass"; }
void runOnFunction() override; void runOnFunction() override;
explicit phiOpCvtPass( PhiOpConvertPass();
std::vector<infrt::Place> valid_places = std::vector<infrt::Place>()) explicit PhiOpConvertPass(const std::vector<infrt::Place> &valid_places)
: valid_places_(valid_places) {} : valid_places_(valid_places) {}
PhiOpConvertPass(const PhiOpConvertPass &other)
: mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass>(*this),
valid_places_(other.valid_places_) {}
::llvm::StringRef getArgument() const override { return "phi-op-convert"; }
void getDependentDialects(mlir::DialectRegistry &registry) const override;
private: private:
void convertStage(); void convertStage();
void diapatchStage(); void dispatchStage();
// Force a specified data format for all layout sensitive operations.
Option<std::string> valid_places_options_{
*this,
"valid-targets",
llvm::cl::desc("Set the valid target, [CPU-FP32-NCHW]")};
std::vector<infrt::Place> valid_places_; std::vector<infrt::Place> valid_places_;
}; };
// Implementation of the PhiOpConvertPass.
// Implementation of the phiOpCvtPass. void PhiOpConvertPass::runOnFunction() {
void phiOpCvtPass::runOnFunction() {
convertStage(); convertStage();
diapatchStage(); dispatchStage();
} }
void phiOpCvtPass::convertStage() {
void PhiOpConvertPass::convertStage() {
mlir::Block &body = getFunction().front(); mlir::Block &body = getFunction().front();
std::vector<mlir::Operation *> worklist; std::vector<mlir::Operation *> worklist;
for (auto &op : body.without_terminator()) { for (auto &op : body.without_terminator()) {
...@@ -62,9 +79,9 @@ void phiOpCvtPass::convertStage() { ...@@ -62,9 +79,9 @@ void phiOpCvtPass::convertStage() {
while (!worklist.empty()) { while (!worklist.empty()) {
auto *op = worklist.back(); auto *op = worklist.back();
worklist.pop_back(); worklist.pop_back();
if (op == nullptr) continue; if (!op) continue;
std::string op_name = op->getName().getIdentifier().str(); auto op_name = op->getName().getIdentifier().str();
// only convert op in pd dialect. // only convert op in pd dialect.
if (op_name.substr(0, 3) != "pd.") continue; if (op_name.substr(0, 3) != "pd.") continue;
...@@ -73,6 +90,7 @@ void phiOpCvtPass::convertStage() { ...@@ -73,6 +90,7 @@ void phiOpCvtPass::convertStage() {
pd_dialect_inputs_info_map_.end() || pd_dialect_inputs_info_map_.end() ||
pd_dialect_outputs_info_map_.find(op_name) == pd_dialect_outputs_info_map_.find(op_name) ==
pd_dialect_outputs_info_map_.end()) { pd_dialect_outputs_info_map_.end()) {
LOG(WARNING) << "No op info found for " << op_name;
// Todo: print log // Todo: print log
continue; continue;
} }
...@@ -85,7 +103,8 @@ void phiOpCvtPass::convertStage() { ...@@ -85,7 +103,8 @@ void phiOpCvtPass::convertStage() {
::llvm::SmallVector<mlir::Type, 4> output_types; ::llvm::SmallVector<mlir::Type, 4> output_types;
for (const std::string &str : std::get<0>(kernel_sign.args)) { for (const std::string &str : std::get<0>(kernel_sign.args)) {
if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) { if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
// Todo: print error log LOG(ERROR) << "No input info for Op " << op_name << " and argument "
<< str;
return; return;
} }
uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str); uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
...@@ -94,7 +113,8 @@ void phiOpCvtPass::convertStage() { ...@@ -94,7 +113,8 @@ void phiOpCvtPass::convertStage() {
for (const std::string &str : std::get<2>(kernel_sign.args)) { for (const std::string &str : std::get<2>(kernel_sign.args)) {
if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) { if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
// Todo: print error log LOG(ERROR) << "No output info for Op " << op_name << " and argument "
<< str;
return; return;
} }
uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str); uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
...@@ -109,14 +129,13 @@ void phiOpCvtPass::convertStage() { ...@@ -109,14 +129,13 @@ void phiOpCvtPass::convertStage() {
for (size_t index = 0; index < ori_output.size(); ++index) { for (size_t index = 0; index < ori_output.size(); ++index) {
ori_output[index].replaceAllUsesWith(kernel_op.getResult(index)); ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
} }
if (!op->use_empty()) {
// Todo: print error log CHECK(op->use_empty());
return;
}
op->erase(); op->erase();
} }
} }
void phiOpCvtPass::diapatchStage() {
void PhiOpConvertPass::dispatchStage() {
std::vector<infrt::KernelOp> worklist; std::vector<infrt::KernelOp> worklist;
mlir::Block &block = getFunction().front(); mlir::Block &block = getFunction().front();
for (auto &op : block) { for (auto &op : block) {
...@@ -129,7 +148,7 @@ void phiOpCvtPass::diapatchStage() { ...@@ -129,7 +148,7 @@ void phiOpCvtPass::diapatchStage() {
for (infrt::KernelOp kernel_op : worklist) { for (infrt::KernelOp kernel_op : worklist) {
std::string kernel_name = kernel_op.name().str(); std::string kernel_name = kernel_op.name().str();
std::vector<infrt::PhiKernelDesc> candidates = std::vector<infrt::PhiKernelDesc> candidates =
getCandidateKernels(kernel_name, valid_places_); GetCandidateKernels(kernel_name, valid_places_);
if (candidates.empty()) { if (candidates.empty()) {
LOG(FATAL) << "No candidate kernels for op:" << kernel_name; LOG(FATAL) << "No candidate kernels for op:" << kernel_name;
continue; continue;
...@@ -140,17 +159,17 @@ void phiOpCvtPass::diapatchStage() { ...@@ -140,17 +159,17 @@ void phiOpCvtPass::diapatchStage() {
const infrt::PhiKernelDesc &phi_kernel_desc = candidates.front(); const infrt::PhiKernelDesc &phi_kernel_desc = candidates.front();
kernel_name = kernel_name =
infrt::getPhiTargetPrefix(phi_kernel_desc.kernelType.target) + infrt::getPhiTargetPrefix(phi_kernel_desc.kernel_type.target) +
kernel_name + kernel_name +
infrt::getPhiPrecisionSuffix(phi_kernel_desc.kernelType.precision) + infrt::getPhiPrecisionSuffix(phi_kernel_desc.kernel_type.precision) +
infrt::getPhiLayoutSuffix(phi_kernel_desc.kernelType.layout); infrt::getPhiLayoutSuffix(phi_kernel_desc.kernel_type.layout);
mlir::OperationName operation_name(kernel_name, kernel_op.getContext()); mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
mlir::OperationState operation_state(kernel_op.getLoc(), operation_name); mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);
if (phi_context.find(phi_kernel_desc.kernelType.target) == if (phi_context.find(phi_kernel_desc.kernel_type.target) ==
phi_context.end()) { phi_context.end()) {
switch (phi_kernel_desc.kernelType.target) { switch (phi_kernel_desc.kernel_type.target) {
case infrt::TargetType::CPU: { case infrt::TargetType::CPU: {
auto context_value = auto context_value =
builder builder
...@@ -169,33 +188,36 @@ void phiOpCvtPass::diapatchStage() { ...@@ -169,33 +188,36 @@ void phiOpCvtPass::diapatchStage() {
} }
} }
operation_state.addOperands( operation_state.addOperands(
phi_context.at(phi_kernel_desc.kernelType.target)); phi_context.at(phi_kernel_desc.kernel_type.target));
for (size_t index = 0; index < phi_kernel_desc.inputsType.size(); ++index) {
for (size_t index = 0; index < phi_kernel_desc.input_types.size();
++index) {
mlir::Value input = kernel_op.getOperand(index); mlir::Value input = kernel_op.getOperand(index);
auto cvt_tensor_type_op = builder.create<infrt::CvtTensorOp>( auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
kernel_op.getLoc(), kernel_op.getLoc(),
infrt::DenseTensorType::get( infrt::DenseTensorType::get(
kernel_op.getContext(), kernel_op.getContext(),
phi_kernel_desc.inputsType[index].target, phi_kernel_desc.input_types[index].target,
phi_kernel_desc.inputsType[index].precision, phi_kernel_desc.input_types[index].precision,
phi_kernel_desc.inputsType[index].layout), phi_kernel_desc.input_types[index].layout),
input); input);
operation_state.addOperands(cvt_tensor_type_op.output()); operation_state.addOperands(cvt_tensor_type_op.output());
} }
for (size_t index = 0; index < phi_kernel_desc.outputsType.size();
for (size_t index = 0; index < phi_kernel_desc.output_types.size();
++index) { ++index) {
operation_state.addTypes(infrt::DenseTensorType::get( operation_state.addTypes(infrt::DenseTensorType::get(
kernel_op.getContext(), kernel_op.getContext(),
phi_kernel_desc.outputsType[index].target, phi_kernel_desc.output_types[index].target,
phi_kernel_desc.outputsType[index].precision, phi_kernel_desc.output_types[index].precision,
phi_kernel_desc.outputsType[index].layout)); phi_kernel_desc.output_types[index].layout));
} }
operation_state.addAttributes(kernel_op.attrsAttr().getValue()); operation_state.addAttributes(kernel_op.attrsAttr().getValue());
mlir::Operation *phi_operation = builder.createOperation(operation_state); mlir::Operation *phi_operation = builder.createOperation(operation_state);
for (size_t index = 0; index < phi_kernel_desc.outputsType.size(); for (size_t index = 0; index < phi_kernel_desc.output_types.size();
++index) { ++index) {
mlir::Value input = phi_operation->getResult(index); mlir::Value input = phi_operation->getResult(index);
auto cvt_tensor_type_op = builder.create<infrt::CvtTensorOp>( auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
kernel_op.getLoc(), kernel_op.getResultTypes()[index], input); kernel_op.getLoc(), kernel_op.getResultTypes()[index], input);
kernel_op.getResult(index).replaceAllUsesWith( kernel_op.getResult(index).replaceAllUsesWith(
cvt_tensor_type_op.output()); cvt_tensor_type_op.output());
...@@ -204,9 +226,35 @@ void phiOpCvtPass::diapatchStage() { ...@@ -204,9 +226,35 @@ void phiOpCvtPass::diapatchStage() {
} }
} }
PhiOpConvertPass::PhiOpConvertPass() {
if (!valid_places_options_.hasValue()) {
valid_places_.emplace_back(infrt::TargetType::CPU,
infrt::PrecisionType::FLOAT32,
infrt::LayoutType::NCHW);
return;
}
LOG(FATAL) << "To be done for specifying places in command line";
}
void PhiOpConvertPass::getDependentDialects(
mlir::DialectRegistry &registry) const {
registry.insert<infrt::InfrtDialect>();
registry.insert<infrt::phi::PHIDialect>();
registry.insert<infrt::phi::PHIDenseTensorDialect>();
registry.insert<infrt::phi::PHICPUKernelDialect>();
registry.insert<infrt::phi::PHIGPUKernelDialect>();
}
} // namespace } // namespace
mlir::PassRegistration<PhiOpConvertPass> phi_op_convert;
std::unique_ptr<mlir::Pass> infrt::createPhiOpCvtPass( std::unique_ptr<mlir::Pass> infrt::createPhiOpCvtPass(
std::vector<Place> valid_places) { std::vector<Place> valid_places) {
return std::make_unique<phiOpCvtPass>(valid_places); return std::make_unique<PhiOpConvertPass>(valid_places);
}
std::unique_ptr<mlir::Pass> infrt::createPhiOpCvtPass() {
return std::make_unique<PhiOpConvertPass>();
} }
...@@ -21,7 +21,8 @@ namespace infrt { ...@@ -21,7 +21,8 @@ namespace infrt {
* phiOpCvtPass. * phiOpCvtPass.
* Convert the general operators from pd Dialect to phi dialect. * Convert the general operators from pd Dialect to phi dialect.
*/ */
std::unique_ptr<mlir::Pass> createPhiOpCvtPass( std::unique_ptr<mlir::Pass> createPhiOpCvtPass(std::vector<Place> valid_places);
std::vector<Place> valid_places = std::vector<Place>());
std::unique_ptr<mlir::Pass> createPhiOpCvtPass();
} // namespace infrt } // namespace infrt
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include "paddle/infrt/kernel/test_kernels.h" #include "paddle/infrt/kernel/test_kernels.h"
#ifdef INFRT_WITH_PHI #ifdef INFRT_WITH_PHI
#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h" #include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h" #include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h" #include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/kernel/phi/registry.h" #include "paddle/infrt/kernel/phi/registry.h"
#endif #endif
......
configure_file(lit.cfg.py.in "${CMAKE_SOURCE_DIR}/paddle/infrt/tests/lit.cfg.py") configure_file(lit.cfg.py.in "${CMAKE_SOURCE_DIR}/paddle/infrt/tests/lit.cfg.py")
add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle/infrt/tests --filter-out \"disabled_*\"" add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle/infrt/tests --filter-out \"disabled_*\""
DEPENDS infrtopt infrtexec phi-ir-exec) DEPENDS infrtopt infrtexec)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir)
// RUN: phi-ir-exec %s // RUN: infrtopt -phi-op-convert -infrt-op-fuse %s
// CHECK-LABEL: @ops // CHECK-LABEL: @ops
func @ops() { func @ops() {
%a = pd.feed() {name="input0"} : !infrt.lod_tensor<?xf32,0> %a = pd.feed() {name="input0"} : !infrt.lod_tensor<?xf32,0>
...@@ -8,3 +9,10 @@ func @ops() { ...@@ -8,3 +9,10 @@ func @ops() {
%h = "pd.abs"(%g):(tensor<?xf32>) -> tensor<?xf32> %h = "pd.abs"(%g):(tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%h) {name="output"} :(tensor<?xf32>)->() "pd.fetch"(%h) {name="output"} :(tensor<?xf32>)->()
} }
// CHECK-LABEL: @op_execute
func @op_execute(%a:!infrt.lod_tensor<?xf32,0>, %b:!infrt.lod_tensor<?xf32,0>, %c:!infrt.lod_tensor<?xf32,0>) -> !infrt.lod_tensor<?xf32,0> {
%g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.lod_tensor<?xf32,0>, !infrt.lod_tensor<?xf32>) -> tensor<?xf32>
%h = "pd.abs"(%g):(tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%h) {name="output"} :(tensor<?xf32>)->()
}
...@@ -93,7 +93,7 @@ function infrt_gen_and_build() { ...@@ -93,7 +93,7 @@ function infrt_gen_and_build() {
exit 7; exit 7;
fi fi
make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec phi-ir-exec phi-exec infrt_lib_dist paddle-mlir-convert;build_error=$? make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec phi-exec infrt_lib_dist paddle-mlir-convert;build_error=$?
if [ "$build_error" != 0 ];then if [ "$build_error" != 0 ];then
exit 7; exit 7;
fi fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册