From dcce54ea76be48cb3a6ac398b7d9569e996ac054 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Thu, 4 Mar 2021 16:24:41 +0800 Subject: [PATCH] improve performance of depthwise_conv2d (#31099) * improve performance of depthwise_conv2d * add unittest --- python/paddle/fluid/tests/unittests/test_conv2d_op.py | 11 +++++++++++ python/paddle/nn/functional/conv.py | 9 +++++++++ 2 files changed, 20 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 85bf18c8c84..9992efee1b3 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle import paddle.fluid.core as core import paddle.fluid as fluid from op_test import OpTest @@ -1328,6 +1329,16 @@ class TestConv2DAPI(unittest.TestCase): groups=1, data_format="NCHW") + def test_depthwise_conv2d(self): + x_var = paddle.uniform((2, 8, 8, 4), dtype='float32', min=-1., max=1.) + conv = paddle.nn.Conv2D( + in_channels=4, + out_channels=4, + kernel_size=(3, 3), + groups=4, + data_format='NHWC') + y_var = conv(x_var) + class TestConv2DAPI_Error(unittest.TestCase): def test_api(self): diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index eaa4dc4d4f2..75dc62e530d 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -110,6 +110,12 @@ def _conv_nd(x, use_mkldnn=False, name=None): + # Due to the poor performance of NHWC, we transpose the input to NCHW. + origin_format = data_format + if origin_format == "NHWC" and op_type == "depthwise_conv2d": + x = nn.transpose(x, perm=[0, 3, 1, 2]) + data_format = "NCHW" + channel_dim = 1 if in_dygraph_mode(): attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', @@ -154,6 +160,9 @@ def _conv_nd(x, else: out = pre_bias + if origin_format == "NHWC" and op_type == "depthwise_conv2d": + out = nn.transpose(out, perm=[0, 2, 3, 1]) + return out -- GitLab