未验证 提交 bf93050c 编写于 作者: 王明冬 提交者: GitHub

[infrt] move graph op from pd dialect to infrt dialect. (#41003)

上级 29d2e949
......@@ -9,6 +9,16 @@ class Infrt_Op<string mnemonic, list<OpTrait> traits = []> : Op<Infrt_Dialect, m
// let parser = [{ return infrt::parse$cppClass(parser, result); }];
}
def PD_GraphOp : Infrt_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<AnyType>:$inputs);
let results = (outs Variadic<AnyType>:$outputs);
}
def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> {
let summary = "kernel op";
let description = [{kernel op!}];
......
......@@ -55,8 +55,8 @@ bool reverseDfs(std::vector<mlir::Operation *> source,
// merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
infrt::pd::GraphOp first,
infrt::pd::GraphOp second) {
::infrt::GraphOp first,
::infrt::GraphOp second) {
// comput inputs and outputs
::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (mlir::Value input : second.getOperands()) {
......@@ -85,7 +85,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
// create the new graph op
builder.setInsertionPoint(first);
auto loc = first.getLoc();
auto graph_op = builder.create<infrt::pd::GraphOp>(loc, return_types, inputs);
auto graph_op = builder.create<::infrt::GraphOp>(loc, return_types, inputs);
mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(),
......@@ -150,13 +150,13 @@ void TRTGraphFusePass::runOnFunction() {
do {
changed = false;
for (auto &op : body) {
infrt::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(&op);
::infrt::GraphOp graph_op =
::llvm::dyn_cast_or_null<::infrt::GraphOp>(&op);
if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) {
infrt::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(user_op);
::infrt::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<::infrt::GraphOp>(user_op);
if (nullptr == user_graph_op) continue;
// get all dst input nodes except src.
std::vector<mlir::Operation *> source_nodes;
......
......@@ -25,15 +25,15 @@ namespace trt {
* source func:
*
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "pd.graph"(%a) {
* %c = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* infrt.return %m...
* } ...
* %d = "pd.graph"(%c) {
* %d = "infrt.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* infrt.return %m...
* } ...
* %f = "pd.graph"(%a) {
* %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* infrt.return %m...
* } ...
......@@ -42,7 +42,7 @@ namespace trt {
*
* destination func:
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "pd.graph"(%a) {
* %d, %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
......
......@@ -21,18 +21,17 @@ namespace infrt {
namespace trt {
// Implementation of the trtGraphSplitPass。
void TRTGraphSplitPass::runOnFunction() {
std::vector<infrt::pd::GraphOp> worklist;
std::vector<::infrt::GraphOp> worklist;
mlir::Block& block = getFunction().front();
for (auto& op : block) {
infrt::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(&op);
::infrt::GraphOp graph_op = ::llvm::dyn_cast_or_null<::infrt::GraphOp>(&op);
if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op);
}
}
while (!worklist.empty()) {
infrt::pd::GraphOp graph_op = worklist.back();
::infrt::GraphOp graph_op = worklist.back();
worklist.pop_back();
mlir::Block* body = graph_op.getBody();
auto return_op = body->getTerminator();
......
......@@ -26,7 +26,7 @@ namespace trt {
* source func:
*
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "pd.graph"(%a) {
* %d, %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
......
......@@ -41,14 +41,15 @@ namespace trt {
#endif // INFRT_WITH_TRT
template <typename T>
::mlir::IntegerAttr createNvinferEnumAttr(::mlir::PatternRewriter &rewriter,
::mlir::IntegerAttr createNvinferEnumAttr(
::mlir::PatternRewriter &rewriter, // NOLINT
T enum_value) {
return rewriter.getSI32IntegerAttr((int32_t)enum_value);
}
template <>
::mlir::IntegerAttr createNvinferEnumAttr<std::string>(
::mlir::PatternRewriter &rewriter, std::string enum_value) {
::mlir::PatternRewriter &rewriter, std::string enum_value) { // NOLINT
(void)enum_value;
return rewriter.getSI32IntegerAttr(-1);
}
......@@ -57,10 +58,11 @@ template <>
struct PD2TRT_GraphLower : public ::mlir::RewritePattern {
explicit PD2TRT_GraphLower(::mlir::MLIRContext *context)
: ::mlir::RewritePattern("pd.graph", 1, context, {"trt.create_engine"}) {}
: ::mlir::RewritePattern(
"infrt.graph", 1, context, {"trt.create_engine"}) {}
::mlir::LogicalResult matchAndRewrite(
::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override {
auto casted_op = ::llvm::dyn_cast<infrt::pd::GraphOp>(op);
auto casted_op = ::llvm::dyn_cast<::infrt::GraphOp>(op);
::mlir::Operation::operand_range inputs = casted_op.inputs();
auto ods_loc = rewriter.getFusedLoc(op->getLoc());
CreateEngineOp create_engine_op;
......
......@@ -25,7 +25,7 @@ namespace trt {
*
* source ir:
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "pd.graph"(%a) {
* %d, %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
......
......@@ -40,12 +40,12 @@ void TRTOpTellerPass::runOnFunction() {
if (op->getName().getStringRef().substr(0, 3) != "pd.") continue;
if (::llvm::dyn_cast_or_null<infrt::pd::FeedOp>(op)) continue;
if (::llvm::dyn_cast_or_null<infrt::pd::FetchOp>(op)) continue;
if (::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(op)) continue;
if (::llvm::dyn_cast_or_null<::infrt::GraphOp>(op)) continue;
if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue;
builder.setInsertionPoint(op);
auto loc = getFunction().getLoc();
auto graph_op = builder.create<infrt::pd::GraphOp>(
auto graph_op = builder.create<::infrt::GraphOp>(
loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
......
......@@ -33,15 +33,15 @@ namespace trt {
*
* destination func:
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "pd.graph"(%a) {
* %c = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* infrt.return %m:...
* } ...
* %d = "pd.graph"(%c) {
* %d = "infrt.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* infrt.return %m:...
* } ...
* %f = "pd.graph"(%a) {
* %f = "infrt.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* infrt.return %m:...
* } ...
......
......@@ -23,16 +23,6 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let arguments = (ins PD_Tensor :$inputs, StrAttr:$name);
}
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<PD_Tensor>:$inputs);
let results = (outs Variadic<PD_Tensor>:$outputs);
}
def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods<InferTypeOpInterface>, AllTypesMatch<["value", "output"]>]> {
let summary = "constant Op";
let description = [{}];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册