diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 85bf18c8c84eb72e15bdd6fcd84694490b964e2b..9992efee1b3053fec7515dd7ce063499af22ca16 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle import paddle.fluid.core as core import paddle.fluid as fluid from op_test import OpTest @@ -1328,6 +1329,16 @@ class TestConv2DAPI(unittest.TestCase): groups=1, data_format="NCHW") + def test_depthwise_conv2d(self): + x_var = paddle.uniform((2, 8, 8, 4), dtype='float32', min=-1., max=1.) + conv = paddle.nn.Conv2D( + in_channels=4, + out_channels=4, + kernel_size=(3, 3), + groups=4, + data_format='NHWC') + y_var = conv(x_var) + class TestConv2DAPI_Error(unittest.TestCase): def test_api(self): diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index eaa4dc4d4f2cd959f20956a2c463655868a643eb..75dc62e530d0db81ee4126dc76918e2f08713d30 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -110,6 +110,12 @@ def _conv_nd(x, use_mkldnn=False, name=None): + # Due to the poor performance of NHWC, we transpose the input to NCHW. + origin_format = data_format + if origin_format == "NHWC" and op_type == "depthwise_conv2d": + x = nn.transpose(x, perm=[0, 3, 1, 2]) + data_format = "NCHW" + channel_dim = 1 if in_dygraph_mode(): attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', @@ -154,6 +160,9 @@ def _conv_nd(x, else: out = pre_bias + if origin_format == "NHWC" and op_type == "depthwise_conv2d": + out = nn.transpose(out, perm=[0, 2, 3, 1]) + return out