未验证 提交 85e697d7 编写于 作者: P Pei Yang 提交者: GitHub

support depthwise_conv2d_transpose (#32593)

上级 809ac036
...@@ -160,7 +160,7 @@ class Deconv2dOpConverter : public OpConverter { ...@@ -160,7 +160,7 @@ class Deconv2dOpConverter : public OpConverter {
nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight, nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* { TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* {
auto* layer = auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Deconvolution, *inputs, n_input, TRT_ENGINE_ADD_LAYER(engine_, Deconvolution, *inputs, n_output,
ksize, weight.get(), bias.get()); ksize, weight.get(), bias.get());
return layer; return layer;
}, },
......
...@@ -109,6 +109,12 @@ class OpConverter { ...@@ -109,6 +109,12 @@ class OpConverter {
it, platform::errors::Unimplemented("no OpConverter for optype [%s]", it, platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type())); op_desc.Type()));
} }
if (op_desc.Type() == "depthwise_conv2d_transpose") {
it = Registry<OpConverter>::Global().Lookup("conv2d_transpose");
PADDLE_ENFORCE_NOT_NULL(
it, platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
}
if (op_desc.Type() == "transpose2") { if (op_desc.Type() == "transpose2") {
it = Registry<OpConverter>::Global().Lookup("transpose"); it = Registry<OpConverter>::Global().Lookup("transpose");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
......
...@@ -102,6 +102,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -102,6 +102,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"dropout", "dropout",
"prelu", "prelu",
"conv2d_transpose", "conv2d_transpose",
"depthwise_conv2d_transpose",
"leaky_relu", "leaky_relu",
"fc", "fc",
"shuffle_channel", "shuffle_channel",
...@@ -172,7 +173,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -172,7 +173,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
if (op_type == "conv2d" || op_type == "conv2d_transpose" || if (op_type == "conv2d" || op_type == "conv2d_transpose" ||
op_type == "conv2d_fusion") { op_type == "conv2d_fusion" || op_type == "depthwise_conv2d" ||
op_type == "depthwise_conv2d_transpose") {
std::vector<int> paddings = std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
...@@ -202,7 +204,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -202,7 +204,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
} }
if (op_type == "conv2d_transpose") { if (op_type == "conv2d_transpose" ||
op_type == "depthwise_conv2d_transpose") {
if (!desc.HasAttr("dilations")) { if (!desc.HasAttr("dilations")) {
return false; return false;
} else { } else {
......
...@@ -96,6 +96,7 @@ class TensorRTSubgraphPassConvTransposeTest(InferencePassTest): ...@@ -96,6 +96,7 @@ class TensorRTSubgraphPassConvTransposeTest(InferencePassTest):
groups=self.conv_groups, groups=self.conv_groups,
padding=self.conv_padding, padding=self.conv_padding,
bias_attr=False, bias_attr=False,
use_cudnn=self.use_cudnn,
act=None) act=None)
self.feeds = { self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"), "data": np.random.random([1, 6, 64, 64]).astype("float32"),
...@@ -110,6 +111,7 @@ class TensorRTSubgraphPassConvTransposeTest(InferencePassTest): ...@@ -110,6 +111,7 @@ class TensorRTSubgraphPassConvTransposeTest(InferencePassTest):
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 1 self.conv_groups = 1
self.conv_padding = [1, 1] self.conv_padding = [1, 1]
self.use_cudnn = True
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
...@@ -126,6 +128,7 @@ class TensorRTSubgraphPassConvTransposeValidPaddingTest( ...@@ -126,6 +128,7 @@ class TensorRTSubgraphPassConvTransposeValidPaddingTest(
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 1 self.conv_groups = 1
self.conv_padding = 'VALID' self.conv_padding = 'VALID'
self.use_cudnn = True
class TensorRTSubgraphPassConvTransposeSamePaddingTest( class TensorRTSubgraphPassConvTransposeSamePaddingTest(
...@@ -135,15 +138,27 @@ class TensorRTSubgraphPassConvTransposeSamePaddingTest( ...@@ -135,15 +138,27 @@ class TensorRTSubgraphPassConvTransposeSamePaddingTest(
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 1 self.conv_groups = 1
self.conv_padding = 'SAME' self.conv_padding = 'SAME'
self.use_cudnn = True
class TensorRTSubgraphPassDepthwiseConvTransposeTest( class TensorRTSubgraphPassConvTransposeMultiGroupTest(
TensorRTSubgraphPassConvTransposeTest): TensorRTSubgraphPassConvTransposeTest):
def set_params(self): def set_params(self):
self.conv_num_filters = 6 self.conv_num_filters = 6
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 1 self.conv_groups = 2
self.conv_padding = [1, 1]
self.use_cudnn = True
class TensorRTSubgraphPassDepthwiseConvTransposeTest(
TensorRTSubgraphPassConvTransposeTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 4
self.conv_groups = 6
self.conv_padding = [1, 1] self.conv_padding = [1, 1]
self.use_cudnn = False
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册