未验证 提交 927767ca 编写于 作者: 王明冬 提交者: GitHub

[infrt]Refine phi dialect (#40505)

* change some symbol names

* add test

* add phi to opt.cc

* clean code

* up

* update

* up

* up

* Update pten_pass.mlir

* Update convolution_grad_kernel.cc

* update

* restore init_infrt_dialects

* restore

* up

* up

* up
Co-authored-by: NSuperjomn <yanchunwei@outlook.com>
上级 a04a6bd5
......@@ -53,9 +53,9 @@ def Infrt_CallOp : Infrt_Op<"call"> {
}];
}
def Infrt_CvtTensorOp : Infrt_Op<"cvt_tensor", [NoSideEffect]> {
let summary = "convert tensor type op";
let description = [{convert tensor type op!}];
def Infrt_TensorCastOp : Infrt_Op<"tensor_cast", [NoSideEffect]> {
let summary = "cast tensor type op";
let description = [{cast tensor type op!}];
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
}
......@@ -5,17 +5,17 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "paddle/infrt/dialect/infrt/ir/infrt_ops.td"
include "paddle/infrt/dialect/pd_ops.td"
def FuseCvtTensorPattern : Pat<
(Infrt_CvtTensorOp (Infrt_CvtTensorOp $arg)),
(Infrt_CvtTensorOp $arg)>;
def FuseTensorCastPattern : Pat<
(Infrt_TensorCastOp (Infrt_TensorCastOp $arg)),
(Infrt_TensorCastOp $arg)>;
def FuseFeedCvtTensorPattern : Pat<
(Infrt_CvtTensorOp (PD_FeedOp $name)),
def FuseFeedTensorCastPattern : Pat<
(Infrt_TensorCastOp (PD_FeedOp $name)),
(PD_FeedOp $name)>;
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
def RedundantCvtTensorOptPattern : Pat<
(Infrt_CvtTensorOp:$res $arg), (replaceWithValue $arg),
def RedundantTensorCastOptPattern : Pat<
(Infrt_TensorCastOp:$res $arg), (replaceWithValue $arg),
[(TypesAreIdentical $res, $arg)]>;
......
......@@ -27,8 +27,12 @@ struct InfrtOpFusePass
: public mlir::PassWrapper<InfrtOpFusePass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "infrtOpFusePass"; }
llvm::StringRef getArgument() const override { return "infrt-op-fuse"; }
void runOnFunction() override;
};
// Implementation of the InfrtOpFusePass.
void InfrtOpFusePass::runOnFunction() {
::mlir::RewritePatternSet patterns(&getContext());
......@@ -39,14 +43,18 @@ void InfrtOpFusePass::runOnFunction() {
if (nullptr == terminator_op) return;
for (auto operand : terminator_op->getOperands()) {
auto *op1 = operand.getDefiningOp();
auto cvt_op = ::llvm::dyn_cast<::infrt::CvtTensorOp>(op1);
auto cvt_op = ::llvm::dyn_cast<::infrt::TensorCastOp>(op1);
if (!cvt_op) continue;
mlir::Value value = cvt_op.input();
operand.replaceAllUsesWith(value);
cvt_op.erase();
}
}
} // namespace
std::unique_ptr<mlir::Pass> infrt::createInfrtOpFusePass() {
return std::make_unique<InfrtOpFusePass>();
}
mlir::PassRegistration<InfrtOpFusePass> infrt_op_fuse_pass;
......@@ -5,9 +5,6 @@ endif()
add_subdirectory(ir)
add_subdirectory(pass)
add_executable(phi-ir-exec phi_ir_exec.cc)
target_link_libraries(phi-ir-exec infrt)
add_executable(phi-exec phi_exec.cc)
target_link_libraries(phi-exec infrt)
......
......@@ -29,6 +29,7 @@ namespace infrt {
namespace phi {
void PHIDialect::initialize() {
LOG(INFO) << "PHI Dialect initalized";
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/phi/ir/infrt_phi_base.cpp.inc" // NOLINT
......
......@@ -2,6 +2,8 @@ core_gather_headers()
gather_srcs(infrt_src SRCS
proto_arg_map_context.cc
phi_op_cvt_pass.cc
phi_op_convert_pass.cc
kernel_op_desc.cc
)
cc_test(test_kernel_op_desc SRCS kernel_op_desc_test.cc DEPS infrt)
......@@ -73,7 +73,7 @@ std::string getPhiLayoutSuffix(LayoutType layout) {
}
}
std::vector<PhiKernelDesc> getCandidateKernels(
std::vector<PhiKernelDesc> GetCandidateKernels(
std::string name, const std::vector<Place>& valid_palces) {
std::vector<PhiKernelDesc> candidate_kernels;
PhiKernelDesc phi_kernel_desc;
......@@ -88,19 +88,20 @@ std::vector<PhiKernelDesc> getCandidateKernels(
if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) continue;
place.layout = LayoutType::ANY;
}
phi_kernel_desc.kernelType = place;
phi_kernel_desc.inputsType.clear();
phi_kernel_desc.outputsType.clear();
phi_kernel_desc.kernel_type = place;
phi_kernel_desc.input_types.clear();
phi_kernel_desc.output_types.clear();
phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def();
const paddle::SmallVector<phi::TensorArgDef>& input_arg =
args_def.input_defs();
const paddle::SmallVector<phi::TensorArgDef>& output_arg =
args_def.output_defs();
for (auto tensor_arg : input_arg) {
phi_kernel_desc.inputsType.emplace_back(ConvertPlaceFromPhi(tensor_arg));
phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg));
}
for (auto tensor_arg : output_arg) {
phi_kernel_desc.outputsType.emplace_back(ConvertPlaceFromPhi(tensor_arg));
phi_kernel_desc.output_types.emplace_back(
ConvertPlaceFromPhi(tensor_arg));
}
candidate_kernels.emplace_back(phi_kernel_desc);
}
......
......@@ -21,16 +21,16 @@
namespace infrt {
struct PhiKernelDesc {
std::vector<Place> inputsType; // kernel input place
std::vector<Place> outputsType; // kernel output place
Place kernelType; // kernel place
std::vector<Place> input_types; // kernel input place
std::vector<Place> output_types; // kernel output place
Place kernel_type; // kernel place
};
std::string getPhiTargetPrefix(TargetType target);
std::string getPhiPrecisionSuffix(PrecisionType precision);
std::string getPhiLayoutSuffix(LayoutType layout);
std::vector<PhiKernelDesc> getCandidateKernels(
std::vector<PhiKernelDesc> GetCandidateKernels(
std::string name, const std::vector<Place>& valid_palces);
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/phi/kernels/declarations.h"
namespace infrt {
TEST(phi, get_op_desc) {
std::vector<Place> places;
places.emplace_back(
TargetType::CPU, PrecisionType::FLOAT32, LayoutType::NCHW);
auto kernels = GetCandidateKernels("addmm", places);
ASSERT_GE(kernels.size(), 1UL);
}
} // namespace infrt
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
#include <glog/logging.h>
#include <llvm/ADT/SetVector.h>
......@@ -24,35 +24,52 @@
#include <unordered_set>
#include <vector>
#include "paddle/infrt/common/string.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/phi/ir/phi_kernels.h"
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/ops/compat/signatures.h"
namespace {
class phiOpCvtPass
: public mlir::PassWrapper<phiOpCvtPass, mlir::FunctionPass> {
class PhiOpConvertPass
: public mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "phiOpCvtPass"; }
::llvm::StringRef getName() const override { return "PhiOpConvertPass"; }
void runOnFunction() override;
explicit phiOpCvtPass(
std::vector<infrt::Place> valid_places = std::vector<infrt::Place>())
PhiOpConvertPass();
explicit PhiOpConvertPass(const std::vector<infrt::Place> &valid_places)
: valid_places_(valid_places) {}
PhiOpConvertPass(const PhiOpConvertPass &other)
: mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass>(*this),
valid_places_(other.valid_places_) {}
::llvm::StringRef getArgument() const override { return "phi-op-convert"; }
void getDependentDialects(mlir::DialectRegistry &registry) const override;
private:
void convertStage();
void diapatchStage();
void dispatchStage();
// Force a specified data format for all layout sensitive operations.
Option<std::string> valid_places_options_{
*this,
"valid-targets",
llvm::cl::desc("Set the valid target, [CPU-FP32-NCHW]")};
std::vector<infrt::Place> valid_places_;
};
// Implementation of the phiOpCvtPass.
void phiOpCvtPass::runOnFunction() {
// Implementation of the PhiOpConvertPass.
void PhiOpConvertPass::runOnFunction() {
convertStage();
diapatchStage();
dispatchStage();
}
void phiOpCvtPass::convertStage() {
void PhiOpConvertPass::convertStage() {
mlir::Block &body = getFunction().front();
std::vector<mlir::Operation *> worklist;
for (auto &op : body.without_terminator()) {
......@@ -62,9 +79,9 @@ void phiOpCvtPass::convertStage() {
while (!worklist.empty()) {
auto *op = worklist.back();
worklist.pop_back();
if (op == nullptr) continue;
if (!op) continue;
std::string op_name = op->getName().getIdentifier().str();
auto op_name = op->getName().getIdentifier().str();
// only convert op in pd dialect.
if (op_name.substr(0, 3) != "pd.") continue;
......@@ -73,6 +90,7 @@ void phiOpCvtPass::convertStage() {
pd_dialect_inputs_info_map_.end() ||
pd_dialect_outputs_info_map_.find(op_name) ==
pd_dialect_outputs_info_map_.end()) {
LOG(WARNING) << "No op info found for " << op_name;
// Todo: print log
continue;
}
......@@ -85,7 +103,8 @@ void phiOpCvtPass::convertStage() {
::llvm::SmallVector<mlir::Type, 4> output_types;
for (const std::string &str : std::get<0>(kernel_sign.args)) {
if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
// Todo: print error log
LOG(ERROR) << "No input info for Op " << op_name << " and argument "
<< str;
return;
}
uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
......@@ -94,7 +113,8 @@ void phiOpCvtPass::convertStage() {
for (const std::string &str : std::get<2>(kernel_sign.args)) {
if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
// Todo: print error log
LOG(ERROR) << "No output info for Op " << op_name << " and argument "
<< str;
return;
}
uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
......@@ -109,14 +129,13 @@ void phiOpCvtPass::convertStage() {
for (size_t index = 0; index < ori_output.size(); ++index) {
ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
}
if (!op->use_empty()) {
// Todo: print error log
return;
}
CHECK(op->use_empty());
op->erase();
}
}
void phiOpCvtPass::diapatchStage() {
void PhiOpConvertPass::dispatchStage() {
std::vector<infrt::KernelOp> worklist;
mlir::Block &block = getFunction().front();
for (auto &op : block) {
......@@ -129,7 +148,7 @@ void phiOpCvtPass::diapatchStage() {
for (infrt::KernelOp kernel_op : worklist) {
std::string kernel_name = kernel_op.name().str();
std::vector<infrt::PhiKernelDesc> candidates =
getCandidateKernels(kernel_name, valid_places_);
GetCandidateKernels(kernel_name, valid_places_);
if (candidates.empty()) {
LOG(FATAL) << "No candidate kernels for op:" << kernel_name;
continue;
......@@ -140,17 +159,17 @@ void phiOpCvtPass::diapatchStage() {
const infrt::PhiKernelDesc &phi_kernel_desc = candidates.front();
kernel_name =
infrt::getPhiTargetPrefix(phi_kernel_desc.kernelType.target) +
infrt::getPhiTargetPrefix(phi_kernel_desc.kernel_type.target) +
kernel_name +
infrt::getPhiPrecisionSuffix(phi_kernel_desc.kernelType.precision) +
infrt::getPhiLayoutSuffix(phi_kernel_desc.kernelType.layout);
infrt::getPhiPrecisionSuffix(phi_kernel_desc.kernel_type.precision) +
infrt::getPhiLayoutSuffix(phi_kernel_desc.kernel_type.layout);
mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);
if (phi_context.find(phi_kernel_desc.kernelType.target) ==
if (phi_context.find(phi_kernel_desc.kernel_type.target) ==
phi_context.end()) {
switch (phi_kernel_desc.kernelType.target) {
switch (phi_kernel_desc.kernel_type.target) {
case infrt::TargetType::CPU: {
auto context_value =
builder
......@@ -169,33 +188,36 @@ void phiOpCvtPass::diapatchStage() {
}
}
operation_state.addOperands(
phi_context.at(phi_kernel_desc.kernelType.target));
for (size_t index = 0; index < phi_kernel_desc.inputsType.size(); ++index) {
phi_context.at(phi_kernel_desc.kernel_type.target));
for (size_t index = 0; index < phi_kernel_desc.input_types.size();
++index) {
mlir::Value input = kernel_op.getOperand(index);
auto cvt_tensor_type_op = builder.create<infrt::CvtTensorOp>(
auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
kernel_op.getLoc(),
infrt::DenseTensorType::get(
kernel_op.getContext(),
phi_kernel_desc.inputsType[index].target,
phi_kernel_desc.inputsType[index].precision,
phi_kernel_desc.inputsType[index].layout),
phi_kernel_desc.input_types[index].target,
phi_kernel_desc.input_types[index].precision,
phi_kernel_desc.input_types[index].layout),
input);
operation_state.addOperands(cvt_tensor_type_op.output());
}
for (size_t index = 0; index < phi_kernel_desc.outputsType.size();
for (size_t index = 0; index < phi_kernel_desc.output_types.size();
++index) {
operation_state.addTypes(infrt::DenseTensorType::get(
kernel_op.getContext(),
phi_kernel_desc.outputsType[index].target,
phi_kernel_desc.outputsType[index].precision,
phi_kernel_desc.outputsType[index].layout));
phi_kernel_desc.output_types[index].target,
phi_kernel_desc.output_types[index].precision,
phi_kernel_desc.output_types[index].layout));
}
operation_state.addAttributes(kernel_op.attrsAttr().getValue());
mlir::Operation *phi_operation = builder.createOperation(operation_state);
for (size_t index = 0; index < phi_kernel_desc.outputsType.size();
for (size_t index = 0; index < phi_kernel_desc.output_types.size();
++index) {
mlir::Value input = phi_operation->getResult(index);
auto cvt_tensor_type_op = builder.create<infrt::CvtTensorOp>(
auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
kernel_op.getLoc(), kernel_op.getResultTypes()[index], input);
kernel_op.getResult(index).replaceAllUsesWith(
cvt_tensor_type_op.output());
......@@ -204,9 +226,35 @@ void phiOpCvtPass::diapatchStage() {
}
}
PhiOpConvertPass::PhiOpConvertPass() {
if (!valid_places_options_.hasValue()) {
valid_places_.emplace_back(infrt::TargetType::CPU,
infrt::PrecisionType::FLOAT32,
infrt::LayoutType::NCHW);
return;
}
LOG(FATAL) << "To be done for specifying places in command line";
}
void PhiOpConvertPass::getDependentDialects(
mlir::DialectRegistry &registry) const {
registry.insert<infrt::InfrtDialect>();
registry.insert<infrt::phi::PHIDialect>();
registry.insert<infrt::phi::PHIDenseTensorDialect>();
registry.insert<infrt::phi::PHICPUKernelDialect>();
registry.insert<infrt::phi::PHIGPUKernelDialect>();
}
} // namespace
mlir::PassRegistration<PhiOpConvertPass> phi_op_convert;
std::unique_ptr<mlir::Pass> infrt::createPhiOpCvtPass(
std::vector<Place> valid_places) {
return std::make_unique<phiOpCvtPass>(valid_places);
return std::make_unique<PhiOpConvertPass>(valid_places);
}
std::unique_ptr<mlir::Pass> infrt::createPhiOpCvtPass() {
return std::make_unique<PhiOpConvertPass>();
}
......@@ -21,7 +21,8 @@ namespace infrt {
* phiOpCvtPass.
* Convert the general operators from pd Dialect to phi dialect.
*/
std::unique_ptr<mlir::Pass> createPhiOpCvtPass(
std::vector<Place> valid_places = std::vector<Place>());
std::unique_ptr<mlir::Pass> createPhiOpCvtPass(std::vector<Place> valid_places);
std::unique_ptr<mlir::Pass> createPhiOpCvtPass();
} // namespace infrt
......@@ -30,7 +30,7 @@
#include "paddle/infrt/kernel/test_kernels.h"
#ifdef INFRT_WITH_PHI
#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/kernel/phi/registry.h"
#endif
......
configure_file(lit.cfg.py.in "${CMAKE_SOURCE_DIR}/paddle/infrt/tests/lit.cfg.py")
add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle/infrt/tests --filter-out \"disabled_*\""
DEPENDS infrtopt infrtexec phi-ir-exec)
DEPENDS infrtopt infrtexec)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir)
// RUN: phi-ir-exec %s
// RUN: infrtopt -phi-op-convert -infrt-op-fuse %s
// CHECK-LABEL: @ops
func @ops() {
%a = pd.feed() {name="input0"} : !infrt.lod_tensor<?xf32,0>
......@@ -8,3 +9,10 @@ func @ops() {
%h = "pd.abs"(%g):(tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%h) {name="output"} :(tensor<?xf32>)->()
}
// CHECK-LABEL: @op_execute
func @op_execute(%a:!infrt.lod_tensor<?xf32,0>, %b:!infrt.lod_tensor<?xf32,0>, %c:!infrt.lod_tensor<?xf32,0>) -> !infrt.lod_tensor<?xf32,0> {
%g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.lod_tensor<?xf32,0>, !infrt.lod_tensor<?xf32>) -> tensor<?xf32>
%h = "pd.abs"(%g):(tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%h) {name="output"} :(tensor<?xf32>)->()
}
......@@ -93,7 +93,7 @@ function infrt_gen_and_build() {
exit 7;
fi
make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec phi-ir-exec phi-exec infrt_lib_dist paddle-mlir-convert;build_error=$?
make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec phi-exec infrt_lib_dist paddle-mlir-convert;build_error=$?
if [ "$build_error" != 0 ];then
exit 7;
fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册