未验证 提交 dcce54ea 编写于 作者: Z Zhang Ting 提交者: GitHub

improve performance of depthwise_conv2d (#31099)

* improve performance of depthwise_conv2d

* add unittest
上级 4d6d2db8
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
...@@ -1328,6 +1329,16 @@ class TestConv2DAPI(unittest.TestCase): ...@@ -1328,6 +1329,16 @@ class TestConv2DAPI(unittest.TestCase):
groups=1, groups=1,
data_format="NCHW") data_format="NCHW")
def test_depthwise_conv2d(self):
x_var = paddle.uniform((2, 8, 8, 4), dtype='float32', min=-1., max=1.)
conv = paddle.nn.Conv2D(
in_channels=4,
out_channels=4,
kernel_size=(3, 3),
groups=4,
data_format='NHWC')
y_var = conv(x_var)
class TestConv2DAPI_Error(unittest.TestCase): class TestConv2DAPI_Error(unittest.TestCase):
def test_api(self): def test_api(self):
......
...@@ -110,6 +110,12 @@ def _conv_nd(x, ...@@ -110,6 +110,12 @@ def _conv_nd(x,
use_mkldnn=False, use_mkldnn=False,
name=None): name=None):
# Due to the poor performance of NHWC, we transpose the input to NCHW.
origin_format = data_format
if origin_format == "NHWC" and op_type == "depthwise_conv2d":
x = nn.transpose(x, perm=[0, 3, 1, 2])
data_format = "NCHW"
channel_dim = 1
if in_dygraph_mode(): if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn',
...@@ -154,6 +160,9 @@ def _conv_nd(x, ...@@ -154,6 +160,9 @@ def _conv_nd(x,
else: else:
out = pre_bias out = pre_bias
if origin_format == "NHWC" and op_type == "depthwise_conv2d":
out = nn.transpose(out, perm=[0, 2, 3, 1])
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册