diff --git a/paddle/infrt/dialect/tensorrt/convert.h b/paddle/infrt/dialect/tensorrt/convert.h index fc607aa112714f3f736800726973757af30c1cb5..c1f87ecde787247dbb1ac0ee2888487cca67c985 100644 --- a/paddle/infrt/dialect/tensorrt/convert.h +++ b/paddle/infrt/dialect/tensorrt/convert.h @@ -14,17 +14,49 @@ #pragma once #include +#include +#include #include +#include +#include #include - #include "paddle/infrt/dialect/infrt/common/types.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/pd/ir/pd_ops.h" #include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" +#include "paddle/infrt/kernel/tensorrt/trt_helper.h" namespace infrt { namespace trt { + +#ifdef INFRT_WITH_TRT + +#define STRING_TO_ENUM_TYPE(enum_type) enum_type +#define STRING_TO_ENUM_VALUE(enum_value) enum_value +#include + +#else // INFRT_WITH_TRT + +#define STRING_TO_ENUM_TYPE(enum_type) std::string +#define STRING_TO_ENUM_VALUE(enum_value) #enum_value + +#endif // INFRT_WITH_TRT + +template +::mlir::IntegerAttr createNvinferEnumAttr( + ::mlir::PatternRewriter &rewriter, // NOLINT + T enum_value) { + return rewriter.getSI32IntegerAttr((int32_t)enum_value); +} + +template <> +::mlir::IntegerAttr createNvinferEnumAttr( + ::mlir::PatternRewriter &rewriter, std::string enum_value) { // NOLINT + (void)enum_value; + return rewriter.getSI32IntegerAttr(-1); +} + static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, // NOLINT mlir::Operation *op) { auto conv_op = ::llvm::dyn_cast(op); @@ -205,5 +237,127 @@ static mlir::Value createTRTShuffledOp( return rewriter.create( op->getLoc(), resultTypes, operands, attributes); } + +inline mlir::IntegerAttr CreatePoolingType( + mlir::PatternRewriter &builder, // NOLINT + mlir::StringAttr pool_type) { + // pool_type. + auto ptype = pool_type.str(); + if (ptype == "max") { + return createNvinferEnumAttr(builder, nvinfer1::PoolingType::kMAX); + } else if (ptype == "avg") { + return createNvinferEnumAttr(builder, nvinfer1::PoolingType::kAVERAGE); + } else { + llvm_unreachable("unknown pool_type."); + return {}; + } +} + +inline mlir::IntegerAttr CreatePaddingMode( + mlir::PatternRewriter &builder, // NOLINT + mlir::StringAttr padding_algorithm, + mlir::BoolAttr ceil_mode) { + // TODO(Inference): Phi pool kernel seems not process ceil_mode. + auto padding_algo = padding_algorithm.str(); + if (padding_algo == "SAME") { + return createNvinferEnumAttr(builder, nvinfer1::PaddingMode::kSAME_UPPER); + } + if (ceil_mode.getValue() && padding_algo != "SAME") { + return createNvinferEnumAttr(builder, + nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP); + } else { + return createNvinferEnumAttr(builder, + nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN); + } +} + +inline ::llvm::SmallVector<::mlir::Value, 4> CreatePaddleTrtPoolingOp( + mlir::PatternRewriter &builder, // NOLINT + mlir::Value input, + mlir::StringAttr pool_type, + mlir::ArrayAttr ksize, + mlir::BoolAttr global_pooling, + mlir::ArrayAttr strides, + mlir::ArrayAttr paddings, + mlir::BoolAttr exclusive, + mlir::BoolAttr adaptive, + mlir::BoolAttr ceil_mode, + mlir::StringAttr data_format, + mlir::StringAttr padding_algorithm) { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + + // TODO(inference): Support NHWC. + if (data_format.str() != "NCHW") { + CHECK(false) << "The pool2d converter now only support NCHW."; + } + + // TODO(Wilber): How to support dynamic shape? + + auto *input_producer = input.getDefiningOp(); + + // Process pool_type. + auto pool_type_attr = CreatePoolingType(builder, pool_type); + + // Update padding. + auto padding_algorithm_str = padding_algorithm.str(); + auto paddings_attr = paddings; + if (padding_algorithm_str == "EXPLICIT") { + // Do nothing on paddings. + } else if (padding_algorithm_str == "SAME") { + // We should process this case in trt network build phase. + } else if (padding_algorithm_str == "VALID") { + // Set padding to zero. + paddings_attr = builder.getI32ArrayAttr({0, 0}); + } else { + CHECK(false) << "Unknown padding_algotithm."; + } + + // if global_pooling == true or adaptive == true, padding will be ignored + if (global_pooling.getValue() || adaptive.getValue()) { + paddings_attr = builder.getI32ArrayAttr({0, 0}); + } + + // if global_pooling == true, then we should update kernel size to input dims. + if (global_pooling.getValue() == true) { + // Update ksize to input dims. + } + + // The adaptive logic should be processed when we get the context of + // INetworkDefinition, so we place the logic in infrt runtime(trt compile + // time). + + // The `exclusive` may be a naive attr, which can be forward to trt. + + auto padding_mode_attr = + CreatePaddingMode(builder, padding_algorithm, ceil_mode); + + if (global_pooling.getValue() == true) { + CHECK(false) << "Temporarily not support global_pool"; + return tblgen_repl_values; + } + + PoolingOp pool_op; + { + auto ods_loc = builder.getFusedLoc({input_producer->getLoc()}); + builder.create(ods_loc, + input.getType(), + input, + pool_type_attr, + ksize, + strides, + paddings_attr, + padding_mode_attr, + exclusive, + adaptive, + padding_algorithm); + } + + for (auto v : + ::llvm::SmallVector<::mlir::Value, 4>{pool_op.getODSResults(0)}) { + tblgen_repl_values.push_back(v); + } + return tblgen_repl_values; +} + } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td index ad60906ececbf9a170c34cae587897f42c365a6d..227b473c3fc1963c584acd2b904a406764b756e7 100644 --- a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td +++ b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td @@ -31,9 +31,10 @@ def PD2TRT_Conv2d_Lower : Pat< (PD_Conv2dOp:$old_value $Input, $Filter, $strides, $paddings, $padding_algorithm, $groups, $dilations, $data_format), (createTRTConv2dOp $old_value)>; +def createTrtPoolingOp : NativeCodeCall<"::infrt::trt::CreatePaddleTrtPoolingOp($_builder, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10)">; def PD2TRT_Pooling_Lower : Pat< (PD_Pool2dOp $Input, $pooling_type, $ksize, $global_pooling, $strides, $paddings, $exclusive, $adaptive, $ceil_mode, $data_format, $padding_algorithm), - (TRT_PoolingOp $Input, (INFRT_createI32Attr<"0">)/*kmax*/, $ksize, $strides, $paddings, $padding_algorithm)>; + (createTrtPoolingOp $Input, $pooling_type, $ksize, $global_pooling, $strides, $paddings, $exclusive, $adaptive, $ceil_mode, $data_format, $padding_algorithm)>; def PD2TRT_MatrixMultipl_Lower : Pat< (PD_MulOp $Input1, $Input2, $x_num_col_dims, $y_num_col_dims), diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index 6bcef3d913d79955aeadd1f234e16cf913755128..95dd31fcd5838676e0114869de6c46eb72ea2aeb 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -28,33 +28,6 @@ namespace infrt { namespace trt { -#ifdef INFRT_WITH_TRT - -#define STRING_TO_ENUM_TYPE(enum_type) enum_type -#define STRING_TO_ENUM_VALUE(enum_value) enum_value -#include - -#else // INFRT_WITH_TRT - -#define STRING_TO_ENUM_TYPE(enum_type) std::string -#define STRING_TO_ENUM_VALUE(enum_value) #enum_value - -#endif // INFRT_WITH_TRT - -template -::mlir::IntegerAttr createNvinferEnumAttr( - ::mlir::PatternRewriter &rewriter, // NOLINT - T enum_value) { - return rewriter.getSI32IntegerAttr((int32_t)enum_value); -} - -template <> -::mlir::IntegerAttr createNvinferEnumAttr( - ::mlir::PatternRewriter &rewriter, std::string enum_value) { // NOLINT - (void)enum_value; - return rewriter.getSI32IntegerAttr(-1); -} - #include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT struct PD2TRT_GraphLower : public ::mlir::RewritePattern { diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index 3fd3f377f4ec77457f4932534f7e58c554366b88..68a593e440b50f94f590b33cbbc4ae78022d8530 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -101,7 +101,10 @@ def TRT_PoolingOp : TRT_Op<"Pooling", [NoSideEffect]> { I32ArrayAttr:$window_size, I32ArrayAttr:$strides, I32ArrayAttr:$paddings, - StrAttr:$padding_mode + I32Attr:$padding_mode, + BoolAttr:$exclusive, + BoolAttr:$adaptive, + StrAttr:$padding_algorithm ); let results = (outs DenseTensor:$output_tensor diff --git a/paddle/infrt/kernel/tensorrt/trt_helper.h b/paddle/infrt/kernel/tensorrt/trt_helper.h index 96122bffacdb2251c28e311ae02fe6f9c5319615..13529430d683db7c722c4709eae4496589ae32b5 100644 --- a/paddle/infrt/kernel/tensorrt/trt_helper.h +++ b/paddle/infrt/kernel/tensorrt/trt_helper.h @@ -28,13 +28,13 @@ namespace infrt { namespace kernel { namespace tensorrt { -static nvinfer1::DataType TensorTypeToWeightType(phi::DataType tensor_type) { +static nvinfer1::DataType TensorTypeToWeightType(::phi::DataType tensor_type) { switch (tensor_type) { - case phi::DataType::FLOAT32: + case ::phi::DataType::FLOAT32: return nvinfer1::DataType::kFLOAT; - case phi::DataType::INT32: + case ::phi::DataType::INT32: return nvinfer1::DataType::kINT32; - case phi::DataType::FLOAT16: + case ::phi::DataType::FLOAT16: return nvinfer1::DataType::kHALF; default: llvm_unreachable("should not reach here"); @@ -52,7 +52,7 @@ static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) { return dims; } -static nvinfer1::Weights TensorToWeights(phi::DenseTensor* tensor) { +static nvinfer1::Weights TensorToWeights(::phi::DenseTensor* tensor) { CHECK_NOTNULL(tensor); nvinfer1::Weights ret; ret.type = TensorTypeToWeightType(tensor->dtype()); diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.cc b/paddle/infrt/kernel/tensorrt/trt_kernels.cc index a6d740f01846d8ea47cb188e79393b956c2d57b5..92e3a624bb02109c5f225a61870e8b9c1f46180a 100644 --- a/paddle/infrt/kernel/tensorrt/trt_kernels.cc +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.cc @@ -129,6 +129,7 @@ namespace tensorrt { // TODO(wilber): Find a way to add layer. for (auto& operation : block.without_terminator()) { + VLOG(1) << "process " << operation.getName().getStringRef().str() << " ..."; if (trt::ActivationOp op = llvm::dyn_cast(operation)) { ActivationFunc( op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); @@ -138,6 +139,8 @@ namespace tensorrt { } else if (trt::ConvolutionOp op = llvm::dyn_cast(operation)) { ConvFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); + } else if (trt::PoolingOp op = llvm::dyn_cast(operation)) { + PoolFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); } else { CHECK(false) << "not supported operation."; } diff --git a/paddle/infrt/kernel/tensorrt/trt_layers.h b/paddle/infrt/kernel/tensorrt/trt_layers.h index 19e20c170ec835444a5a37818b837dafb096b2b8..3a300ad0c10af2177b4fb880049bfffb8e3def10 100644 --- a/paddle/infrt/kernel/tensorrt/trt_layers.h +++ b/paddle/infrt/kernel/tensorrt/trt_layers.h @@ -15,13 +15,15 @@ #pragma once #include +#include +#include #include +#include #include #include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/kernel/tensorrt/trt_helper.h" - #include "paddle/phi/core/dense_tensor.h" namespace infrt { @@ -63,7 +65,12 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT nvinfer1::Dims dims = ArrayAttrToNvDims(size_attrs); auto kernel_weights = TensorToWeights(value_to_tensor_map[op.kernel_weights()]); - auto bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]); + nvinfer1::Weights bias_weights; + if (op.bias_weights() == mlir::Value()) { + bias_weights = nvinfer1::Weights{}; + } else { + bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]); + } auto* layer = network->addConvolutionNd(*value_to_trt_tensor_map[input_tensor_repr], @@ -77,6 +84,51 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT value_to_trt_tensor_map[out_repr] = out_tensor; } +inline void PoolFunc(trt::PoolingOp& op, // NOLINT + nvinfer1::INetworkDefinition* network, + ValueToITensorMap& value_to_trt_tensor_map, // NOLINT + ValueToTensorMap& value_to_tensor_map) { // NOLINT + mlir::Value input_tensor_repr = op.input_tensor(); + nvinfer1::ITensor* input_itensor = value_to_trt_tensor_map[input_tensor_repr]; + // nvinfer1::Dims input_shape = input_itensor->getDimensions(); + // int input_dims = input_shape.nbDims; + + auto padding_mode = op.padding_mode(); + auto pool_type = op.pool_type(); + mlir::ArrayAttr paddings = op.paddings(); + mlir::ArrayAttr strides = op.strides(); + mlir::ArrayAttr ksize = op.window_size(); + bool exclusive = op.exclusive(); + bool adaptive = op.adaptive(); + auto padding_algorithm = op.padding_algorithm().str(); + + if (padding_algorithm == "SAME") { + // TODO(wilber) + CHECK(false) << "Not supported `same` padding algorithm"; + } + + if (adaptive) { + // TODO(Inference) + CHECK(false) << "Not supported adaptive pool"; + } + + nvinfer1::Dims window_size = ArrayAttrToNvDims(ksize); + + auto* layer = + network->addPoolingNd(*input_itensor, + static_cast(pool_type), + window_size); + CHECK_NOTNULL(layer); + layer->setPaddingMode(static_cast(padding_mode)); + layer->setPaddingNd(ArrayAttrToNvDims(paddings)); + layer->setStrideNd(ArrayAttrToNvDims(strides)); + layer->setAverageCountExcludesPadding(exclusive); + + mlir::Value out_repr = op.output_tensor(); + nvinfer1::ITensor* out_tensor = layer->getOutput(0); + value_to_trt_tensor_map[out_repr] = out_tensor; +} + inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT nvinfer1::INetworkDefinition* network, ValueToITensorMap& value_to_trt_tensor_map, // NOLINT diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir deleted file mode 100644 index ef86dcf1e72a04c478a7763000cf366715665d81..0000000000000000000000000000000000000000 --- a/paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir +++ /dev/null @@ -1,37 +0,0 @@ -// RUN: infrtexec -i %s | FileCheck %s - -// CHECK-LABEL: @run_trt -func @run_trt(%0 : !infrt.dense_tensor, %ctx : !phi.context) { - %a = "trt.create_engine"(%0) ({ - %1 = "trt.Activation"(%0) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor) -> !infrt.dense_tensor - "infrt.return"(%1) : (!infrt.dense_tensor) -> () - }) : (!infrt.dense_tensor) -> !trt.engine - "trt.inspect_engine"(%a) {} : (!trt.engine) -> () - - %res = "trt.compute"(%a, %ctx) {} : (!trt.engine, !phi.context) -> (!infrt.tensor_list) - %size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32) - "infrt.print.i32"(%size) {} : (i32) -> () - - %ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) - "phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor) -> () - - infrt.return -} - -// CHECK-LABEL: @main -func @main() { - %ctx = "phi_dt.create_context.gpu" (): () -> !phi.context - %t = "phi_dt.create_dense_tensor.gpu" (%ctx) { - precision=#infrt.precision, - layout=#infrt.layout, - dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) - - "phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor) -> () - "phi_dt.print_tensor" (%t) : (!infrt.dense_tensor) -> () - - //%res = - infrt.call @run_trt(%t, %ctx) : (!infrt.dense_tensor, !phi.context) -> () - //-> (!infrt.dense_tensor) - - infrt.return -} diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir new file mode 100644 index 0000000000000000000000000000000000000000..557990677696e25a35abb8a384d631f83ea15c85 --- /dev/null +++ b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir @@ -0,0 +1,21 @@ +module { + func @main_graph(%arg0: !infrt.dense_tensor) -> !infrt.dense_tensor { + %0 = "phi_dt.create_context.gpu"() : () -> !phi.context + %1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + %2 = "trt.create_engine"(%1) ( { + %6 = "trt.Activation"(%1) {activation_type = 1 : si32, alpha = 0.000000e+00 : f32, beta = 0.000000e+00 : f32} : (!infrt.dense_tensor) -> !infrt.dense_tensor + infrt.return %6 : !infrt.dense_tensor + }) {run_once = true} : (!infrt.dense_tensor) -> !trt.engine + %3 = "trt.compute"(%2, %0) : (!trt.engine, !phi.context) -> !infrt.tensor_list + %4 = "dt.tensor_list_get_tensor"(%3) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor + %5 = "phi_dt.memcpy.gpu"(%4, %0) {d2h = true} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + infrt.return %5 : !infrt.dense_tensor + } + func @main() { + %0 = "phi_dt.create_context.cpu"() : () -> !phi.context + %1 = "phi_dt.create_inited_dense_tensor.cpu.f32"(%0) {dims = [3, 6, 1, 1], layout = #infrt.layout, lod = [0], value = 1.500000e+00 : f32} : (!phi.context) -> !infrt.dense_tensor + %2 = infrt.call @main_graph(%1) : (!infrt.dense_tensor) -> !infrt.dense_tensor + phi_dt.print_tensor(%2 : !infrt.dense_tensor) + infrt.return + } +} diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir index 78dc4ac1c1093c1eb9b3fb30d0ea3f0cd5be6104..aba706df718435c8240d306d78af964cc9930c7c 100644 --- a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir +++ b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir @@ -1,46 +1,25 @@ -// RUN: infrtexec -i %s | FileCheck %s - -// CHECK-LABEL: @main -func @main() { - %ctx = "phi_dt.create_context.gpu" (): () -> !phi.context - %cpu_ctx = "phi_dt.create_context.cpu" (): () -> !phi.context - - %input_tensor = "phi_dt.create_dense_tensor.gpu" (%ctx) { - precision=#infrt.precision, - layout=#infrt.layout, - dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) - "phi_dt.fill_dense_tensor.f32"(%input_tensor) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor) -> () - //"phi_dt.print_tensor" (%input_tensor) : (!infrt.dense_tensor) -> () - - %kernel_weight = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) { - precision=#infrt.precision, - layout=#infrt.layout, - dims=[2:i64, 3:i64], lod=[1:i64]} : (!phi.context) -> (!infrt.dense_tensor) - "phi_dt.fill_dense_tensor.f32"(%kernel_weight) {value=[1.:f32, 2.:f32, 3.:f32, 4.:f32, 5.:f32, 6.:f32]} : (!infrt.dense_tensor) -> () - //"phi_dt.print_tensor" (%kernel_weight) : (!infrt.dense_tensor) -> () - - %kernel_bias = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) { - precision=#infrt.precision, - layout=#infrt.layout, - dims=[2:i64], lod=[1:i64]} : (!phi.context) -> (!infrt.dense_tensor) - "phi_dt.fill_dense_tensor.f32"(%kernel_bias) {value=[1.:f32, 2.:f32]} : (!infrt.dense_tensor) -> () - //"phi_dt.print_tensor" (%kernel_bias) : (!infrt.dense_tensor) -> () - - %engine = "trt.create_engine"(%input_tensor, %kernel_weight, %kernel_bias) ({ - %1 = "trt.Activation"(%input_tensor) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor) -> !infrt.dense_tensor - %2 = "trt.FullyConnected"(%input_tensor, %kernel_weight, %kernel_bias) {out_channel_num = 2 : si32} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor - "infrt.return"(%1, %2) : (!infrt.dense_tensor, !infrt.dense_tensor) -> () - }) : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !trt.engine - - %res = "trt.compute"(%engine, %ctx) {} : (!trt.engine, !phi.context) -> (!infrt.tensor_list) - %size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32) - "infrt.print.i32"(%size) {} : (i32) -> () - - %ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) - "phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor) -> () - - %ts1 = "dt.tensor_list_get_tensor"(%res) {id = 1 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) - "phi_dt.print_tensor" (%ts1) : (!infrt.dense_tensor) -> () - - infrt.return +module { + func @main_graph(%arg0: !infrt.dense_tensor) -> !infrt.dense_tensor { + %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context + %0 = "phi_dt.create_context.gpu"() : () -> !phi.context + %1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + %4 = "phi_dt.create_inited_dense_tensor.cpu.f32" (%ctx) {value=1.5:f32, layout=#infrt.layout, lod=[0], dims=[2, 6]}: (!phi.context) -> (!infrt.dense_tensor) + %3 = "phi_dt.create_inited_dense_tensor.cpu.f32" (%ctx) {value=1.5:f32, layout=#infrt.layout, lod=[0], dims=[2]}: (!phi.context) -> (!infrt.dense_tensor) + %5 = "trt.create_engine"(%1, %4, %3) ( { + %10 = "trt.FullyConnected"(%1, %4, %3) {out_channel_num = 2 : si32} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + infrt.return %10 : !infrt.dense_tensor + }) {run_once = true} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !trt.engine + %6 = "trt.compute"(%5, %0) : (!trt.engine, !phi.context) -> !infrt.tensor_list + %7 = "dt.tensor_list_get_tensor"(%6) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor + %8 = "phi_dt.memcpy.gpu"(%7, %0) {d2h = true} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + infrt.return %8 : !infrt.dense_tensor + } + + func @main() { + %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context + %input_tensor = "phi_dt.create_inited_dense_tensor.cpu.f32" (%ctx) {value=1.5:f32, layout=#infrt.layout, lod=[0], dims=[3, 6, 1, 1]}: (!phi.context) -> (!infrt.dense_tensor) + %res = infrt.call @main_graph(%input_tensor) {} : (!infrt.dense_tensor) -> !infrt.dense_tensor + "phi_dt.print_tensor" (%res) : (!infrt.dense_tensor) -> () + infrt.return + } } diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir new file mode 100644 index 0000000000000000000000000000000000000000..af24ac63d23fecfe84e194f444a4b530ae968f4e --- /dev/null +++ b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir @@ -0,0 +1,21 @@ +module { + func @main_graph(%arg0: !infrt.dense_tensor) -> !infrt.dense_tensor { + %0 = "phi_dt.create_context.gpu"() : () -> !phi.context + %1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + %2 = "trt.create_engine"(%1) ( { + %6 = "trt.Pooling"(%1) {padding_mode = 0 : i32, paddings = [1 : i32, 1 : i32], pool_type = 0 : i32, strides = [2 : i32, 2 : i32], window_size = [3 : i32, 3 : i32], exclusive = false, adaptive = false, padding_algorithm = "EXPLICIT"} : (!infrt.dense_tensor) -> !infrt.dense_tensor + infrt.return %6 : !infrt.dense_tensor + }) {run_once = true} : (!infrt.dense_tensor) -> !trt.engine + %3 = "trt.compute"(%2, %0) : (!trt.engine, !phi.context) -> !infrt.tensor_list + %4 = "dt.tensor_list_get_tensor"(%3) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor + %5 = "phi_dt.memcpy.gpu"(%4, %0) {d2h = true} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + infrt.return %5 : !infrt.dense_tensor + } + func @main() { + %0 = "phi_dt.create_context.cpu"() : () -> !phi.context + %1 = "phi_dt.create_inited_dense_tensor.cpu.f32"(%0) {dims = [1, 3, 10, 10], layout = #infrt.layout, lod = [0], value = 1.500000e+00 : f32} : (!phi.context) -> !infrt.dense_tensor + %2 = infrt.call @main_graph(%1) : (!infrt.dense_tensor) -> !infrt.dense_tensor + phi_dt.print_tensor(%2 : !infrt.dense_tensor) + infrt.return + } +}