未验证 提交 932bbe95 编写于 作者: Z Zhaolong Xing 提交者: GitHub

fix pool trt plugin bug (#26463)

test=develop
上级 0a29fc85
......@@ -104,32 +104,51 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions(
auto stri_0 = expr_builder.constant(strides_[0]);
auto stri_1 = expr_builder.constant(strides_[1]);
auto one_value = expr_builder.constant(1);
auto tmp1_0 =
expr_builder.constant((-ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1);
auto tmp1_1 =
expr_builder.constant((-ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1);
auto v0_tmp = expr_builder.constant(-ksize_[0] + 2 * paddings_[0]);
auto v1_tmp = expr_builder.constant(-ksize_[1] + 2 * paddings_[1]);
auto tmp2_0 = expr_builder.constant(
(-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) / strides_[0] + 1);
auto tmp2_1 = expr_builder.constant(
(-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) / strides_[1] + 1);
auto *a_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV,
*inputs[0].d[2], *stri_0);
auto *b_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV,
*inputs[0].d[3], *stri_1);
auto ceil_tmp =
expr_builder.constant(-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1);
auto ceil1_tmp =
expr_builder.constant(-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1);
if (!ceil_mode_) {
output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*a_d, *tmp1_0);
output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*b_d, *tmp1_1);
output.d[2] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*inputs[0].d[2], *v0_tmp),
*stri_0),
*one_value);
output.d[3] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*inputs[0].d[3], *v1_tmp),
*stri_1),
*one_value);
} else {
output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*a_d, *tmp2_0);
output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*b_d, *tmp2_1);
output.d[2] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*inputs[0].d[2], *ceil_tmp),
*stri_0),
*one_value);
output.d[3] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*inputs[0].d[3], *ceil1_tmp),
*stri_1),
*one_value);
}
return output;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册