未验证 提交 dcce54ea 编写于 作者: Z Zhang Ting 提交者: GitHub

improve performance of depthwise_conv2d (#31099)

* improve performance of depthwise_conv2d

* add unittest
上级 4d6d2db8
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test import OpTest
......@@ -1328,6 +1329,16 @@ class TestConv2DAPI(unittest.TestCase):
groups=1,
data_format="NCHW")
def test_depthwise_conv2d(self):
x_var = paddle.uniform((2, 8, 8, 4), dtype='float32', min=-1., max=1.)
conv = paddle.nn.Conv2D(
in_channels=4,
out_channels=4,
kernel_size=(3, 3),
groups=4,
data_format='NHWC')
y_var = conv(x_var)
class TestConv2DAPI_Error(unittest.TestCase):
def test_api(self):
......
......@@ -110,6 +110,12 @@ def _conv_nd(x,
use_mkldnn=False,
name=None):
# Due to the poor performance of NHWC, we transpose the input to NCHW.
origin_format = data_format
if origin_format == "NHWC" and op_type == "depthwise_conv2d":
x = nn.transpose(x, perm=[0, 3, 1, 2])
data_format = "NCHW"
channel_dim = 1
if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn',
......@@ -154,6 +160,9 @@ def _conv_nd(x,
else:
out = pre_bias
if origin_format == "NHWC" and op_type == "depthwise_conv2d":
out = nn.transpose(out, perm=[0, 2, 3, 1])
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册