From 5fb28500081994159b4f96fdea6d9b274de2af55 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Thu, 28 Oct 2021 14:57:37 +0800 Subject: [PATCH] change api to support trt8 in pool3d_op_convert (#36783) (#36812) * change api for support trt8 --- .../inference/tensorrt/convert/pool3d_op.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc index 9baed499f1..b8e87a8d94 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc @@ -30,8 +30,8 @@ namespace tensorrt { inline void DealCeilMode(const nvinfer1::Dims &input_shape, std::vector ksize, std::vector strides, - std::vector paddings, nvinfer1::DimsCHW *pre_pad, - nvinfer1::DimsCHW *post_pad, int input_dims) { + std::vector paddings, nvinfer1::Dims3 *pre_pad, + nvinfer1::Dims3 *post_pad, int input_dims) { int input_depth = input_shape.d[input_dims - 3]; int input_height = input_shape.d[input_dims - 2]; int input_width = input_shape.d[input_dims - 1]; @@ -56,15 +56,15 @@ inline void DealCeilMode(const nvinfer1::Dims &input_shape, 1; if (floor_d_output_size != ceil_d_output_size) { - post_pad->c() = strides[0] - 1; + post_pad->d[0] = strides[0] - 1; } if (floor_h_output_size != ceil_h_output_size) { - post_pad->h() = strides[1] - 1; + post_pad->d[1] = strides[1] - 1; } if (floor_w_output_size != ceil_w_output_size) { - post_pad->w() = strides[2] - 1; + post_pad->d[2] = strides[2] - 1; } } @@ -118,9 +118,9 @@ class Pool3dOpConverter : public OpConverter { reduce_operation = nvinfer1::ReduceOperation::kAVG; plugin_pool_type = plugin::Pool3DPlugin::Pool3DType::avg; } - nvinfer1::DimsCHW nv_ksize(ksize[0], ksize[1], ksize[2]); - nvinfer1::DimsCHW nv_strides(strides[0], strides[1], strides[2]); - nvinfer1::DimsCHW nv_paddings(paddings[0], paddings[1], paddings[2]); + nvinfer1::Dims3 nv_ksize(ksize[0], ksize[1], ksize[2]); + nvinfer1::Dims3 nv_strides(strides[0], strides[1], strides[2]); + nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]); nvinfer1::ILayer *layer = nullptr; if (op_desc.HasAttr("enable_int8")) { CHECK(op_desc.HasAttr("X_scale")); -- GitLab