未验证 提交 32211fe9 编写于 作者: P Pei Yang 提交者: GitHub

TRT conv2d converter support SAME padding (#31379)

上级 e312a1ff
...@@ -97,6 +97,10 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -97,6 +97,10 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
const std::vector<int> paddings = const std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
std::string padding_algorithm = "EXPLICIT";
if (op_desc.HasAttr("padding_algorithm"))
padding_algorithm =
BOOST_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm"));
nvinfer1::DimsHW nv_ksize(filter_h, filter_w); nvinfer1::DimsHW nv_ksize(filter_h, filter_w);
nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]); nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]);
...@@ -126,6 +130,9 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -126,6 +130,9 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
layer->setStride(nv_strides); layer->setStride(nv_strides);
layer->setPadding(nv_paddings); layer->setPadding(nv_paddings);
layer->setNbGroups(groups); layer->setNbGroups(groups);
if (padding_algorithm == "SAME") {
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
// set dilations // set dilations
fset_dilation(layer, nv_dilations); fset_dilation(layer, nv_dilations);
......
...@@ -129,13 +129,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -129,13 +129,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
std::vector<int> paddings = std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
std::string padding_algorithm = "EXPLICIT"; if (paddings.size() > 2) return false;
if (desc.HasAttr("padding_algorithm"))
padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (paddings.size() > 2 ||
(padding_algorithm == "SAME" && op_type != "pool2d"))
return false;
} }
if (op_type == "matmul") { if (op_type == "matmul") {
auto* block = desc.Block(); auto* block = desc.Block();
......
...@@ -67,15 +67,12 @@ class TensorRTSubgraphPassConvValidPaddingTest(TensorRTSubgraphPassConvTest): ...@@ -67,15 +67,12 @@ class TensorRTSubgraphPassConvValidPaddingTest(TensorRTSubgraphPassConvTest):
self.conv_padding = 'VALID' self.conv_padding = 'VALID'
'''
# conv2d padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete.
class TensorRTSubgraphPassConvSamePaddingTest(InferencePassTest): class TensorRTSubgraphPassConvSamePaddingTest(InferencePassTest):
def set_params(self): def set_params(self):
self.conv_num_filters = 6 self.conv_num_filters = 6
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 3 self.conv_groups = 3
self.conv_padding = 'SAME' self.conv_padding = 'SAME'
'''
class TensorRTSubgraphPassDepthwiseConvTest(TensorRTSubgraphPassConvTest): class TensorRTSubgraphPassDepthwiseConvTest(TensorRTSubgraphPassConvTest):
...@@ -131,15 +128,13 @@ class TensorRTSubgraphPassConvTransposeValidPaddingTest( ...@@ -131,15 +128,13 @@ class TensorRTSubgraphPassConvTransposeValidPaddingTest(
self.conv_padding = 'VALID' self.conv_padding = 'VALID'
''' class TensorRTSubgraphPassConvTransposeSamePaddingTest(
# conv2d_transpose padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete. TensorRTSubgraphPassConvTransposeTest):
class TensorRTSubgraphPassConvTransposeSamePaddingTest(TensorRTSubgraphPassConvTransposeTest):
def set_params(self): def set_params(self):
self.conv_num_filters = 6 self.conv_num_filters = 6
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 1 self.conv_groups = 1
self.conv_padding = 'SAME' self.conv_padding = 'SAME'
'''
class TensorRTSubgraphPassDepthwiseConvTransposeTest( class TensorRTSubgraphPassDepthwiseConvTransposeTest(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册