未验证 提交 acdf0663 编写于 作者: S Shang Zhizhou 提交者: GitHub

update pd_2_trt lower pass (#40019)

* update pd_2_trt lower pass

* update pd_2_trt lower pass

* update style

* udpate

* change trt.graph to trt.create_engine

* update comments

* update comments

* add test
上级 852a872f
......@@ -53,9 +53,9 @@ bool reverseDfs(std::vector<mlir::Operation *> source,
}
// merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
mlir::pd::GraphOp first,
mlir::pd::GraphOp second) {
void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT
CreateEngineOp first,
CreateEngineOp second) {
// comput inputs and outputs
::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (mlir::Value input : second.getOperands()) {
......@@ -84,7 +84,8 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
// create the new graph op
builder.setInsertionPoint(first);
auto loc = first.getLoc();
auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
auto graph_op =
builder.create<CreateEngineOp>(loc, return_types, inputs, true);
mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(),
......@@ -97,7 +98,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
copy_range.begin(),
copy_range.end());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::ReturnOp>(loc, outputs);
builder.create<::infrt::dialect::ReturnOp>(loc, outputs);
graph_op.body().push_back(block);
// mapping the output
......@@ -149,13 +150,12 @@ void TRTGraphFusePass::runOnFunction() {
do {
changed = false;
for (auto &op : body) {
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
CreateEngineOp graph_op = ::llvm::dyn_cast_or_null<CreateEngineOp>(&op);
if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) {
mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
CreateEngineOp user_graph_op =
::llvm::dyn_cast_or_null<CreateEngineOp>(user_op);
if (nullptr == user_graph_op) continue;
// get all dst input nodes except src.
std::vector<mlir::Operation *> source_nodes;
......@@ -168,7 +168,7 @@ void TRTGraphFusePass::runOnFunction() {
// Reverse DFS from the source_nodes.
if (!reverseDfs(source_nodes,
[&op](const mlir::Operation *n) { return n == &op; })) {
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
mergeTwoAdjacentCreateEngineOp(builder, graph_op, user_graph_op);
changed = true;
break;
}
......
......@@ -14,6 +14,8 @@
#pragma once
#include <mlir/Pass/Pass.h>
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -26,28 +28,28 @@ namespace trt {
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %c = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.return" %m
* "Infrt.return" %m
* } ...
* %d = "pd.graph"(%c) {
* %d = "trt.create_engine"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.return" %m
* "Infrt.return" %m
* } ...
* %f = "pd.graph"(%a) {
* %f = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.return" %m
* "Infrt.return" %m
* } ...
* "pd.fetch" %d, %f
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %d, %f = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.return" %n, %s
* "Infrt.return" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
......@@ -55,6 +57,9 @@ namespace trt {
class TRTGraphFusePass
: public mlir::PassWrapper<TRTGraphFusePass, mlir::FunctionPass> {
public:
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<TensorRTDialect, ::infrt::dialect::INFRTDialect>();
}
::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override;
};
......
......@@ -22,18 +22,17 @@ namespace infrt {
namespace trt {
// Implementation of the trtGraphSplitPass。
void TRTGraphSplitPass::runOnFunction() {
std::vector<mlir::pd::GraphOp> worklist;
std::vector<CreateEngineOp> worklist;
mlir::Block& block = getFunction().front();
for (auto& op : block) {
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
CreateEngineOp graph_op = ::llvm::dyn_cast_or_null<CreateEngineOp>(&op);
if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op);
}
}
while (!worklist.empty()) {
mlir::pd::GraphOp graph_op = worklist.back();
CreateEngineOp graph_op = worklist.back();
worklist.pop_back();
mlir::Block* body = graph_op.getBody();
auto return_op = body->getTerminator();
......
......@@ -14,6 +14,8 @@
#pragma once
#include <mlir/Pass/Pass.h>
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -27,11 +29,11 @@ namespace trt {
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %d, %f = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.return" (%n, %s)
* "Infrt.return" (%n, %s)
* } ...
* "pd.fetch" (%d, %f)
* }
......@@ -49,6 +51,9 @@ class TRTGraphSplitPass
: public mlir::PassWrapper<TRTGraphSplitPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<TensorRTDialect, ::infrt::dialect::INFRTDialect>();
}
void runOnFunction() override;
explicit TRTGraphSplitPass(size_t min_subgraph_size = 3)
: min_subgraph_size_(min_subgraph_size) {}
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"
#include <mlir/IR/Builders.h>
#include <mlir/Transforms/DialectConversion.h>
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/pd_ops.h"
......@@ -22,12 +22,10 @@ namespace trt {
#include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT
using namespace mlir;
void TRTOpConverterPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
::mlir::ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to TensorRTDialect from
......@@ -36,13 +34,13 @@ void TRTOpConverterPass::runOnOperation() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the TensorRT operations.
RewritePatternSet patterns(&getContext());
::mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(
if (::mlir::failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
......
......@@ -25,11 +25,11 @@ namespace trt {
* source ir:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %d, %f = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.return" %n, %s
* "Infrt.return" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
......@@ -37,11 +37,11 @@ namespace trt {
* destination ir:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %d, %f = "trt.create_engine"(%a) {
* %m = "trt.Convolution"(%a)...
* %n = "trt.Convolution"(%m)...
* %s = "trt.Convolution"(%a)...
* "pd.return" %n, %s
* "Infrt.return" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
......
......@@ -15,6 +15,7 @@
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/basic_kernels.h"
#include "paddle/infrt/dialect/pd_ops.h"
namespace infrt {
......@@ -33,16 +34,14 @@ void TRTOpTellerPass::runOnFunction() {
auto *op = worklist.back();
worklist.pop_back();
if (op == nullptr) continue;
auto op1 = ::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op);
if (op1) continue;
auto op2 = ::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op);
if (op2) continue;
auto op3 = ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op);
if (op3) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op)) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op)) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op)) continue;
if (::llvm::dyn_cast_or_null<CreateEngineOp>(op)) continue;
builder.setInsertionPoint(op);
auto loc = getFunction().getLoc();
auto graph_op = builder.create<mlir::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands());
auto graph_op = builder.create<CreateEngineOp>(
loc, op->getResultTypes(), op->getOperands(), true);
::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
for (auto v :
......@@ -55,7 +54,7 @@ void TRTOpTellerPass::runOnFunction() {
graph_op.body().push_back(block);
op->moveBefore(block, block->begin());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::ReturnOp>(loc, op->getResults());
builder.create<::infrt::dialect::ReturnOp>(loc, op->getResults());
}
}
} // namespace trt
......
......@@ -14,6 +14,8 @@
#pragma once
#include <mlir/Pass/Pass.h>
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -35,17 +37,17 @@ namespace trt {
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %c = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.return" (%m)
* "Infrt.return" (%m)
* } ...
* %d = "pd.graph"(%c) {
* %d = "trt.create_engine"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.return" (%m)
* "Infrt.return" (%m)
* } ...
* %f = "pd.graph"(%a) {
* %f = "trt.create_engine"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.return" (%m)
* "Infrt.return" (%m)
* } ...
* "pd.fetch" (%d, %f)
* }
......@@ -55,6 +57,9 @@ namespace trt {
class TRTOpTellerPass
: public mlir::PassWrapper<TRTOpTellerPass, mlir::FunctionPass> {
public:
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<TensorRTDialect, ::infrt::dialect::INFRTDialect>();
}
::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override;
};
......
......@@ -28,6 +28,7 @@
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/basic_kernels.h"
namespace infrt {
namespace trt {
......
......@@ -7,25 +7,14 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/tensorrt/trt_op_base.td"
def TRT_FetchOp : TRT_Op<"fetch", [Terminator]> {
let summary = "TensorRT engine return operation";
let description = [{
The `trt.fetch` operation terminates and returns values for the
`trt.graph` operation.
}];
let arguments = (ins Variadic<TRT_Tensor>:$inputs);
}
def TRT_GraphOp : TRT_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> {
def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::dialect::ReturnOp">]> {
let summary = "trt Graph Op";
let description = [{
Describe a tensorrt subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<TRT_Tensor>:$inputs);
let arguments = (ins Variadic<TRT_Tensor>:$inputs, DefaultValuedAttr<BoolAttr, "true">:$run_once);
let results = (outs Variadic<TRT_Tensor>:$outputs);
}
def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> {
......
// RUN: trt-exec %s
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
%bias = "pd.feed"() {name="input0"} : () -> tensor<?xf32>
......
......@@ -21,10 +21,11 @@ build_dir = "@CMAKE_BINARY_DIR@"
config.llvm_tools_dir = os.path.join(build_dir, "third_party/install/llvm/bin")
config.llvm_tools_dir = os.path.join(build_dir, "/third_party/install/llvm/lib")
infrtopt_bin = os.path.join(build_dir, "paddle/infrt/dialect/")
trtexec_bin = os.path.join(build_dir, "paddle/infrt/dialect/tensorrt/")
infrtexec_bin = os.path.join(build_dir, "paddle/infrt/host_context/")
llvm_bin = os.path.join(build_dir, "third_party/install/llvm/bin/")
config.environment['PATH'] = os.path.pathsep.join(
(infrtopt_bin, infrtexec_bin, llvm_bin, config.environment['PATH']))
(infrtopt_bin, infrtexec_bin, trtexec_bin, llvm_bin, config.environment['PATH']))
config.suffixes = ['.mlir']
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册