未验证 提交 7e8e3bab 编写于 作者: W Wilber 提交者: GitHub

[Cherry-pick Release/1.8] Fix bug of trt plugin pooling. (#29033)

上级 21d98351
......@@ -79,18 +79,25 @@ class Pool2dOpConverter : public OpConverter {
std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
bool ceil_mode = boost::get<bool>(op_desc.GetAttr("ceil_mode"));
bool exclusive = op_desc.HasAttr("exclusive")
? boost::get<bool>(op_desc.GetAttr("exclusive"))
: true;
bool adaptive = false;
if (op_desc.HasAttr("adaptive"))
adaptive = boost::get<bool>(op_desc.GetAttr("adaptive"));
nvinfer1::PoolingType nv_pool_type = nvinfer1::PoolingType::kMAX;
nvinfer1::ReduceOperation reduce_operation =
nvinfer1::ReduceOperation::kMAX;
plugin::PoolPlugin::PoolType plugin_pool_type =
plugin::PoolPlugin::PoolType::max;
if (pool_type == "max") {
nv_pool_type = nvinfer1::PoolingType::kMAX;
reduce_operation = nvinfer1::ReduceOperation::kMAX;
plugin_pool_type = plugin::PoolPlugin::PoolType::max;
} else if (pool_type == "avg") {
nv_pool_type = nvinfer1::PoolingType::kAVERAGE;
reduce_operation = nvinfer1::ReduceOperation::kAVG;
plugin_pool_type = plugin::PoolPlugin::PoolType::avg;
} else {
PADDLE_THROW(platform::errors::Fatal(
......@@ -113,12 +120,17 @@ class Pool2dOpConverter : public OpConverter {
}
if (engine_->with_dynamic_shape()) {
if (!adaptive && pool_type == "max" && !global_pooling && !ceil_mode) {
if (!adaptive && !global_pooling && !ceil_mode) {
auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1,
nv_pool_type, nv_ksize);
pool_layer->setStride(nv_strides);
pool_layer->setPadding(nv_paddings);
pool_layer->setAverageCountExcludesPadding(exclusive);
layer = pool_layer;
} else if (global_pooling) {
auto *reduce_layer = TRT_ENGINE_ADD_LAYER(engine_, Reduce, *input1,
reduce_operation, 12, true);
layer = reduce_layer;
} else {
#if IS_TRT_VERSION_GE(6000)
plugin::PoolPluginDynamic *plugin =
......@@ -140,23 +152,27 @@ class Pool2dOpConverter : public OpConverter {
if (global_pooling == true) {
nv_ksize.d[0] = input_shape.d[input_dims - 2];
nv_ksize.d[1] = input_shape.d[input_dims - 1];
auto *layer = TRT_ENGINE_ADD_LAYER(
auto *pool_layer = TRT_ENGINE_ADD_LAYER(
engine_, Pooling, *const_cast<nvinfer1::ITensor *>(input1),
nv_pool_type, nv_ksize);
PADDLE_ENFORCE_NOT_NULL(
layer, platform::errors::Fatal(
"trt pool layer in converter could not be created."));
pool_layer, platform::errors::Fatal(
"trt pool layer in converter could not be created."));
auto output_name = op_desc.Output("Out")[0];
layer->setName(("pool2d (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
pool_layer->setStride(nv_strides);
pool_layer->setPadding(nv_paddings);
pool_layer->setAverageCountExcludesPadding(exclusive);
pool_layer->setName(("pool2d (Output: " + output_name + ")").c_str());
pool_layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, pool_layer->getOutput(0));
layer = pool_layer;
if (test_mode) {
engine_->DeclareOutput(output_name);
}
return;
}
if (!adaptive && pool_type == "max") {
if (!adaptive) {
// Under ceil mode, the pre_pad and post_pad are used to
// record the the padding size. In some ceil mode cases,
// we do not need padding, so we initialize the two vars to 0.
......@@ -184,6 +200,7 @@ class Pool2dOpConverter : public OpConverter {
"trt pool layer in converter could not be created."));
pool_layer->setStride(nv_strides);
pool_layer->setPadding(nv_paddings);
pool_layer->setAverageCountExcludesPadding(exclusive);
layer = pool_layer;
} else {
// Average pooling needs to exclude the padding pixels from the average
......@@ -203,7 +220,6 @@ class Pool2dOpConverter : public OpConverter {
"trt pool plugin layer in converter could not be created."));
layer = pool_layer;
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册