未验证 提交 bf93050c 编写于 作者: 王明冬 提交者: GitHub

[infrt] move graph op from pd dialect to infrt dialect. (#41003)

上级 29d2e949
...@@ -9,6 +9,16 @@ class Infrt_Op<string mnemonic, list<OpTrait> traits = []> : Op<Infrt_Dialect, m ...@@ -9,6 +9,16 @@ class Infrt_Op<string mnemonic, list<OpTrait> traits = []> : Op<Infrt_Dialect, m
// let parser = [{ return infrt::parse$cppClass(parser, result); }]; // let parser = [{ return infrt::parse$cppClass(parser, result); }];
} }
def PD_GraphOp : Infrt_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<AnyType>:$inputs);
let results = (outs Variadic<AnyType>:$outputs);
}
def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> { def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> {
let summary = "kernel op"; let summary = "kernel op";
let description = [{kernel op!}]; let description = [{kernel op!}];
......
...@@ -55,8 +55,8 @@ bool reverseDfs(std::vector<mlir::Operation *> source, ...@@ -55,8 +55,8 @@ bool reverseDfs(std::vector<mlir::Operation *> source,
// merge the first&second graph op to a new graph op. // merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
infrt::pd::GraphOp first, ::infrt::GraphOp first,
infrt::pd::GraphOp second) { ::infrt::GraphOp second) {
// comput inputs and outputs // comput inputs and outputs
::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs; ::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (mlir::Value input : second.getOperands()) { for (mlir::Value input : second.getOperands()) {
...@@ -85,7 +85,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT ...@@ -85,7 +85,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
// create the new graph op // create the new graph op
builder.setInsertionPoint(first); builder.setInsertionPoint(first);
auto loc = first.getLoc(); auto loc = first.getLoc();
auto graph_op = builder.create<infrt::pd::GraphOp>(loc, return_types, inputs); auto graph_op = builder.create<::infrt::GraphOp>(loc, return_types, inputs);
mlir::Block *block = new mlir::Block; mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator(); auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(), block->getOperations().splice(block->begin(),
...@@ -150,13 +150,13 @@ void TRTGraphFusePass::runOnFunction() { ...@@ -150,13 +150,13 @@ void TRTGraphFusePass::runOnFunction() {
do { do {
changed = false; changed = false;
for (auto &op : body) { for (auto &op : body) {
infrt::pd::GraphOp graph_op = ::infrt::GraphOp graph_op =
::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(&op); ::llvm::dyn_cast_or_null<::infrt::GraphOp>(&op);
if (nullptr == graph_op) continue; if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) { for (auto user_op : op.getUsers()) {
infrt::pd::GraphOp user_graph_op = ::infrt::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(user_op); ::llvm::dyn_cast_or_null<::infrt::GraphOp>(user_op);
if (nullptr == user_graph_op) continue; if (nullptr == user_graph_op) continue;
// get all dst input nodes except src. // get all dst input nodes except src.
std::vector<mlir::Operation *> source_nodes; std::vector<mlir::Operation *> source_nodes;
......
...@@ -25,15 +25,15 @@ namespace trt { ...@@ -25,15 +25,15 @@ namespace trt {
* source func: * source func:
* *
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "pd.graph"(%a) { * %c = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* infrt.return %m... * infrt.return %m...
* } ... * } ...
* %d = "pd.graph"(%c) { * %d = "infrt.graph"(%c) {
* %m = "pd.conv3d"(%c)... * %m = "pd.conv3d"(%c)...
* infrt.return %m... * infrt.return %m...
* } ... * } ...
* %f = "pd.graph"(%a) { * %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* infrt.return %m... * infrt.return %m...
* } ... * } ...
...@@ -42,7 +42,7 @@ namespace trt { ...@@ -42,7 +42,7 @@ namespace trt {
* *
* destination func: * destination func:
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "pd.graph"(%a) { * %d, %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
......
...@@ -21,18 +21,17 @@ namespace infrt { ...@@ -21,18 +21,17 @@ namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtGraphSplitPass。 // Implementation of the trtGraphSplitPass。
void TRTGraphSplitPass::runOnFunction() { void TRTGraphSplitPass::runOnFunction() {
std::vector<infrt::pd::GraphOp> worklist; std::vector<::infrt::GraphOp> worklist;
mlir::Block& block = getFunction().front(); mlir::Block& block = getFunction().front();
for (auto& op : block) { for (auto& op : block) {
infrt::pd::GraphOp graph_op = ::infrt::GraphOp graph_op = ::llvm::dyn_cast_or_null<::infrt::GraphOp>(&op);
::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(&op);
if (nullptr != graph_op && if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op); worklist.push_back(graph_op);
} }
} }
while (!worklist.empty()) { while (!worklist.empty()) {
infrt::pd::GraphOp graph_op = worklist.back(); ::infrt::GraphOp graph_op = worklist.back();
worklist.pop_back(); worklist.pop_back();
mlir::Block* body = graph_op.getBody(); mlir::Block* body = graph_op.getBody();
auto return_op = body->getTerminator(); auto return_op = body->getTerminator();
......
...@@ -26,7 +26,7 @@ namespace trt { ...@@ -26,7 +26,7 @@ namespace trt {
* source func: * source func:
* *
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "pd.graph"(%a) { * %d, %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
......
...@@ -41,14 +41,15 @@ namespace trt { ...@@ -41,14 +41,15 @@ namespace trt {
#endif // INFRT_WITH_TRT #endif // INFRT_WITH_TRT
template <typename T> template <typename T>
::mlir::IntegerAttr createNvinferEnumAttr(::mlir::PatternRewriter &rewriter, ::mlir::IntegerAttr createNvinferEnumAttr(
T enum_value) { ::mlir::PatternRewriter &rewriter, // NOLINT
T enum_value) {
return rewriter.getSI32IntegerAttr((int32_t)enum_value); return rewriter.getSI32IntegerAttr((int32_t)enum_value);
} }
template <> template <>
::mlir::IntegerAttr createNvinferEnumAttr<std::string>( ::mlir::IntegerAttr createNvinferEnumAttr<std::string>(
::mlir::PatternRewriter &rewriter, std::string enum_value) { ::mlir::PatternRewriter &rewriter, std::string enum_value) { // NOLINT
(void)enum_value; (void)enum_value;
return rewriter.getSI32IntegerAttr(-1); return rewriter.getSI32IntegerAttr(-1);
} }
...@@ -57,10 +58,11 @@ template <> ...@@ -57,10 +58,11 @@ template <>
struct PD2TRT_GraphLower : public ::mlir::RewritePattern { struct PD2TRT_GraphLower : public ::mlir::RewritePattern {
explicit PD2TRT_GraphLower(::mlir::MLIRContext *context) explicit PD2TRT_GraphLower(::mlir::MLIRContext *context)
: ::mlir::RewritePattern("pd.graph", 1, context, {"trt.create_engine"}) {} : ::mlir::RewritePattern(
"infrt.graph", 1, context, {"trt.create_engine"}) {}
::mlir::LogicalResult matchAndRewrite( ::mlir::LogicalResult matchAndRewrite(
::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override { ::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override {
auto casted_op = ::llvm::dyn_cast<infrt::pd::GraphOp>(op); auto casted_op = ::llvm::dyn_cast<::infrt::GraphOp>(op);
::mlir::Operation::operand_range inputs = casted_op.inputs(); ::mlir::Operation::operand_range inputs = casted_op.inputs();
auto ods_loc = rewriter.getFusedLoc(op->getLoc()); auto ods_loc = rewriter.getFusedLoc(op->getLoc());
CreateEngineOp create_engine_op; CreateEngineOp create_engine_op;
......
...@@ -25,7 +25,7 @@ namespace trt { ...@@ -25,7 +25,7 @@ namespace trt {
* *
* source ir: * source ir:
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "pd.graph"(%a) { * %d, %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
......
...@@ -40,12 +40,12 @@ void TRTOpTellerPass::runOnFunction() { ...@@ -40,12 +40,12 @@ void TRTOpTellerPass::runOnFunction() {
if (op->getName().getStringRef().substr(0, 3) != "pd.") continue; if (op->getName().getStringRef().substr(0, 3) != "pd.") continue;
if (::llvm::dyn_cast_or_null<infrt::pd::FeedOp>(op)) continue; if (::llvm::dyn_cast_or_null<infrt::pd::FeedOp>(op)) continue;
if (::llvm::dyn_cast_or_null<infrt::pd::FetchOp>(op)) continue; if (::llvm::dyn_cast_or_null<infrt::pd::FetchOp>(op)) continue;
if (::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::GraphOp>(op)) continue;
if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue;
builder.setInsertionPoint(op); builder.setInsertionPoint(op);
auto loc = getFunction().getLoc(); auto loc = getFunction().getLoc();
auto graph_op = builder.create<infrt::pd::GraphOp>( auto graph_op = builder.create<::infrt::GraphOp>(
loc, op->getResultTypes(), op->getOperands()); loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values; ::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
......
...@@ -33,15 +33,15 @@ namespace trt { ...@@ -33,15 +33,15 @@ namespace trt {
* *
* destination func: * destination func:
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> { * func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "pd.graph"(%a) { * %c = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* infrt.return %m:... * infrt.return %m:...
* } ... * } ...
* %d = "pd.graph"(%c) { * %d = "infrt.graph"(%c) {
* %m = "pd.conv3d"(%c)... * %m = "pd.conv3d"(%c)...
* infrt.return %m:... * infrt.return %m:...
* } ... * } ...
* %f = "pd.graph"(%a) { * %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* infrt.return %m:... * infrt.return %m:...
* } ... * } ...
......
...@@ -23,16 +23,6 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> { ...@@ -23,16 +23,6 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let arguments = (ins PD_Tensor :$inputs, StrAttr:$name); let arguments = (ins PD_Tensor :$inputs, StrAttr:$name);
} }
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<PD_Tensor>:$inputs);
let results = (outs Variadic<PD_Tensor>:$outputs);
}
def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods<InferTypeOpInterface>, AllTypesMatch<["value", "output"]>]> { def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods<InferTypeOpInterface>, AllTypesMatch<["value", "output"]>]> {
let summary = "constant Op"; let summary = "constant Op";
let description = [{}]; let description = [{}];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册