未验证 提交 4c9e34dc 编写于 作者: W Weilong Wu 提交者: GitHub

fix conv axis is not the default value -1 (#51486)

* fix add axis is not default -1

* polish conv logic

* Don't import paddle

* fix error
上级 76adcc80
......@@ -20,6 +20,7 @@ from paddle.device import (
is_compiled_with_rocm,
)
from paddle.fluid.framework import _global_flags, in_dygraph_mode
from paddle.tensor.manipulation import reshape
from paddle.tensor.math import _add_with_axis
from ...common_ops_import import Variable
......@@ -247,12 +248,30 @@ def _conv_nd(
)
if bias is not None:
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias], 'Y': [bias]},
outputs={'Out': [out]},
attrs={'axis': channel_dim, 'use_mkldnn': use_mkldnn},
)
x_shape = list(pre_bias.shape)
y_shape = list(bias.shape)
if channel_dim == -1 or len(x_shape) == len(y_shape):
helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias], 'Y': [bias]},
outputs={'Out': [out]},
attrs={'axis': -1, 'use_mkldnn': use_mkldnn},
)
else:
assert len(x_shape) > len(
y_shape
), 'The length of pre_bias must greater than the length of bias'
padding = len(x_shape) - len(y_shape) - channel_dim
bias = reshape(
bias, [1] * channel_dim + y_shape + [1] * padding
)
helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias], 'Y': [bias]},
outputs={'Out': [out]},
attrs={'axis': -1, 'use_mkldnn': use_mkldnn},
)
else:
out = pre_bias
return out
......@@ -1335,7 +1354,30 @@ def conv2d_transpose(
)
if bias is not None:
out = _add_with_axis(pre_bias, bias, axis=channel_dim)
out = helper.create_variable_for_type_inference(x.dtype)
x_shape = list(pre_bias.shape)
y_shape = list(bias.shape)
if channel_dim == -1 or len(x_shape) == len(y_shape):
helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias], 'Y': [bias]},
outputs={'Out': [out]},
attrs={'axis': -1, 'use_mkldnn': False},
)
else:
assert len(x_shape) > len(
y_shape
), 'The length of pre_bias must greater than the length of bias'
padding = len(x_shape) - len(y_shape) - channel_dim
bias = reshape(
bias, [1] * channel_dim + y_shape + [1] * padding
)
helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias], 'Y': [bias]},
outputs={'Out': [out]},
attrs={'axis': -1, 'use_mkldnn': False},
)
else:
out = pre_bias
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册