未验证 提交 85e697d7 编写于 作者: P Pei Yang 提交者: GitHub

support depthwise_conv2d_transpose (#32593)

上级 809ac036
......@@ -160,7 +160,7 @@ class Deconv2dOpConverter : public OpConverter {
nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* {
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Deconvolution, *inputs, n_input,
TRT_ENGINE_ADD_LAYER(engine_, Deconvolution, *inputs, n_output,
ksize, weight.get(), bias.get());
return layer;
},
......
......@@ -109,6 +109,12 @@ class OpConverter {
it, platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
}
if (op_desc.Type() == "depthwise_conv2d_transpose") {
it = Registry<OpConverter>::Global().Lookup("conv2d_transpose");
PADDLE_ENFORCE_NOT_NULL(
it, platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
}
if (op_desc.Type() == "transpose2") {
it = Registry<OpConverter>::Global().Lookup("transpose");
PADDLE_ENFORCE_NOT_NULL(
......
......@@ -102,6 +102,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"dropout",
"prelu",
"conv2d_transpose",
"depthwise_conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
......@@ -172,7 +173,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
if (op_type == "conv2d" || op_type == "conv2d_transpose" ||
op_type == "conv2d_fusion") {
op_type == "conv2d_fusion" || op_type == "depthwise_conv2d" ||
op_type == "depthwise_conv2d_transpose") {
std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
......@@ -202,7 +204,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
if (op_type == "conv2d_transpose") {
if (op_type == "conv2d_transpose" ||
op_type == "depthwise_conv2d_transpose") {
if (!desc.HasAttr("dilations")) {
return false;
} else {
......
......@@ -96,6 +96,7 @@ class TensorRTSubgraphPassConvTransposeTest(InferencePassTest):
groups=self.conv_groups,
padding=self.conv_padding,
bias_attr=False,
use_cudnn=self.use_cudnn,
act=None)
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
......@@ -110,6 +111,7 @@ class TensorRTSubgraphPassConvTransposeTest(InferencePassTest):
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = [1, 1]
self.use_cudnn = True
def test_check_output(self):
if core.is_compiled_with_cuda():
......@@ -126,6 +128,7 @@ class TensorRTSubgraphPassConvTransposeValidPaddingTest(
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = 'VALID'
self.use_cudnn = True
class TensorRTSubgraphPassConvTransposeSamePaddingTest(
......@@ -135,15 +138,27 @@ class TensorRTSubgraphPassConvTransposeSamePaddingTest(
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = 'SAME'
self.use_cudnn = True
class TensorRTSubgraphPassDepthwiseConvTransposeTest(
class TensorRTSubgraphPassConvTransposeMultiGroupTest(
TensorRTSubgraphPassConvTransposeTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_groups = 2
self.conv_padding = [1, 1]
self.use_cudnn = True
class TensorRTSubgraphPassDepthwiseConvTransposeTest(
TensorRTSubgraphPassConvTransposeTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 4
self.conv_groups = 6
self.conv_padding = [1, 1]
self.use_cudnn = False
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册