未验证 提交 e72ef603 编写于 作者: S Shang Zhizhou 提交者: GitHub

Add trt execute (#40224)

* add trt.execute

* merge trt.engine type

* update return op

* update comments

* fix style

* fix style
上级 99fc1b08
...@@ -18,6 +18,22 @@ def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> { ...@@ -18,6 +18,22 @@ def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> {
let results = (outs Variadic<AnyType>); let results = (outs Variadic<AnyType>);
} }
def Infrt_ReturnOp : Infrt_Op<"return", [Terminator]> {
let summary = "host executor return operation";
let description = [{
The "infrt.return" operation represents a return operation within a function.
func @foo() : (i32, f8) {
infrt.return %0, %1 : i32, f8
}
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<(ins),
[{ build($_builder, $_state, llvm::None); }]>];
}
def Infrt_CvtTensorOp : Infrt_Op<"cvt_tensor", [NoSideEffect]> { def Infrt_CvtTensorOp : Infrt_Op<"cvt_tensor", [NoSideEffect]> {
let summary = "convert tensor type op"; let summary = "convert tensor type op";
let description = [{convert tensor type op!}]; let description = [{convert tensor type op!}];
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <mlir/IR/Matchers.h> #include <mlir/IR/Matchers.h>
#include <mlir/IR/PatternMatch.h> #include <mlir/IR/PatternMatch.h>
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/infrt_base.h"
#define GET_OP_CLASSES #define GET_OP_CLASSES
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <mlir/Interfaces/InferTypeOpInterface.h> #include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h> #include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
namespace mlir { namespace mlir {
namespace pd { namespace pd {
......
...@@ -23,6 +23,8 @@ class EngineType ...@@ -23,6 +23,8 @@ class EngineType
: public mlir::Type::TypeBase<EngineType, mlir::Type, mlir::TypeStorage> { : public mlir::Type::TypeBase<EngineType, mlir::Type, mlir::TypeStorage> {
public: public:
using Base::Base; using Base::Base;
static EngineType get();
static EngineType get(mlir::MLIRContext *context);
}; };
} // namespace trt } // namespace trt
......
...@@ -53,9 +53,9 @@ bool reverseDfs(std::vector<mlir::Operation *> source, ...@@ -53,9 +53,9 @@ bool reverseDfs(std::vector<mlir::Operation *> source,
} }
// merge the first&second graph op to a new graph op. // merge the first&second graph op to a new graph op.
void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
CreateEngineOp first, mlir::pd::GraphOp first,
CreateEngineOp second) { mlir::pd::GraphOp second) {
// comput inputs and outputs // comput inputs and outputs
::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs; ::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (mlir::Value input : second.getOperands()) { for (mlir::Value input : second.getOperands()) {
...@@ -84,8 +84,7 @@ void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT ...@@ -84,8 +84,7 @@ void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT
// create the new graph op // create the new graph op
builder.setInsertionPoint(first); builder.setInsertionPoint(first);
auto loc = first.getLoc(); auto loc = first.getLoc();
auto graph_op = auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
builder.create<CreateEngineOp>(loc, return_types, inputs, true);
mlir::Block *block = new mlir::Block; mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator(); auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(), block->getOperations().splice(block->begin(),
...@@ -98,7 +97,7 @@ void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT ...@@ -98,7 +97,7 @@ void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT
copy_range.begin(), copy_range.begin(),
copy_range.end()); copy_range.end());
builder.setInsertionPointToEnd(block); builder.setInsertionPointToEnd(block);
builder.create<::infrt::dialect::ReturnOp>(loc, outputs); builder.create<::infrt::ReturnOp>(loc, outputs);
graph_op.body().push_back(block); graph_op.body().push_back(block);
// mapping the output // mapping the output
...@@ -150,12 +149,13 @@ void TRTGraphFusePass::runOnFunction() { ...@@ -150,12 +149,13 @@ void TRTGraphFusePass::runOnFunction() {
do { do {
changed = false; changed = false;
for (auto &op : body) { for (auto &op : body) {
CreateEngineOp graph_op = ::llvm::dyn_cast_or_null<CreateEngineOp>(&op); mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr == graph_op) continue; if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) { for (auto user_op : op.getUsers()) {
CreateEngineOp user_graph_op = mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<CreateEngineOp>(user_op); ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
if (nullptr == user_graph_op) continue; if (nullptr == user_graph_op) continue;
// get all dst input nodes except src. // get all dst input nodes except src.
std::vector<mlir::Operation *> source_nodes; std::vector<mlir::Operation *> source_nodes;
...@@ -168,7 +168,7 @@ void TRTGraphFusePass::runOnFunction() { ...@@ -168,7 +168,7 @@ void TRTGraphFusePass::runOnFunction() {
// Reverse DFS from the source_nodes. // Reverse DFS from the source_nodes.
if (!reverseDfs(source_nodes, if (!reverseDfs(source_nodes,
[&op](const mlir::Operation *n) { return n == &op; })) { [&op](const mlir::Operation *n) { return n == &op; })) {
mergeTwoAdjacentCreateEngineOp(builder, graph_op, user_graph_op); mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true; changed = true;
break; break;
} }
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include <mlir/Pass/Pass.h> #include <mlir/Pass/Pass.h>
#include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -26,40 +25,37 @@ namespace trt { ...@@ -26,40 +25,37 @@ namespace trt {
* *
* source func: * source func:
* *
* func @main() -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %a = "pd.feed"()... * %c = "pd.graph"(%a) {
* %c = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "Infrt.return" %m * "infrt.return" (%m)
* } ... * } ...
* %d = "trt.create_engine"(%c) { * %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)... * %m = "pd.conv3d"(%c)...
* "Infrt.return" %m * "infrt.return" (%m)
* } ... * } ...
* %f = "trt.create_engine"(%a) { * %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "Infrt.return" %m * "infrt.return" (%m)
* } ... * } ...
* "pd.fetch" %d, %f * "infrt.return" (%d, %f)..
* }
* *
* destination func: * destination func:
* func @main() -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %a = "pd.feed"()... * %d, %f = "pd.graph"(%a) {
* %d, %f = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
* "Infrt.return" %n, %s * "infrt.return" (%n, %s)
* } ... * } ...
* "pd.fetch" %d, %f * "infrt.return" (%d, %f)
* } * }
*/ */
class TRTGraphFusePass class TRTGraphFusePass
: public mlir::PassWrapper<TRTGraphFusePass, mlir::FunctionPass> { : public mlir::PassWrapper<TRTGraphFusePass, mlir::FunctionPass> {
public: public:
void getDependentDialects(mlir::DialectRegistry &registry) const override { void getDependentDialects(mlir::DialectRegistry &registry) const override {}
registry.insert<TensorRTDialect, ::infrt::dialect::INFRTDialect>();
}
::llvm::StringRef getName() const override { return "trtGraphFusePass"; } ::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override; void runOnFunction() override;
}; };
......
...@@ -16,23 +16,23 @@ ...@@ -16,23 +16,23 @@
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtGraphSplitPass。 // Implementation of the trtGraphSplitPass。
void TRTGraphSplitPass::runOnFunction() { void TRTGraphSplitPass::runOnFunction() {
std::vector<CreateEngineOp> worklist; std::vector<mlir::pd::GraphOp> worklist;
mlir::Block& block = getFunction().front(); mlir::Block& block = getFunction().front();
for (auto& op : block) { for (auto& op : block) {
CreateEngineOp graph_op = ::llvm::dyn_cast_or_null<CreateEngineOp>(&op); mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr != graph_op && if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op); worklist.push_back(graph_op);
} }
} }
while (!worklist.empty()) { while (!worklist.empty()) {
CreateEngineOp graph_op = worklist.back(); mlir::pd::GraphOp graph_op = worklist.back();
worklist.pop_back(); worklist.pop_back();
mlir::Block* body = graph_op.getBody(); mlir::Block* body = graph_op.getBody();
auto return_op = body->getTerminator(); auto return_op = body->getTerminator();
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include <mlir/Pass/Pass.h> #include <mlir/Pass/Pass.h>
#include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -27,33 +26,29 @@ namespace trt { ...@@ -27,33 +26,29 @@ namespace trt {
* *
* source func: * source func:
* *
* func @main() -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %a = "pd.feed"()... * %d, %f = "pd.graph"(%a) {
* %d, %f = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
* "Infrt.return" (%n, %s) * "infrt.return" (%n, %s)...
* } ... * } ...
* "pd.fetch" (%d, %f) * "infrt.return" (%d, %f)...
* } * }
* *
* destination func: * destination func:
* func @main() -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %a = "pd.feed"()...
* %c = "pd.conv2d"(%a) ... * %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ... * %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ... * %f = "pd.conv2d"(%a) ...
* "pd.fetch" (%d, %f) * "infrt.return" (%d, %f)...
* } * }
*/ */
class TRTGraphSplitPass class TRTGraphSplitPass
: public mlir::PassWrapper<TRTGraphSplitPass, mlir::FunctionPass> { : public mlir::PassWrapper<TRTGraphSplitPass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void getDependentDialects(mlir::DialectRegistry &registry) const override { void getDependentDialects(mlir::DialectRegistry &registry) const override {}
registry.insert<TensorRTDialect, ::infrt::dialect::INFRTDialect>();
}
void runOnFunction() override; void runOnFunction() override;
explicit TRTGraphSplitPass(size_t min_subgraph_size = 3) explicit TRTGraphSplitPass(size_t min_subgraph_size = 3)
: min_subgraph_size_(min_subgraph_size) {} : min_subgraph_size_(min_subgraph_size) {}
......
...@@ -16,12 +16,64 @@ ...@@ -16,12 +16,64 @@
#include <mlir/Transforms/DialectConversion.h> #include <mlir/Transforms/DialectConversion.h>
#include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
#include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT #include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT
struct PD2TRT_GraphLower : public ::mlir::RewritePattern {
PD2TRT_GraphLower(::mlir::MLIRContext *context)
: ::mlir::RewritePattern("pd.graph", 1, context, {"trt.create_engine"}) {}
::mlir::LogicalResult matchAndRewrite(
::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override {
auto casted_op = ::llvm::dyn_cast<mlir::pd::GraphOp>(op);
::mlir::Operation::operand_range inputs = casted_op.inputs();
auto ods_loc = rewriter.getFusedLoc(op->getLoc());
CreateEngineOp create_engine_op;
// inputs
::mlir::SmallVector<::mlir::Value, 4> trt_inputs;
for (auto v : inputs) {
trt_inputs.push_back(v);
}
create_engine_op = rewriter.create<CreateEngineOp>(
ods_loc,
::llvm::SmallVector<mlir::Type, 4>(1, EngineType::get()),
trt_inputs,
true /*run_once*/);
::mlir::Block *block = new ::mlir::Block;
block->getOperations().splice(block->begin(),
casted_op.getBody()->getOperations(),
casted_op.getBody()->begin(),
casted_op.getBody()->end());
create_engine_op.body().push_back(block);
// trt.execute
// outputs
::llvm::SmallVector<::mlir::Type, 4> execute_outputs_types;
for (auto v : casted_op.getODSResults(0)) {
execute_outputs_types.push_back(v.getType());
}
// inputs
::mlir::SmallVector<::mlir::Value, 4> execute_inputs(
create_engine_op.getODSResults(0));
for (auto v : inputs) {
execute_inputs.push_back(v);
}
auto execute_op = rewriter.create<ExecuteOp>(
ods_loc, execute_outputs_types, execute_inputs);
::llvm::SmallVector<::mlir::Value, 4> replace_values;
for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{execute_op.getODSResults(0)}) {
replace_values.push_back(v);
}
rewriter.replaceOp(op, replace_values);
return ::mlir::success();
}
};
void TRTOpConverterPass::runOnOperation() { void TRTOpConverterPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the // The first thing to define is the conversion target. This will define the
// final target for this lowering. // final target for this lowering.
...@@ -36,6 +88,7 @@ void TRTOpConverterPass::runOnOperation() { ...@@ -36,6 +88,7 @@ void TRTOpConverterPass::runOnOperation() {
// the set of patterns that will lower the TensorRT operations. // the set of patterns that will lower the TensorRT operations.
::mlir::RewritePatternSet patterns(&getContext()); ::mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns); populateWithGenerated(patterns);
patterns.add<PD2TRT_GraphLower>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
...@@ -23,27 +24,26 @@ namespace trt { ...@@ -23,27 +24,26 @@ namespace trt {
* trtOpConverterPass. * trtOpConverterPass.
* *
* source ir: * source ir:
* func @main() -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %a = "pd.feed"()... * %d, %f = "pd.graph"(%a) {
* %d, %f = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
* "Infrt.return" %n, %s * "infrt.return" (%n, %s)...
* } ... * } ...
* "pd.fetch" %d, %f * "infrt.return" (%d, %f)...
* } * }
* *
* destination ir: * destination ir:
* func @main() -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %a = "pd.feed"()... * %engine = "trt.create_engine"(%a) ({
* %d, %f = "trt.create_engine"(%a) {
* %m = "trt.Convolution"(%a)... * %m = "trt.Convolution"(%a)...
* %n = "trt.Convolution"(%m)... * %n = "trt.Convolution"(%m)...
* %s = "trt.Convolution"(%a)... * %s = "trt.Convolution"(%a)...
* "Infrt.return" %n, %s * "infrt.return" (%n, %s)...
* } ... * }){run_once = true} ...
* "pd.fetch" %d, %f * %d, %f = "trt.execute"(%engine, %a)...
* "infrt.return" (%d, %f)...
* } * }
*/ */
struct TRTOpConverterPass struct TRTOpConverterPass
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/basic_kernels.h" #include "paddle/infrt/dialect/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
namespace infrt { namespace infrt {
...@@ -37,11 +38,11 @@ void TRTOpTellerPass::runOnFunction() { ...@@ -37,11 +38,11 @@ void TRTOpTellerPass::runOnFunction() {
if (::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op)) continue; if (::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op)) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op)) continue; if (::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op)) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op)) continue; if (::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op)) continue;
if (::llvm::dyn_cast_or_null<CreateEngineOp>(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue;
builder.setInsertionPoint(op); builder.setInsertionPoint(op);
auto loc = getFunction().getLoc(); auto loc = getFunction().getLoc();
auto graph_op = builder.create<CreateEngineOp>( auto graph_op = builder.create<mlir::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands(), true); loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values; ::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
for (auto v : for (auto v :
...@@ -54,7 +55,7 @@ void TRTOpTellerPass::runOnFunction() { ...@@ -54,7 +55,7 @@ void TRTOpTellerPass::runOnFunction() {
graph_op.body().push_back(block); graph_op.body().push_back(block);
op->moveBefore(block, block->begin()); op->moveBefore(block, block->begin());
builder.setInsertionPointToEnd(block); builder.setInsertionPointToEnd(block);
builder.create<::infrt::dialect::ReturnOp>(loc, op->getResults()); builder.create<::infrt::ReturnOp>(loc, op->getResults());
} }
} }
} // namespace trt } // namespace trt
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include <mlir/Pass/Pass.h> #include <mlir/Pass/Pass.h>
#include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -26,30 +25,28 @@ namespace trt { ...@@ -26,30 +25,28 @@ namespace trt {
* *
* source func: * source func:
* *
* func @main() -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %a = "pd.feed"()...
* %c = "pd.conv2d"(%a) ... * %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ... * %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ... * %f = "pd.conv2d"(%a) ...
* "pd.fetch" (%d, %f) * "infrt.return"(%d, %f) ...
* } * }
* *
* destination func: * destination func:
* func @main() -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %a = "pd.feed"()... * %c = "pd.graph"(%a) {
* %c = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "Infrt.return" (%m) * "infrt.return" (%m)
* } ... * } ...
* %d = "trt.create_engine"(%c) { * %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)... * %m = "pd.conv3d"(%c)...
* "Infrt.return" (%m) * "infrt.return" (%m)
* } ... * } ...
* %f = "trt.create_engine"(%a) { * %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "Infrt.return" (%m) * "infrt.return" (%m)
* } ... * } ...
* "pd.fetch" (%d, %f) * "infrt.return" (%d, %f)
* } * }
* TODO(winter-wang): Supplementary how to judge the operators can be supported * TODO(winter-wang): Supplementary how to judge the operators can be supported
* by tensorrt. * by tensorrt.
...@@ -57,9 +54,7 @@ namespace trt { ...@@ -57,9 +54,7 @@ namespace trt {
class TRTOpTellerPass class TRTOpTellerPass
: public mlir::PassWrapper<TRTOpTellerPass, mlir::FunctionPass> { : public mlir::PassWrapper<TRTOpTellerPass, mlir::FunctionPass> {
public: public:
void getDependentDialects(mlir::DialectRegistry &registry) const override { void getDependentDialects(mlir::DialectRegistry &registry) const override {}
registry.insert<TensorRTDialect, ::infrt::dialect::INFRTDialect>();
}
::llvm::StringRef getName() const override { return "trtOpTellerPass"; } ::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override; void runOnFunction() override;
}; };
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include <mlir/IR/DialectImplementation.h> #include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Matchers.h> #include <mlir/IR/Matchers.h>
...@@ -19,11 +18,20 @@ ...@@ -19,11 +18,20 @@
#include <mlir/IR/PatternMatch.h> #include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/CallInterfaces.h> #include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/tensorrt/trt_dilaect_types.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
EngineType EngineType::get() {
return Base::get(::infrt::Global::getMLIRContext());
}
EngineType EngineType::get(mlir::MLIRContext *context) {
return Base::get(context);
}
TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context) TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context)
: mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) { : mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) {
addTypes<EngineType>(); addTypes<EngineType>();
......
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
#include <mlir/Interfaces/LoopLikeInterface.h> #include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/basic_kernels.h" #include "paddle/infrt/dialect/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/pd_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
......
...@@ -7,14 +7,24 @@ include "mlir/Interfaces/CallInterfaces.td" ...@@ -7,14 +7,24 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/tensorrt/trt_op_base.td" include "paddle/infrt/dialect/tensorrt/trt_op_base.td"
def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::dialect::ReturnOp">]> {
let summary = "trt Graph Op"; def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "trt CreateEngine Op";
let description = [{ let description = [{
Describe a tensorrt subgraph. Describe a tensorrt subgraph.
}]; }];
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<TRT_Tensor>:$inputs, DefaultValuedAttr<BoolAttr, "true">:$run_once); let arguments = (ins Variadic<TRT_Tensor>:$inputs, DefaultValuedAttr<BoolAttr, "true">:$run_once);
let results = (outs Variadic<TRT_Tensor>:$outputs); let results = (outs TRT_EngineType:$output);
}
def TRT_ExecuteOp : TRT_Op<"execute", [NoSideEffect]> {
let summary = "trt execute Op";
let description = [{
Describe a tensorrt runtime.
}];
let arguments = (ins TRT_EngineType:$engine, Variadic<TRT_Tensor>:$inputs);
let results = (outs Variadic<TRT_Tensor>:$output);
} }
def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> {
......
// RUN: trt-exec %s // RUN: trt-exec %s
// CHECK-LABEL: @main // CHECK-LABEL: @main
func @main() -> tensor<?xf32> { func @main(%bias:tensor<?xf32>, %c:tensor<?xf32>, %b1:tensor<?xf32>, %b2:tensor<?xf32>, %bias1:tensor<?xf32>, %bias2:tensor<?xf32>) -> tensor<?xf32> {
%bias = "pd.feed"() {name="input0"} : () -> tensor<?xf32>
%c = "pd.feed"() {name="input1"} : () -> tensor<?xf32>
%b1 = "pd.feed"() {name="input2"} : () -> tensor<?xf32>
%b2 = "pd.feed"() {name="input3"} : () -> tensor<?xf32>
%bias1 = "pd.feed"() {name="input4"} : () -> tensor<?xf32>
%bias2 = "pd.feed"() {name="input5"} : () -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32> %e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
...@@ -19,5 +12,5 @@ func @main() -> tensor<?xf32> { ...@@ -19,5 +12,5 @@ func @main() -> tensor<?xf32> {
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32> %e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->() "infrt.return"(%e2) : (tensor<?xf32>)->()
} }
...@@ -33,7 +33,7 @@ def PD_ReturnOp : PD_Op<"return", [Terminator]> { ...@@ -33,7 +33,7 @@ def PD_ReturnOp : PD_Op<"return", [Terminator]> {
let arguments = (ins Variadic<PD_Tensor>:$inputs); let arguments = (ins Variadic<PD_Tensor>:$inputs);
} }
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> { def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "paddle graph Op"; let summary = "paddle graph Op";
let description = [{ let description = [{
Describe a paddle graph or subgraph. Describe a paddle graph or subgraph.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册