未验证 提交 602d2ce5 编写于 作者: P Pei Yang 提交者: GitHub

change avg pooling from trt plugin to trt layer (#28032)

上级 5289b72a
...@@ -88,6 +88,9 @@ class Pool2dOpConverter : public OpConverter { ...@@ -88,6 +88,9 @@ class Pool2dOpConverter : public OpConverter {
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
std::vector<int> paddings = std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
bool exclusive = op_desc.HasAttr("exclusive")
? BOOST_GET_CONST(bool, op_desc.GetAttr("exclusive"))
: true;
bool ceil_mode = BOOST_GET_CONST(bool, op_desc.GetAttr("ceil_mode")); bool ceil_mode = BOOST_GET_CONST(bool, op_desc.GetAttr("ceil_mode"));
bool adaptive = false; bool adaptive = false;
if (op_desc.HasAttr("adaptive")) if (op_desc.HasAttr("adaptive"))
...@@ -166,7 +169,7 @@ class Pool2dOpConverter : public OpConverter { ...@@ -166,7 +169,7 @@ class Pool2dOpConverter : public OpConverter {
return; return;
} }
if (!adaptive && pool_type == "max") { if (!adaptive) {
// Under ceil mode, the pre_pad and post_pad are used to // Under ceil mode, the pre_pad and post_pad are used to
// record the the padding size. In some ceil mode cases, // record the the padding size. In some ceil mode cases,
// we do not need padding, so we initialize the two vars to 0. // we do not need padding, so we initialize the two vars to 0.
...@@ -194,6 +197,7 @@ class Pool2dOpConverter : public OpConverter { ...@@ -194,6 +197,7 @@ class Pool2dOpConverter : public OpConverter {
"trt pool layer in converter could not be created.")); "trt pool layer in converter could not be created."));
pool_layer->setStride(nv_strides); pool_layer->setStride(nv_strides);
pool_layer->setPadding(nv_paddings); pool_layer->setPadding(nv_paddings);
pool_layer->setAverageCountExcludesPadding(exclusive);
layer = pool_layer; layer = pool_layer;
} else { } else {
// Average pooling needs to exclude the padding pixels from the average // Average pooling needs to exclude the padding pixels from the average
...@@ -213,7 +217,6 @@ class Pool2dOpConverter : public OpConverter { ...@@ -213,7 +217,6 @@ class Pool2dOpConverter : public OpConverter {
"trt pool plugin layer in converter could not be created.")); "trt pool plugin layer in converter could not be created."));
layer = pool_layer; layer = pool_layer;
} }
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册