// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" #include #include #include #include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/pd/ir/pd_ops.h" #include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/phi/ir/phi_base.h" #include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { #ifdef INFRT_WITH_TRT #define STRING_TO_ENUM_TYPE(enum_type) enum_type #define STRING_TO_ENUM_VALUE(enum_value) enum_value #include #else // INFRT_WITH_TRT #define STRING_TO_ENUM_TYPE(enum_type) std::string #define STRING_TO_ENUM_VALUE(enum_value) #enum_value #endif // INFRT_WITH_TRT template ::mlir::IntegerAttr createNvinferEnumAttr( ::mlir::PatternRewriter &rewriter, // NOLINT T enum_value) { return rewriter.getSI32IntegerAttr((int32_t)enum_value); } template <> ::mlir::IntegerAttr createNvinferEnumAttr( ::mlir::PatternRewriter &rewriter, std::string enum_value) { // NOLINT (void)enum_value; return rewriter.getSI32IntegerAttr(-1); } #include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT struct PD2TRT_GraphLower : public ::mlir::RewritePattern { explicit PD2TRT_GraphLower(::mlir::MLIRContext *context) : ::mlir::RewritePattern( "infrt.graph", 1, context, {"trt.create_engine"}) {} ::mlir::LogicalResult matchAndRewrite( ::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override { auto casted_op = ::llvm::dyn_cast<::infrt::GraphOp>(op); ::mlir::Operation::operand_range inputs = casted_op.inputs(); auto ods_loc = rewriter.getFusedLoc(op->getLoc()); CreateEngineOp create_engine_op; // inputs ::mlir::SmallVector<::mlir::Value, 4> trt_inputs; for (auto v : inputs) { trt_inputs.push_back(v); } create_engine_op = rewriter.create( ods_loc, ::llvm::SmallVector(1, EngineType::get()), trt_inputs, true /*run_once*/); auto &block = create_engine_op.body().emplaceBlock(); block.getOperations().splice(block.begin(), casted_op.getBody()->getOperations(), casted_op.getBody()->begin(), casted_op.getBody()->end()); // trt.compute ::llvm::SmallVector<::mlir::Value, 4> replace_values2; auto ctx_op = rewriter.create<::infrt::phi::CreateGPUContextOp>( ods_loc, infrt::phi::ContextType::get(rewriter.getContext(), infrt::TargetType::GPU)); auto compute_op = rewriter.create( ods_loc, ::infrt::DenseTensorListType::get(rewriter.getContext()), create_engine_op.engine(), ctx_op.output()); auto tensor_list_val = compute_op.outputs(); for (size_t i = 0; i < casted_op.getNumResults(); ++i) { auto res = casted_op->getResult(i); auto int_attr = mlir::IntegerAttr::get( mlir::IntegerType::get(rewriter.getContext(), 32), i); auto get_tensor_op = rewriter.create<::infrt::dt::TensorListGetTensorOp>( ods_loc, res.getType(), tensor_list_val, int_attr); replace_values2.push_back(get_tensor_op.output()); } ctx_op->moveBefore(ctx_op->getBlock(), ctx_op->getBlock()->begin()); rewriter.replaceOp(op, replace_values2); return ::mlir::success(); } }; void TRTOpConverterPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. ::mlir::ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to TensorRTDialect from // PaddleDialect target.addLegalDialect(); target.addLegalDialect<::infrt::phi::PHIDialect>(); target.addLegalDialect<::infrt::dt::DTDialect>(); target.addLegalDialect(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the TensorRT operations. ::mlir::RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); patterns.add(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (::mlir::failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } } // namespace trt } // namespace infrt