未验证 提交 602d2ce5 编写于 作者: P Pei Yang 提交者: GitHub

change avg pooling from trt plugin to trt layer (#28032)

上级 5289b72a
......@@ -88,6 +88,9 @@ class Pool2dOpConverter : public OpConverter {
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
bool exclusive = op_desc.HasAttr("exclusive")
? BOOST_GET_CONST(bool, op_desc.GetAttr("exclusive"))
: true;
bool ceil_mode = BOOST_GET_CONST(bool, op_desc.GetAttr("ceil_mode"));
bool adaptive = false;
if (op_desc.HasAttr("adaptive"))
......@@ -166,7 +169,7 @@ class Pool2dOpConverter : public OpConverter {
return;
}
if (!adaptive && pool_type == "max") {
if (!adaptive) {
// Under ceil mode, the pre_pad and post_pad are used to
// record the the padding size. In some ceil mode cases,
// we do not need padding, so we initialize the two vars to 0.
......@@ -194,6 +197,7 @@ class Pool2dOpConverter : public OpConverter {
"trt pool layer in converter could not be created."));
pool_layer->setStride(nv_strides);
pool_layer->setPadding(nv_paddings);
pool_layer->setAverageCountExcludesPadding(exclusive);
layer = pool_layer;
} else {
// Average pooling needs to exclude the padding pixels from the average
......@@ -213,7 +217,6 @@ class Pool2dOpConverter : public OpConverter {
"trt pool plugin layer in converter could not be created."));
layer = pool_layer;
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册