未验证 提交 b798fb07 编写于 作者: 王明冬 提交者: GitHub

[infrt] fold the infrt.cvtTensorOp. test=develop (#40214)

上级 79a32715
...@@ -100,8 +100,8 @@ endfunction() ...@@ -100,8 +100,8 @@ endfunction()
function(mlir_add_rewriter td_base) function(mlir_add_rewriter td_base)
set(LLVM_TARGET_DEFINITIONS ${td_base}.td) set(LLVM_TARGET_DEFINITIONS ${td_base}.td)
mlir_tablegen(${td_base}.cpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass") mlir_tablegen(${td_base}.cpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass")
add_public_tablegen_target(${td_base}_IncGen) add_public_tablegen_target(MLIR${td_base}IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) add_dependencies(mlir-headers MLIR${td_base}IncGen)
endfunction() endfunction()
# Execute the mlir script with infrt-exec program. # Execute the mlir script with infrt-exec program.
......
...@@ -95,9 +95,7 @@ set(infrt_mlir_incs ...@@ -95,9 +95,7 @@ set(infrt_mlir_incs
dense_tensor_inc dense_tensor_inc
pd_ops_inc pd_ops_inc
pd_extra_ops_inc pd_extra_ops_inc
rewrite_inc
trt_ops_inc trt_ops_inc
pd_lower_to_trt_inc
) )
if (INFRT_WITH_PHI) if (INFRT_WITH_PHI)
......
...@@ -13,3 +13,5 @@ mlir_tablegen(infrt_opsAttributes.h.inc -gen-attrdef-decls -dialect=infrt) ...@@ -13,3 +13,5 @@ mlir_tablegen(infrt_opsAttributes.h.inc -gen-attrdef-decls -dialect=infrt)
mlir_tablegen(infrt_opsAttributes.cpp.inc -gen-attrdef-defs -dialect=infrt) mlir_tablegen(infrt_opsAttributes.cpp.inc -gen-attrdef-defs -dialect=infrt)
add_public_tablegen_target(MLIRinfrt_opsAttributesIncGen) add_public_tablegen_target(MLIRinfrt_opsAttributesIncGen)
add_dependencies(mlir-headers MLIRinfrt_opsAttributesIncGen) add_dependencies(mlir-headers MLIRinfrt_opsAttributesIncGen)
add_subdirectory(pass)
core_gather_headers()
gather_srcs(infrt_src SRCS
infrt_op_fuse_pass.cc
)
mlir_add_rewriter(infrt_op_fuse)
#ifndef INFRT_OP_FUSE
#define INFRT_OP_FUSE
include "mlir/Interfaces/SideEffectInterfaces.td"
include "paddle/infrt/dialect/infrt/infrt_ops.td"
include "paddle/infrt/dialect/pd_ops.td"
def FuseCvtTensorPattern : Pat<
(Infrt_CvtTensorOp (Infrt_CvtTensorOp $arg)),
(Infrt_CvtTensorOp $arg)>;
def FuseFeedCvtTensorPattern : Pat<
(Infrt_CvtTensorOp (PD_FeedOp $name)),
(PD_FeedOp $name)>;
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
def RedundantCvtTensorOptPattern : Pat<
(Infrt_CvtTensorOp:$res $arg), (replaceWithValue $arg),
[(TypesAreIdentical $res, $arg)]>;
#endif // INFRT_OP_FUSE
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h"
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/pd_ops.h"
namespace {
#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse.cpp.inc" // NOLINT
/*
* infrtOpFusePass.
*/
struct InfrtOpFusePass
: public mlir::PassWrapper<InfrtOpFusePass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "infrtOpFusePass"; }
void runOnFunction() override;
};
// Implementation of the InfrtOpFusePass.
void InfrtOpFusePass::runOnFunction() {
::mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// Fuse pd.return Operation
auto terminator_op = getFunction().front().getTerminator();
if (nullptr == terminator_op) return;
for (auto operand : terminator_op->getOperands()) {
auto *op1 = operand.getDefiningOp();
auto cvt_op = ::llvm::dyn_cast<::infrt::CvtTensorOp>(op1);
if (!cvt_op) continue;
mlir::Value value = cvt_op.input();
operand.replaceAllUsesWith(value);
cvt_op.erase();
}
}
} // namespace
std::unique_ptr<mlir::Pass> infrt::createInfrtOpFusePass() {
return std::make_unique<InfrtOpFusePass>();
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/Pass/Pass.h>
namespace infrt {
/*
* infrtOpFusePass.
*/
std::unique_ptr<mlir::Pass> createInfrtOpFusePass();
} // namespace infrt
...@@ -75,7 +75,7 @@ def PD_ElementType : Type<Or<[PD_Float.predicate, ...@@ -75,7 +75,7 @@ def PD_ElementType : Type<Or<[PD_Float.predicate,
// def PD_Tensor : TensorOf<[PD_ElementType]>; // def PD_Tensor : TensorOf<[PD_ElementType]>;
def PD_Tensor1 : TensorOf<[PD_ElementType]>; def PD_Tensor1 : TensorOf<[PD_ElementType]>;
def PD_Tensor : AnyTypeOf<[PD_Tensor1, LoDTensor],"pd.ttype">; def PD_Tensor : AnyTypeOf<[PD_Tensor1, LoDTensor, DenseTensor],"pd.ttype">;
def PD_Tensor_Array : VectorOf<[PD_Tensor]>; def PD_Tensor_Array : VectorOf<[PD_Tensor]>;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h"
#include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h" #include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h"
...@@ -38,6 +39,7 @@ int main(int argc, char** argv) { ...@@ -38,6 +39,7 @@ int main(int argc, char** argv) {
infrt::PrecisionType::FLOAT32, infrt::PrecisionType::FLOAT32,
infrt::LayoutType::NCHW}}; infrt::LayoutType::NCHW}};
phi_pass_manager.addPass(std::make_unique<infrt::phiOpCvtPass>(valid_places)); phi_pass_manager.addPass(std::make_unique<infrt::phiOpCvtPass>(valid_places));
phi_pass_manager.addPass(infrt::createInfrtOpFusePass());
if (mlir::failed(pm.run(*module))) { if (mlir::failed(pm.run(*module))) {
std::cout << "\npass failed!\n" << std::endl; std::cout << "\npass failed!\n" << std::endl;
return 4; return 4;
......
def PD_FeedOp : PD_Op<"feed"> { def PD_FeedOp : PD_Op<"feed", [NoSideEffect]> {
let summary = "Feed Op"; let summary = "Feed Op";
let description = [{ let description = [{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册