未验证 提交 eb32746a 编写于 作者: W Wang Bojun 提交者: GitHub

TRT pool2d adaptive mode bugfix (#46802)

* draft with debug print
上级 eb429936
...@@ -321,16 +321,16 @@ int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -321,16 +321,16 @@ int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
paddings[1] = 0; paddings[1] = 0;
output_shape[2] = 1; output_shape[2] = 1;
output_shape[3] = 1; output_shape[3] = 1;
if (adaptive_) {
output_shape[2] = h;
output_shape[3] = w;
}
} else { } else {
auto data_dim = CalcOutputSize( auto data_dim = CalcOutputSize(
{h, w}, ceil_mode_, adaptive_, ksize_, strides_, paddings_); {h, w}, ceil_mode_, adaptive_, ksize_, strides_, paddings_);
output_shape[2] = data_dim[0]; output_shape[2] = data_dim[0];
output_shape[3] = data_dim[1]; output_shape[3] = data_dim[1];
} }
if (adaptive_) {
output_shape[2] = h;
output_shape[3] = w;
}
if (pool_type_ == "max") { if (pool_type_ == "max") {
phi::funcs::MaxPool<float> pool_process; phi::funcs::MaxPool<float> pool_process;
......
...@@ -460,7 +460,6 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()( ...@@ -460,7 +460,6 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
const int stride_width = strides[1]; const int stride_width = strides[1];
const int padding_height = paddings[0]; const int padding_height = paddings[0];
const int padding_width = paddings[1]; const int padding_width = paddings[1];
int nthreads = batch_size * output_channels * output_height * output_width; int nthreads = batch_size * output_channels * output_height * output_width;
auto pool_divmods = auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height); FastDivModForPooling(input_channels, output_width, output_height);
...@@ -491,6 +490,7 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()( ...@@ -491,6 +490,7 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
pool_compute, pool_compute,
exclusive, exclusive,
output); output);
} else { } else {
int thread_num = 1024; int thread_num = 1024;
#ifdef WITH_NV_JETSON #ifdef WITH_NV_JETSON
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册