未验证 提交 35471b1f 编写于 作者: S Shang Zhizhou 提交者: GitHub

【infrt】add TrtOpConverterPass (#39902)

* add some trt layers

* trtOpConverter pass ok

* add comments

* add constraints to some attrs in the pd_lower_to_trt patterns

* update constraint

* fix code style

* update pass name

* update code style

* change .hpp.inc to .cc.inc in mlir_add_rewriter
上级 3cb93edf
...@@ -99,7 +99,7 @@ endfunction() ...@@ -99,7 +99,7 @@ endfunction()
function(mlir_add_rewriter td_base) function(mlir_add_rewriter td_base)
set(LLVM_TARGET_DEFINITIONS ${td_base}.td) set(LLVM_TARGET_DEFINITIONS ${td_base}.td)
mlir_tablegen(${td_base}.hpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass") mlir_tablegen(${td_base}.cpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass")
add_public_tablegen_target(${td_base}_IncGen) add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
endfunction() endfunction()
......
...@@ -97,6 +97,7 @@ set(infrt_mlir_incs ...@@ -97,6 +97,7 @@ set(infrt_mlir_incs
pd_extra_ops_inc pd_extra_ops_inc
rewrite_inc rewrite_inc
trt_ops_inc trt_ops_inc
pd_lower_to_trt_inc
) )
if (INFRT_WITH_PHI) if (INFRT_WITH_PHI)
......
...@@ -54,6 +54,20 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT ...@@ -54,6 +54,20 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
return b.getIntegerAttr(b.getI32Type(), constant); return b.getIntegerAttr(b.getI32Type(), constant);
} }
template <typename T>
static mlir::IntegerAttr createSI32Attr(mlir::OpBuilder &b, // NOLINT
mlir::Location loc,
T constant) {
return b.getSI32IntegerAttr(constant);
}
template <typename T>
static mlir::FloatAttr createF32Attr(mlir::OpBuilder &b, // NOLINT
mlir::Location loc,
T constant) {
return b.getF32FloatAttr(constant);
}
static mlir::SmallVector<mlir::Value, 4> cvtValueToValueRange( static mlir::SmallVector<mlir::Value, 4> cvtValueToValueRange(
const mlir::Value &operand) { const mlir::Value &operand) {
return mlir::SmallVector<mlir::Value, 4>(1, operand); return mlir::SmallVector<mlir::Value, 4>(1, operand);
......
...@@ -28,6 +28,12 @@ def BufferType : OpaqueType<"b", "buffer", "buffer">; ...@@ -28,6 +28,12 @@ def BufferType : OpaqueType<"b", "buffer", "buffer">;
class INFRT_createI32Attr<string value> : NativeCodeCall< class INFRT_createI32Attr<string value> : NativeCodeCall<
"infrt::createI32Attr($_builder, $_loc, " # value # ")">; "infrt::createI32Attr($_builder, $_loc, " # value # ")">;
class INFRT_createSI32Attr<string value> : NativeCodeCall<
"infrt::createSI32Attr($_builder, $_loc, " # value # ")">;
class INFRT_createF32Attr<string value> : NativeCodeCall<
"infrt::createF32Attr($_builder, $_loc, " # value # ")">;
def INFRT_cvtValueToValueRange : NativeCodeCall< def INFRT_cvtValueToValueRange : NativeCodeCall<
"infrt::cvtValueToValueRange($0)">; "infrt::cvtValueToValueRange($0)">;
......
...@@ -24,11 +24,11 @@ ...@@ -24,11 +24,11 @@
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_extra_ops.cpp.inc" // NOLINT #include "paddle/infrt/dialect/pd_extra_ops.cpp.inc" // NOLINT
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
namespace mlir { namespace mlir {
namespace pd { namespace pd {
#include "paddle/infrt/dialect/rewrite.cpp.inc" // NOLINT
PaddleDialect::PaddleDialect(MLIRContext *context) PaddleDialect::PaddleDialect(MLIRContext *context)
: Dialect("pd", context, TypeID::get<PaddleDialect>()) { : Dialect("pd", context, TypeID::get<PaddleDialect>()) {
addOperations< addOperations<
......
...@@ -2,11 +2,13 @@ core_gather_headers() ...@@ -2,11 +2,13 @@ core_gather_headers()
gather_srcs(infrt_src SRCS gather_srcs(infrt_src SRCS
trt_ops.cc trt_ops.cc
trt_op_converter_pass.cc
trt_op_teller_pass.cc trt_op_teller_pass.cc
trt_graph_fuse_pass.cc trt_graph_fuse_pass.cc
trt_graph_split_pass.cc trt_graph_split_pass.cc
) )
mlir_tablegen_on(trt_ops) mlir_tablegen_on(trt_ops)
mlir_add_rewriter(pd_lower_to_trt)
add_executable(trt-exec trt_exec.cc) add_executable(trt-exec trt_exec.cc)
target_link_libraries(trt-exec infrt ${MLIR_IR_LIBS}) target_link_libraries(trt-exec infrt ${MLIR_IR_LIBS})
#ifndef PD_LOWER_TO_TRT
#define PD_LOWER_TO_TRT
include "mlir/Interfaces/SideEffectInterfaces.td"
include "paddle/infrt/dialect/infrt_base.td"
include "paddle/infrt/dialect/pd_ops.td"
include "paddle/infrt/dialect/tensorrt/trt_ops.td"
def PD2TRT_Matmul_Lower : Pat<
(PD_MatmulOp $X, $Y, $transpose_X, $transpose_Y, ConstantAttr<F32Attr, "1.0">, ConstantAttr<SI32Attr, "1">),
(TRT_MatrixMultiplyOp $X, $transpose_X, $Y, $transpose_Y)>;
//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ElementWiseOperation::kSUM
def PD2TRT_ElementwiseAdd_Lower : Pat<
(PD_Elementwise_addOp $X, $Y, ConstantAttr<SI32Attr, "-1">),
(TRT_ElementWiseOp $X, $Y, (INFRT_createSI32Attr<"0">)/*kSUM*/)>;
//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ActivationType::kRELU
def PD2TRT_Relu_Lower : Pat<
(PD_ReluOp $X),
(TRT_ActivationOp $X, (INFRT_createSI32Attr<"0">)/*kRELU*/, (INFRT_createF32Attr<"0.0">), (INFRT_createF32Attr<"0.0">))>;
//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ActivationType::kCLIP
def PD2TRT_Relu6_Lower : Pat<
(PD_Relu6Op $X, $threshold),
(TRT_ActivationOp $X, (INFRT_createSI32Attr<"8">)/*kCLIP*/, (INFRT_createF32Attr<"0.0">), $threshold)>;
#endif // PD_LOWER_TO_TRT
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
...@@ -36,9 +37,10 @@ int main(int argc, char** argv) { ...@@ -36,9 +37,10 @@ int main(int argc, char** argv) {
mlir::PassManager pm(context); mlir::PassManager pm(context);
mlir::OpPassManager& trt_pass_manager = pm.nest<mlir::FuncOp>(); mlir::OpPassManager& trt_pass_manager = pm.nest<mlir::FuncOp>();
trt_pass_manager.addPass(std::make_unique<infrt::trt::trtOpTellerPass>()); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTOpTellerPass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::trtGraphFusePass>()); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphFusePass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::trtGraphSplitPass>(10)); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphSplitPass>(1));
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTOpConverterPass>());
if (mlir::failed(pm.run(*module))) { if (mlir::failed(pm.run(*module))) {
std::cout << "\npass failed!\n" << std::endl; std::cout << "\npass failed!\n" << std::endl;
return 4; return 4;
......
...@@ -142,7 +142,7 @@ void topoSortBlock(mlir::Block &body) { // NOLINT ...@@ -142,7 +142,7 @@ void topoSortBlock(mlir::Block &body) { // NOLINT
} // namespace } // namespace
// Implementation of the trtGraphFusePass. // Implementation of the trtGraphFusePass.
void trtGraphFusePass::runOnFunction() { void TRTGraphFusePass::runOnFunction() {
mlir::Block &body = getFunction().front(); mlir::Block &body = getFunction().front();
mlir::OpBuilder builder(&body, body.begin()); mlir::OpBuilder builder(&body, body.begin());
bool changed = false; bool changed = false;
......
...@@ -52,8 +52,8 @@ namespace trt { ...@@ -52,8 +52,8 @@ namespace trt {
* "pd.fetch" %d, %f * "pd.fetch" %d, %f
* } * }
*/ */
class trtGraphFusePass class TRTGraphFusePass
: public mlir::PassWrapper<trtGraphFusePass, mlir::FunctionPass> { : public mlir::PassWrapper<TRTGraphFusePass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtGraphFusePass"; } ::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override; void runOnFunction() override;
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtGraphSplitPass。 // Implementation of the trtGraphSplitPass。
void trtGraphSplitPass::runOnFunction() { void TRTGraphSplitPass::runOnFunction() {
std::vector<mlir::pd::GraphOp> worklist; std::vector<mlir::pd::GraphOp> worklist;
mlir::Block& block = getFunction().front(); mlir::Block& block = getFunction().front();
for (auto& op : block) { for (auto& op : block) {
......
...@@ -45,12 +45,12 @@ namespace trt { ...@@ -45,12 +45,12 @@ namespace trt {
* "pd.fetch" (%d, %f) * "pd.fetch" (%d, %f)
* } * }
*/ */
class trtGraphSplitPass class TRTGraphSplitPass
: public mlir::PassWrapper<trtGraphSplitPass, mlir::FunctionPass> { : public mlir::PassWrapper<TRTGraphSplitPass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void runOnFunction() override; void runOnFunction() override;
explicit trtGraphSplitPass(size_t min_subgraph_size = 3) explicit TRTGraphSplitPass(size_t min_subgraph_size = 3)
: min_subgraph_size_(min_subgraph_size) {} : min_subgraph_size_(min_subgraph_size) {}
private: private:
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/pd_ops.h"
namespace infrt {
namespace trt {
#include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT
using namespace mlir;
void TRTOpConverterPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to TensorRTDialect from
// PaddleDialect
target.addLegalDialect<TensorRTDialect>();
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the TensorRT operations.
RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
} // namespace trt
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
/*
* trtOpConverterPass.
*
* source ir:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.return" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
*
* destination ir:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %m = "trt.Convolution"(%a)...
* %n = "trt.Convolution"(%m)...
* %s = "trt.Convolution"(%a)...
* "pd.return" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
*/
struct TRTOpConverterPass
: public mlir::PassWrapper<TRTOpConverterPass,
mlir::OperationPass<mlir::FuncOp>> {
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<TensorRTDialect>();
}
::llvm::StringRef getName() const override { return "trtOpConverterPass"; }
void runOnOperation() final;
};
} // namespace trt
} // namespace infrt
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtOpTellerPass。 // Implementation of the trtOpTellerPass。
void trtOpTellerPass::runOnFunction() { void TRTOpTellerPass::runOnFunction() {
mlir::Block &body = getFunction().front(); mlir::Block &body = getFunction().front();
std::vector<mlir::Operation *> worklist; std::vector<mlir::Operation *> worklist;
worklist.reserve(body.getOperations().size()); worklist.reserve(body.getOperations().size());
......
...@@ -52,8 +52,8 @@ namespace trt { ...@@ -52,8 +52,8 @@ namespace trt {
* TODO(winter-wang): Supplementary how to judge the operators can be supported * TODO(winter-wang): Supplementary how to judge the operators can be supported
* by tensorrt. * by tensorrt.
*/ */
class trtOpTellerPass class TRTOpTellerPass
: public mlir::PassWrapper<trtOpTellerPass, mlir::FunctionPass> { : public mlir::PassWrapper<TRTOpTellerPass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtOpTellerPass"; } ::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override; void runOnFunction() override;
......
...@@ -23,8 +23,48 @@ def TRT_GraphOp : TRT_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { ...@@ -23,8 +23,48 @@ def TRT_GraphOp : TRT_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> {
Describe a tensorrt subgraph. Describe a tensorrt subgraph.
}]; }];
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<TRT_Tensor>:$inputs);
let results = (outs Variadic<TRT_Tensor>:$outputs); let results = (outs Variadic<TRT_Tensor>:$outputs);
} }
def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> {
let summary = "TensorRT IActivationLayer";
let description = [{
TensorRT IActivationLayer.
}];
let arguments = (ins TRT_Tensor:$input, SI32Attr:$activation_type,
DefaultValuedAttr<F32Attr, "0.0">:$alpha,
DefaultValuedAttr<F32Attr, "0.0">:$beta);
let results = (outs TRT_Tensor:$output);
}
def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> {
let summary = "TensorRT IElementWiseLayer";
let description = [{
TensorRT IElementWiseLayer.
}];
let arguments = (ins TRT_Tensor:$input1, TRT_Tensor:$input2, SI32Attr:$elementwise_operation);
let results = (outs TRT_Tensor:$output);
}
def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> {
let summary = "TensorRT IMatrixMultiplyLayer";
let description = [{
TensorRT IMatrixMultiplyLayer.
}];
let arguments = (ins TRT_Tensor:$input1, BoolAttr:$transpose1,
TRT_Tensor:$input2, BoolAttr:$transpose2);
let results = (outs TRT_Tensor:$output);
}
#endif // TRT_OPS #endif // TRT_OPS
...@@ -7,15 +7,15 @@ func @main() -> tensor<?xf32> { ...@@ -7,15 +7,15 @@ func @main() -> tensor<?xf32> {
%bias1 = "pd.feed"() {name="input4"} : () -> tensor<?xf32> %bias1 = "pd.feed"() {name="input4"} : () -> tensor<?xf32>
%bias2 = "pd.feed"() {name="input5"} : () -> tensor<?xf32> %bias2 = "pd.feed"() {name="input5"} : () -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32> %e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
%c1 = "pd.matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %c1 = "pd.matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d1 = "pd.elementwise_add"(%c1, %bias1) {axis=1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d1 = "pd.elementwise_add"(%c1, %bias1) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e1 = "pd.relu"(%d1) {} : (tensor<?xf32>) -> tensor<?xf32> %e1 = "pd.relu"(%d1) {} : (tensor<?xf32>) -> tensor<?xf32>
%c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32> %e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->() "pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册