未验证 提交 5fb28500 编写于 作者: F feng_shuai 提交者: GitHub

change api to support trt8 in pool3d_op_convert (#36783) (#36812)

* change api for support trt8
上级 8ede9e6f
...@@ -30,8 +30,8 @@ namespace tensorrt { ...@@ -30,8 +30,8 @@ namespace tensorrt {
inline void DealCeilMode(const nvinfer1::Dims &input_shape, inline void DealCeilMode(const nvinfer1::Dims &input_shape,
std::vector<int> ksize, std::vector<int> strides, std::vector<int> ksize, std::vector<int> strides,
std::vector<int> paddings, nvinfer1::DimsCHW *pre_pad, std::vector<int> paddings, nvinfer1::Dims3 *pre_pad,
nvinfer1::DimsCHW *post_pad, int input_dims) { nvinfer1::Dims3 *post_pad, int input_dims) {
int input_depth = input_shape.d[input_dims - 3]; int input_depth = input_shape.d[input_dims - 3];
int input_height = input_shape.d[input_dims - 2]; int input_height = input_shape.d[input_dims - 2];
int input_width = input_shape.d[input_dims - 1]; int input_width = input_shape.d[input_dims - 1];
...@@ -56,15 +56,15 @@ inline void DealCeilMode(const nvinfer1::Dims &input_shape, ...@@ -56,15 +56,15 @@ inline void DealCeilMode(const nvinfer1::Dims &input_shape,
1; 1;
if (floor_d_output_size != ceil_d_output_size) { if (floor_d_output_size != ceil_d_output_size) {
post_pad->c() = strides[0] - 1; post_pad->d[0] = strides[0] - 1;
} }
if (floor_h_output_size != ceil_h_output_size) { if (floor_h_output_size != ceil_h_output_size) {
post_pad->h() = strides[1] - 1; post_pad->d[1] = strides[1] - 1;
} }
if (floor_w_output_size != ceil_w_output_size) { if (floor_w_output_size != ceil_w_output_size) {
post_pad->w() = strides[2] - 1; post_pad->d[2] = strides[2] - 1;
} }
} }
...@@ -118,9 +118,9 @@ class Pool3dOpConverter : public OpConverter { ...@@ -118,9 +118,9 @@ class Pool3dOpConverter : public OpConverter {
reduce_operation = nvinfer1::ReduceOperation::kAVG; reduce_operation = nvinfer1::ReduceOperation::kAVG;
plugin_pool_type = plugin::Pool3DPlugin::Pool3DType::avg; plugin_pool_type = plugin::Pool3DPlugin::Pool3DType::avg;
} }
nvinfer1::DimsCHW nv_ksize(ksize[0], ksize[1], ksize[2]); nvinfer1::Dims3 nv_ksize(ksize[0], ksize[1], ksize[2]);
nvinfer1::DimsCHW nv_strides(strides[0], strides[1], strides[2]); nvinfer1::Dims3 nv_strides(strides[0], strides[1], strides[2]);
nvinfer1::DimsCHW nv_paddings(paddings[0], paddings[1], paddings[2]); nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]);
nvinfer1::ILayer *layer = nullptr; nvinfer1::ILayer *layer = nullptr;
if (op_desc.HasAttr("enable_int8")) { if (op_desc.HasAttr("enable_int8")) {
CHECK(op_desc.HasAttr("X_scale")); CHECK(op_desc.HasAttr("X_scale"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册