未验证 提交 44112817 编写于 作者: 王明冬 提交者: GitHub

[infrt] add tensorrt op teller pass. test=develop (#38304)

上级 ddc15a18
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
set(PADDLE_INFRT_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_infrt_install_dir" CACHE STRING
set(INFRT_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_infrt_install_dir" CACHE STRING
"A path setting paddle infrt shared and static libraries")
function(copy TARGET)
......@@ -52,18 +52,17 @@ add_custom_target(infrt_lib_dist DEPENDS ${infrt_lib_deps})
# CMakeCache Info
copy(infrt_lib_dist
SRCS ${CMAKE_BINARY_DIR}/CMakeCache.txt
DSTS ${PADDLE_INFRT_INSTALL_DIR})
DSTS ${INFRT_INSTALL_DIR})
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/infrt")
set(paddle_infrt_lib ${PADDLE_BINARY_DIR}/paddle/infrt/libinfrt.*)
set(infrt_lib ${INFRT_BINARY_DIR}/libinfrt.*)
copy(infrt_lib_dist
SRCS ${src_dir}/api/infrt_api.h ${paddle_infrt_lib}
DSTS ${PADDLE_INFRT_INSTALL_DIR}/infrt/include ${PADDLE_INFRT_INSTALL_DIR}/infrt/lib)
SRCS ${INFRT_SOURCE_DIR}/api/infrt_api.h ${infrt_lib}
DSTS ${INFRT_INSTALL_DIR}/infrt/include ${INFRT_INSTALL_DIR}/infrt/lib)
copy(infrt_lib_dist
SRCS ${CMAKE_BINARY_DIR}/paddle/infrt/paddle/framework.pb.h
DSTS ${PADDLE_INFRT_INSTALL_DIR}/infrt/include/internal)
SRCS ${INFRT_BINARY_DIR}/paddle/framework.pb.h
DSTS ${INFRT_INSTALL_DIR}/infrt/include/internal)
# paddle fluid version
function(version version_file)
......@@ -74,4 +73,4 @@ function(version version_file)
file(WRITE ${version_file} "GIT COMMIT ID: ${PADDLE_GIT_COMMIT}\n")
file(APPEND ${version_file} "CXX compiler version: ${CMAKE_CXX_COMPILER_VERSION}\n")
endfunction()
version(${PADDLE_INFRT_INSTALL_DIR}/version.txt)
version(${INFRT_INSTALL_DIR}/version.txt)
if (NOT WITH_INFRT)
return()
endif()
set(INFRT_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/infrt" )
set(INFRT_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/infrt" )
include(infrt_lib)
......@@ -74,6 +76,7 @@ set(infrt_mlir_incs
dense_tensor_inc
pd_ops_inc
rewrite_inc
trt_ops_inc
)
message(STATUS "infrt srcs:\n${infrt_src}")
......
......@@ -51,3 +51,5 @@ infrt_exec_check(test_infrt_tensor_type mlir_tests/tensor_type.mlir)
infrt_exec_check(test_infrt__basic mlir_tests/basic.mlir)
infrt_exec_check(test_infrt_benchmark mlir_tests/benchmark.mlir)
infrt_exec_check(test_infrt_mlir_dense_tensor mlir_tests/dense_tensor.mlir)
add_subdirectory(tensorrt)
......@@ -20,11 +20,6 @@
namespace mlir {
namespace pd {
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES
PaddleDialect::PaddleDialect(MLIRContext *context)
: Dialect("pd", context, TypeID::get<PaddleDialect>()) {
addOperations<
......
......@@ -53,5 +53,9 @@ class PaddleDialect : public Dialect {
}
};
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace pd
} // namespace mlir
......@@ -6,7 +6,7 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/pd_op_base.td"
def PD_FeedOp : PD_Op<"Feed", [NoSideEffect]> {
def PD_FeedOp : PD_Op<"feed", [NoSideEffect]> {
let summary = "Feed Op";
let description = [{
......@@ -21,6 +21,26 @@ def PD_FeedOp : PD_Op<"Feed", [NoSideEffect]> {
}];
}
def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let summary = "fetch Op";
let description = [{
Fetch tensor from the graph.
}];
let arguments = (ins Variadic<PD_Tensor>:$inputs);
}
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<PD_Tensor>:$inputs);
let results = (outs Variadic<PD_Tensor>:$outputs);
}
def PD_ConstantOp : PD_Op<"Constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods<InferTypeOpInterface>, AllTypesMatch<["value", "output"]>]> {
let summary = "constant Op";
let description = [{}];
......
core_gather_headers()
gather_srcs(infrt_src SRCS
trt_ops.cc
trt_op_teller_pass.cc
trt_graph_fuse_pass.cc
)
mlir_tablegen_on(trt_ops)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include <list>
#include <unordered_set>
#include <vector>
#include "llvm/ADT/SetVector.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
namespace {
// FlexibleDFS
// do reverse dfs. calls leave(node) after visiting all parents of node.
// Reference the function with the same name but defined in:
// paddle/fluid/framework/ir/subgraph_detector.cc.
void FlexibleDFS(const std::vector<::mlir::Operation *> &source,
const std::function<bool(const ::mlir::Operation *)> &leave) {
typedef struct {
::mlir::Operation *node;
bool leave;
} FNode;
std::vector<FNode> stack;
for (auto &node : source) {
stack.push_back(FNode{node, false});
}
std::unordered_set<const ::mlir::Operation *> visited;
while (!stack.empty()) {
auto fnode = stack.back();
stack.pop_back();
if (fnode.leave) {
if (leave && !leave(fnode.node)) return;
}
if (visited.count(fnode.node)) continue;
visited.insert(fnode.node);
if (leave) stack.push_back(FNode{fnode.node, true});
auto values = fnode.node->getOperands();
for (auto value : values) {
::mlir::Operation *node = value.getDefiningOp();
if (!visited.count(node)) {
stack.push_back(FNode{node, false});
}
}
}
}
// merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
::mlir::pd::GraphOp first,
::mlir::pd::GraphOp second) {
// comput inputs and outputs
::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs;
for (::mlir::Value input : second.getOperands()) {
if (input.getDefiningOp() != first) {
inputs.push_back(input);
}
}
::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping;
for (::mlir::Value output : first.getResults()) {
for (::mlir::Operation *user : output.getUsers()) {
if (user != second && user->getParentOp() != second) {
op_output_mapping[output] = outputs.size();
outputs.push_back(output);
break;
}
}
}
auto fetch_op = second.getBody()->getTerminator();
outputs.append(fetch_op->getOperands().begin(),
fetch_op->getOperands().end());
::llvm::SmallVector<::mlir::Type, 4> fetch_types;
for (auto value : outputs) {
fetch_types.push_back(value.getType());
}
// create the new graph op
builder.setInsertionPoint(first);
auto loc = first.getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs);
::mlir::Block *block = new ::mlir::Block;
auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(),
second.getBody()->getOperations(),
copy_range.begin(),
copy_range.end());
copy_range = first.getBody()->without_terminator();
block->getOperations().splice(block->begin(),
first.getBody()->getOperations(),
copy_range.begin(),
copy_range.end());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, outputs);
graph_op.body().push_back(block);
// mapping the output
unsigned int num_result = first.getNumResults();
fetch_op = first.getBody()->getTerminator();
for (unsigned int index = 0; index < num_result; ++index) {
auto origin_value = first.getResult(index);
if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
origin_value.replaceAllUsesWith(fetch_op->getOperand(index));
} else {
auto inner_value = fetch_op->getOperand(index);
auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
while (!origin_value.use_empty()) {
auto replace_value =
origin_value.use_begin()->getOwner()->getParentOp() == graph_op
? inner_value
: outer_value;
origin_value.use_begin()->set(replace_value);
}
}
}
second.replaceAllUsesWith(
graph_op.getResults().take_back(second.getNumResults()));
first.erase();
second.erase();
}
} // namespace
// Implementation of the trtGraphFusePass.
void trtGraphFusePass::runOnFunction() {
mlir::Block &body = getFunction().front();
::mlir::OpBuilder builder(&body, body.begin());
bool changed = false;
do {
changed = false;
for (auto &op : body) {
::mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) {
::mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op);
if (nullptr == user_graph_op) continue;
// get all dst input nodes except src.
std::vector<::mlir::Operation *> source_nodes;
for (auto operand : user_op->getOperands()) {
auto input = operand.getDefiningOp();
if (input != &op) {
source_nodes.push_back(input);
}
}
// Reverse DFS from the source_nodes.
bool have_excess_path = false;
FlexibleDFS(source_nodes,
[&have_excess_path, &op](const ::mlir::Operation *n) {
if (n == &op) {
have_excess_path = true;
return false;
}
return true;
});
if (!have_excess_path) {
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true;
break;
}
}
if (changed) break;
}
} while (changed);
}
} // namespace trt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
namespace infrt {
namespace trt {
/*
* trtGraphFusePass.
*
* Merge the adjacent graph op to a new graph op.
*
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* } ...
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.fetch" %m
* } ...
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* } ...
* "pd.fetch" %d, %f
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
*/
class trtGraphFusePass
: public ::mlir::PassWrapper<trtGraphFusePass, ::mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override;
};
} // namespace trt
} // namespace infrt
// This file defines some basic elements of Paddle(alias trt) dialect.
// We learned much from TensorFlow mlir dialect https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
#ifndef TRT_OP_BASE
#define TRT_OP_BASE
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def TRT_Dialect : Dialect {
let name = "trt";
let description = [{
The PaddlePaddle dialect.
This dialect contains the PaddlePaddle operators.
}];
let cppNamespace = "::infrt::trt";
}
class TRT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TRT_Dialect, mnemonic, traits>;
class TRT_PaddleAttr <string name, string description> :
Attr<CPred<"$_self.isa<mlir::trt::" # name # "Attr>()">,
"PaddlePaddle " # description # " attribute">;
//===----------------------------------------------------------------------===//
// PaddlePaddle type definitions
//===----------------------------------------------------------------------===//
def TRT_TRTDialectType : Type<CPred<"$_self.isa<mlir::trt::TRTType>()">, "PaddlePaddle type">;
class TRT_PaddleType <string name, string description> :
Type<CPred<"$_self.isa<mlir::trt::" # name #"Type>()">,
"Paddle " # description # " type">,
BuildableType<"getType<mlir::trt::" # name # "Type>()">;
//===----------------------------------------------------------------------===//
// Integer types
def TRT_Bool : AnyTypeOf<[I<1>], "bool">;
def TRT_Int8 : AnyTypeOf<[I8], "8-bit integer">;
def TRT_Int16 : AnyTypeOf<[I16], "16-bit integer">;
def TRT_Int32 : AnyTypeOf<[I32], "32-bit integer">;
def TRT_Int64 : AnyTypeOf<[I64], "64-bit integer">;
def TRT_UInt8 : AnyTypeOf<[UI<8>], "8-bit unsigned integer">;
def TRT_UInt16 : AnyTypeOf<[UI<16>], "16-bit unsigned integer">;
def TRT_UInt32 : AnyTypeOf<[UI<32>], "32-bit unsigned integer">;
def TRT_UInt64 : AnyTypeOf<[UI<64>], "64-bit unsigned integer">;
def TRT_SInt : AnyTypeOf<[TRT_Int8, TRT_Int16, TRT_Int32, TRT_Int64], "signed integer">;
def TRT_UInt : AnyTypeOf<[TRT_UInt8, TRT_UInt16, TRT_UInt32, TRT_UInt64], "unsigned integer">;
def TRT_Int : AnyTypeOf<[TRT_SInt, TRT_UInt], "integer">;
// Float types
def TRT_Float16 : AnyTypeOf<[F16], "16-bit float">;
def TRT_Float32 : AnyTypeOf<[F32], "32-bit float">;
def TRT_Float64 : AnyTypeOf<[F64], "64-bit float">;
def TRT_Float : AnyTypeOf<[TRT_Float16, TRT_Float32, TRT_Float64], "floating-point">;
// Tensor types
def TRT_ElementType : Type<Or<[TRT_Float.predicate,
TRT_Bool.predicate,
TRT_Int.predicate]>,
"trt.dtype">;
def TRT_Tensor : TensorOf<[TRT_ElementType]>;
#endif // TRT_OP_BASE
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
// Implementation of the trtOpTellerPass。
void trtOpTellerPass::runOnFunction() {
::mlir::Block &body = getFunction().front();
std::vector<::mlir::Operation *> worklist;
worklist.reserve(body.getOperations().size());
for (auto &op : body) {
worklist.push_back(&op);
}
// Build GraphOp.
::mlir::OpBuilder builder(&body, body.begin());
while (!worklist.empty()) {
auto *op = worklist.back();
worklist.pop_back();
if (op == nullptr) continue;
auto op1 = ::llvm::dyn_cast_or_null<::mlir::pd::FeedOp>(op);
if (op1) continue;
auto op2 = ::llvm::dyn_cast_or_null<::mlir::pd::FetchOp>(op);
if (op2) continue;
auto op3 = ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(op);
if (op3) continue;
builder.setInsertionPoint(op);
auto loc = getFunction().getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;
for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{graph_op.getODSResults(0)}) {
tblgen_repl_values.push_back(v);
}
op->replaceAllUsesWith(tblgen_repl_values);
// Build graph op.
::mlir::Block *block = new ::mlir::Block;
graph_op.body().push_back(block);
op->moveBefore(block, block->begin());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, op->getResults());
}
}
} // namespace trt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
namespace infrt {
namespace trt {
/*
* trtOpTellerPass.
*
* Pick out the operators supported by tensorrt and convert it to graph.
*
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* }
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* } ...
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.fetch" %m
* } ...
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* } ...
* "pd.fetch" %d, %f
* }
* TODO(winter-wang): Supplementary how to judge the operators can be supported
* by tensorrt.
*/
class trtOpTellerPass
: public ::mlir::PassWrapper<trtOpTellerPass, ::mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override;
};
} // namespace trt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace infrt {
namespace trt {
TensorRTDialect::TensorRTDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect("trt", context, ::mlir::TypeID::get<TensorRTDialect>()) {
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
>();
#undef GET_OP_LIST
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace infrt {
namespace trt {
class TensorRTDialect : public ::mlir::Dialect {
public:
explicit TensorRTDialect(::mlir::MLIRContext* context);
static llvm::StringRef getDialectNamespace() { return "trt"; }
};
// mlir bug。 can be removed safety when update mlir to llvm11.
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "dialect/tensorrt/trt_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
#ifndef TRT_OPS
#define TRT_OPS
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/tensorrt/trt_op_base.td"
def TRT_FetchOp : TRT_Op<"fetch", [Terminator]> {
let summary = "TensorRT engine return operation";
let description = [{
The `trt.fetch` operation terminates and returns values for the
`trt.graph` operation.
}];
let arguments = (ins Variadic<TRT_Tensor>:$inputs);
}
def TRT_GraphOp : TRT_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> {
let summary = "trt Graph Op";
let description = [{
Describe a tensorrt subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let results = (outs Variadic<TRT_Tensor>:$outputs);
}
#endif // TRT_OPS
......@@ -113,6 +113,9 @@ function main() {
infrt_gen_and_build ${parallel_number}
test_infrt
;;
build_only)
infrt_gen_and_build ${parallel_number}
;;
*)
print_usage
exit 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册