未验证 提交 a6854359 编写于 作者: L LielinJiang 提交者: GitHub

fix conv depthwise bug (#27278)

Fix conv deepwise bug when in_channels=1.
上级 bbad3414
...@@ -267,8 +267,8 @@ def conv1d(x, ...@@ -267,8 +267,8 @@ def conv1d(x,
dilation = utils.convert_to_list(dilation, 1, 'dilation') + [1] dilation = utils.convert_to_list(dilation, 1, 'dilation') + [1]
l_type = "conv2d" l_type = "conv2d"
if (num_channels == groups and num_filters % num_channels == 0 and if (num_channels == groups and num_channels != 1 and
not use_cudnn): num_filters % num_channels == 0 and not use_cudnn):
l_type = 'depthwise_conv2d' l_type = 'depthwise_conv2d'
use_cudnn = False use_cudnn = False
...@@ -491,7 +491,8 @@ def conv2d(x, ...@@ -491,7 +491,8 @@ def conv2d(x,
dilation = utils.convert_to_list(dilation, 2, 'dilation') dilation = utils.convert_to_list(dilation, 2, 'dilation')
l_type = "conv2d" l_type = "conv2d"
if (num_channels == groups and num_filters % num_channels == 0): if (num_channels == groups and num_channels != 1 and
num_filters % num_channels == 0):
l_type = 'depthwise_conv2d' l_type = 'depthwise_conv2d'
use_cudnn = False use_cudnn = False
...@@ -761,7 +762,8 @@ def conv_transpose1d(x, ...@@ -761,7 +762,8 @@ def conv_transpose1d(x,
op_type = 'conv2d_transpose' op_type = 'conv2d_transpose'
num_filters = weight.shape[1] num_filters = weight.shape[1]
if (num_channels == groups and num_filters == 1 and not use_cudnn): if (num_channels == groups and num_channels != 1 and num_filters == 1 and
not use_cudnn):
op_type = 'depthwise_conv2d_transpose' op_type = 'depthwise_conv2d_transpose'
use_cudnn = False use_cudnn = False
...@@ -1010,7 +1012,7 @@ def conv_transpose2d(x, ...@@ -1010,7 +1012,7 @@ def conv_transpose2d(x,
op_type = 'conv2d_transpose' op_type = 'conv2d_transpose'
num_filters = weight.shape[1] num_filters = weight.shape[1]
if (num_channels == groups and num_filters == 1): if (num_channels == groups and num_channels != 1 and num_filters == 1):
op_type = 'depthwise_conv2d_transpose' op_type = 'depthwise_conv2d_transpose'
use_cudnn = False use_cudnn = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册