未验证 提交 32211fe9 编写于 作者: P Pei Yang 提交者: GitHub

TRT conv2d converter support SAME padding (#31379)

上级 e312a1ff
......@@ -97,6 +97,10 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
const std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
std::string padding_algorithm = "EXPLICIT";
if (op_desc.HasAttr("padding_algorithm"))
padding_algorithm =
BOOST_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm"));
nvinfer1::DimsHW nv_ksize(filter_h, filter_w);
nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]);
......@@ -126,6 +130,9 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
layer->setStride(nv_strides);
layer->setPadding(nv_paddings);
layer->setNbGroups(groups);
if (padding_algorithm == "SAME") {
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
// set dilations
fset_dilation(layer, nv_dilations);
......
......@@ -129,13 +129,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
std::string padding_algorithm = "EXPLICIT";
if (desc.HasAttr("padding_algorithm"))
padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (paddings.size() > 2 ||
(padding_algorithm == "SAME" && op_type != "pool2d"))
return false;
if (paddings.size() > 2) return false;
}
if (op_type == "matmul") {
auto* block = desc.Block();
......
......@@ -67,15 +67,12 @@ class TensorRTSubgraphPassConvValidPaddingTest(TensorRTSubgraphPassConvTest):
self.conv_padding = 'VALID'
'''
# conv2d padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete.
class TensorRTSubgraphPassConvSamePaddingTest(InferencePassTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 3
self.conv_padding = 'SAME'
'''
class TensorRTSubgraphPassDepthwiseConvTest(TensorRTSubgraphPassConvTest):
......@@ -131,15 +128,13 @@ class TensorRTSubgraphPassConvTransposeValidPaddingTest(
self.conv_padding = 'VALID'
'''
# conv2d_transpose padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete.
class TensorRTSubgraphPassConvTransposeSamePaddingTest(TensorRTSubgraphPassConvTransposeTest):
class TensorRTSubgraphPassConvTransposeSamePaddingTest(
TensorRTSubgraphPassConvTransposeTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = 'SAME'
'''
class TensorRTSubgraphPassDepthwiseConvTransposeTest(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册