未验证 提交 e72ef603 编写于 作者: S Shang Zhizhou 提交者: GitHub

Add trt execute (#40224)

* add trt.execute

* merge trt.engine type

* update return op

* update comments

* fix style

* fix style
上级 99fc1b08
......@@ -18,6 +18,22 @@ def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> {
let results = (outs Variadic<AnyType>);
}
def Infrt_ReturnOp : Infrt_Op<"return", [Terminator]> {
let summary = "host executor return operation";
let description = [{
The "infrt.return" operation represents a return operation within a function.
func @foo() : (i32, f8) {
infrt.return %0, %1 : i32, f8
}
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<(ins),
[{ build($_builder, $_state, llvm::None); }]>];
}
def Infrt_CvtTensorOp : Infrt_Op<"cvt_tensor", [NoSideEffect]> {
let summary = "convert tensor type op";
let description = [{convert tensor type op!}];
......
......@@ -16,7 +16,6 @@
#include <mlir/IR/Matchers.h>
#include <mlir/IR/PatternMatch.h>
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/infrt_base.h"
#define GET_OP_CLASSES
......
......@@ -28,6 +28,7 @@
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
namespace mlir {
namespace pd {
......
......@@ -23,6 +23,8 @@ class EngineType
: public mlir::Type::TypeBase<EngineType, mlir::Type, mlir::TypeStorage> {
public:
using Base::Base;
static EngineType get();
static EngineType get(mlir::MLIRContext *context);
};
} // namespace trt
......
......@@ -53,9 +53,9 @@ bool reverseDfs(std::vector<mlir::Operation *> source,
}
// merge the first&second graph op to a new graph op.
void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT
CreateEngineOp first,
CreateEngineOp second) {
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
mlir::pd::GraphOp first,
mlir::pd::GraphOp second) {
// comput inputs and outputs
::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (mlir::Value input : second.getOperands()) {
......@@ -84,8 +84,7 @@ void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT
// create the new graph op
builder.setInsertionPoint(first);
auto loc = first.getLoc();
auto graph_op =
builder.create<CreateEngineOp>(loc, return_types, inputs, true);
auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(),
......@@ -98,7 +97,7 @@ void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT
copy_range.begin(),
copy_range.end());
builder.setInsertionPointToEnd(block);
builder.create<::infrt::dialect::ReturnOp>(loc, outputs);
builder.create<::infrt::ReturnOp>(loc, outputs);
graph_op.body().push_back(block);
// mapping the output
......@@ -150,12 +149,13 @@ void TRTGraphFusePass::runOnFunction() {
do {
changed = false;
for (auto &op : body) {
CreateEngineOp graph_op = ::llvm::dyn_cast_or_null<CreateEngineOp>(&op);
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) {
CreateEngineOp user_graph_op =
::llvm::dyn_cast_or_null<CreateEngineOp>(user_op);
mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
if (nullptr == user_graph_op) continue;
// get all dst input nodes except src.
std::vector<mlir::Operation *> source_nodes;
......@@ -168,7 +168,7 @@ void TRTGraphFusePass::runOnFunction() {
// Reverse DFS from the source_nodes.
if (!reverseDfs(source_nodes,
[&op](const mlir::Operation *n) { return n == &op; })) {
mergeTwoAdjacentCreateEngineOp(builder, graph_op, user_graph_op);
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true;
break;
}
......
......@@ -15,7 +15,6 @@
#pragma once
#include <mlir/Pass/Pass.h>
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -26,40 +25,37 @@ namespace trt {
*
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %c = "trt.create_engine"(%a) {
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "Infrt.return" %m
* "infrt.return" (%m)
* } ...
* %d = "trt.create_engine"(%c) {
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "Infrt.return" %m
* "infrt.return" (%m)
* } ...
* %f = "trt.create_engine"(%a) {
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "Infrt.return" %m
* "infrt.return" (%m)
* } ...
* "pd.fetch" %d, %f
* "infrt.return" (%d, %f)..
* }
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "trt.create_engine"(%a) {
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "Infrt.return" %n, %s
* "infrt.return" (%n, %s)
* } ...
* "pd.fetch" %d, %f
* "infrt.return" (%d, %f)
* }
*/
class TRTGraphFusePass
: public mlir::PassWrapper<TRTGraphFusePass, mlir::FunctionPass> {
public:
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<TensorRTDialect, ::infrt::dialect::INFRTDialect>();
}
void getDependentDialects(mlir::DialectRegistry &registry) const override {}
::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override;
};
......
......@@ -16,23 +16,23 @@
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
// Implementation of the trtGraphSplitPass。
void TRTGraphSplitPass::runOnFunction() {
std::vector<CreateEngineOp> worklist;
std::vector<mlir::pd::GraphOp> worklist;
mlir::Block& block = getFunction().front();
for (auto& op : block) {
CreateEngineOp graph_op = ::llvm::dyn_cast_or_null<CreateEngineOp>(&op);
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op);
}
}
while (!worklist.empty()) {
CreateEngineOp graph_op = worklist.back();
mlir::pd::GraphOp graph_op = worklist.back();
worklist.pop_back();
mlir::Block* body = graph_op.getBody();
auto return_op = body->getTerminator();
......
......@@ -15,7 +15,6 @@
#pragma once
#include <mlir/Pass/Pass.h>
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -27,33 +26,29 @@ namespace trt {
*
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "trt.create_engine"(%a) {
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "Infrt.return" (%n, %s)
* "infrt.return" (%n, %s)...
* } ...
* "pd.fetch" (%d, %f)
* "infrt.return" (%d, %f)...
* }
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" (%d, %f)
* "infrt.return" (%d, %f)...
* }
*/
class TRTGraphSplitPass
: public mlir::PassWrapper<TRTGraphSplitPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<TensorRTDialect, ::infrt::dialect::INFRTDialect>();
}
void getDependentDialects(mlir::DialectRegistry &registry) const override {}
void runOnFunction() override;
explicit TRTGraphSplitPass(size_t min_subgraph_size = 3)
: min_subgraph_size_(min_subgraph_size) {}
......
......@@ -16,12 +16,64 @@
#include <mlir/Transforms/DialectConversion.h>
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h"
namespace infrt {
namespace trt {
#include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT
struct PD2TRT_GraphLower : public ::mlir::RewritePattern {
PD2TRT_GraphLower(::mlir::MLIRContext *context)
: ::mlir::RewritePattern("pd.graph", 1, context, {"trt.create_engine"}) {}
::mlir::LogicalResult matchAndRewrite(
::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override {
auto casted_op = ::llvm::dyn_cast<mlir::pd::GraphOp>(op);
::mlir::Operation::operand_range inputs = casted_op.inputs();
auto ods_loc = rewriter.getFusedLoc(op->getLoc());
CreateEngineOp create_engine_op;
// inputs
::mlir::SmallVector<::mlir::Value, 4> trt_inputs;
for (auto v : inputs) {
trt_inputs.push_back(v);
}
create_engine_op = rewriter.create<CreateEngineOp>(
ods_loc,
::llvm::SmallVector<mlir::Type, 4>(1, EngineType::get()),
trt_inputs,
true /*run_once*/);
::mlir::Block *block = new ::mlir::Block;
block->getOperations().splice(block->begin(),
casted_op.getBody()->getOperations(),
casted_op.getBody()->begin(),
casted_op.getBody()->end());
create_engine_op.body().push_back(block);
// trt.execute
// outputs
::llvm::SmallVector<::mlir::Type, 4> execute_outputs_types;
for (auto v : casted_op.getODSResults(0)) {
execute_outputs_types.push_back(v.getType());
}
// inputs
::mlir::SmallVector<::mlir::Value, 4> execute_inputs(
create_engine_op.getODSResults(0));
for (auto v : inputs) {
execute_inputs.push_back(v);
}
auto execute_op = rewriter.create<ExecuteOp>(
ods_loc, execute_outputs_types, execute_inputs);
::llvm::SmallVector<::mlir::Value, 4> replace_values;
for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{execute_op.getODSResults(0)}) {
replace_values.push_back(v);
}
rewriter.replaceOp(op, replace_values);
return ::mlir::success();
}
};
void TRTOpConverterPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
......@@ -36,6 +88,7 @@ void TRTOpConverterPass::runOnOperation() {
// the set of patterns that will lower the TensorRT operations.
::mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
patterns.add<PD2TRT_GraphLower>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
......
......@@ -15,6 +15,7 @@
#pragma once
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
......@@ -23,27 +24,26 @@ namespace trt {
* trtOpConverterPass.
*
* source ir:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "trt.create_engine"(%a) {
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "Infrt.return" %n, %s
* "infrt.return" (%n, %s)...
* } ...
* "pd.fetch" %d, %f
* "infrt.return" (%d, %f)...
* }
*
* destination ir:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "trt.create_engine"(%a) {
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %engine = "trt.create_engine"(%a) ({
* %m = "trt.Convolution"(%a)...
* %n = "trt.Convolution"(%m)...
* %s = "trt.Convolution"(%a)...
* "Infrt.return" %n, %s
* } ...
* "pd.fetch" %d, %f
* "infrt.return" (%n, %s)...
* }){run_once = true} ...
* %d, %f = "trt.execute"(%engine, %a)...
* "infrt.return" (%d, %f)...
* }
*/
struct TRTOpConverterPass
......
......@@ -16,6 +16,7 @@
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/pd_ops.h"
namespace infrt {
......@@ -37,11 +38,11 @@ void TRTOpTellerPass::runOnFunction() {
if (::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op)) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op)) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op)) continue;
if (::llvm::dyn_cast_or_null<CreateEngineOp>(op)) continue;
if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue;
builder.setInsertionPoint(op);
auto loc = getFunction().getLoc();
auto graph_op = builder.create<CreateEngineOp>(
loc, op->getResultTypes(), op->getOperands(), true);
auto graph_op = builder.create<mlir::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
for (auto v :
......@@ -54,7 +55,7 @@ void TRTOpTellerPass::runOnFunction() {
graph_op.body().push_back(block);
op->moveBefore(block, block->begin());
builder.setInsertionPointToEnd(block);
builder.create<::infrt::dialect::ReturnOp>(loc, op->getResults());
builder.create<::infrt::ReturnOp>(loc, op->getResults());
}
}
} // namespace trt
......
......@@ -15,7 +15,6 @@
#pragma once
#include <mlir/Pass/Pass.h>
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -26,30 +25,28 @@ namespace trt {
*
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" (%d, %f)
* "infrt.return"(%d, %f) ...
* }
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %c = "trt.create_engine"(%a) {
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "Infrt.return" (%m)
* "infrt.return" (%m)
* } ...
* %d = "trt.create_engine"(%c) {
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "Infrt.return" (%m)
* "infrt.return" (%m)
* } ...
* %f = "trt.create_engine"(%a) {
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "Infrt.return" (%m)
* "infrt.return" (%m)
* } ...
* "pd.fetch" (%d, %f)
* "infrt.return" (%d, %f)
* }
* TODO(winter-wang): Supplementary how to judge the operators can be supported
* by tensorrt.
......@@ -57,9 +54,7 @@ namespace trt {
class TRTOpTellerPass
: public mlir::PassWrapper<TRTOpTellerPass, mlir::FunctionPass> {
public:
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<TensorRTDialect, ::infrt::dialect::INFRTDialect>();
}
void getDependentDialects(mlir::DialectRegistry &registry) const override {}
::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override;
};
......
......@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Matchers.h>
......@@ -19,11 +18,20 @@
#include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/tensorrt/trt_dilaect_types.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h"
namespace infrt {
namespace trt {
EngineType EngineType::get() {
return Base::get(::infrt::Global::getMLIRContext());
}
EngineType EngineType::get(mlir::MLIRContext *context) {
return Base::get(context);
}
TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context)
: mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) {
addTypes<EngineType>();
......
......@@ -29,6 +29,8 @@
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/pd_ops.h"
namespace infrt {
namespace trt {
......
......@@ -7,14 +7,24 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/tensorrt/trt_op_base.td"
def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::dialect::ReturnOp">]> {
let summary = "trt Graph Op";
def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "trt CreateEngine Op";
let description = [{
Describe a tensorrt subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<TRT_Tensor>:$inputs, DefaultValuedAttr<BoolAttr, "true">:$run_once);
let results = (outs Variadic<TRT_Tensor>:$outputs);
let results = (outs TRT_EngineType:$output);
}
def TRT_ExecuteOp : TRT_Op<"execute", [NoSideEffect]> {
let summary = "trt execute Op";
let description = [{
Describe a tensorrt runtime.
}];
let arguments = (ins TRT_EngineType:$engine, Variadic<TRT_Tensor>:$inputs);
let results = (outs Variadic<TRT_Tensor>:$output);
}
def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> {
......
// RUN: trt-exec %s
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
%bias = "pd.feed"() {name="input0"} : () -> tensor<?xf32>
%c = "pd.feed"() {name="input1"} : () -> tensor<?xf32>
%b1 = "pd.feed"() {name="input2"} : () -> tensor<?xf32>
%b2 = "pd.feed"() {name="input3"} : () -> tensor<?xf32>
%bias1 = "pd.feed"() {name="input4"} : () -> tensor<?xf32>
%bias2 = "pd.feed"() {name="input5"} : () -> tensor<?xf32>
func @main(%bias:tensor<?xf32>, %c:tensor<?xf32>, %b1:tensor<?xf32>, %b2:tensor<?xf32>, %bias1:tensor<?xf32>, %bias2:tensor<?xf32>) -> tensor<?xf32> {
%d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
......@@ -19,5 +12,5 @@ func @main() -> tensor<?xf32> {
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
"infrt.return"(%e2) : (tensor<?xf32>)->()
}
......@@ -33,7 +33,7 @@ def PD_ReturnOp : PD_Op<"return", [Terminator]> {
let arguments = (ins Variadic<PD_Tensor>:$inputs);
}
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> {
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册