未验证 提交 96a8bbe7 编写于 作者: W wanghuancoder 提交者: GitHub

refine conv add for xpu (#48432)

上级 cbb1cfbb
......@@ -145,20 +145,9 @@ def _conv_nd(
data_format,
)
if bias is not None:
channel_dim = (
channel_dim + len(x.shape) if channel_dim < 0 else channel_dim
)
if isinstance(x, tuple):
x = x[0]
if isinstance(bias, tuple):
bias = bias[0]
if len(bias.shape) < len(x.shape):
bias = _C_ops.reshape(
bias,
[1 for i in range(channel_dim)]
+ bias.shape
+ [1 for i in range(len(x.shape) - channel_dim - 1)],
)
new_shape = [1] * len(x.shape)
new_shape[channel_dim] = -1
bias = bias.reshape(new_shape)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type():
with no_grad():
......@@ -182,16 +171,10 @@ def _conv_nd(
data_format,
)
if bias is not None:
channel_dim = (
channel_dim + len(x.shape) if channel_dim < 0 else channel_dim
)
tmp_bias = _C_ops.reshape(
bias,
[1 for i in range(channel_dim)]
+ bias.shape
+ [1 for i in range(len(x.shape) - channel_dim - 1)],
)
return _C_ops.add(pre_bias, tmp_bias)
new_shape = [1] * len(x.shape)
new_shape[channel_dim] = -1
bias = bias.reshape(new_shape)
return _C_ops.add(pre_bias, bias)
else:
return pre_bias
......@@ -207,14 +190,10 @@ def _conv_nd(
data_format,
)
if bias is not None:
channel_dim = (
channel_dim + len(x.shape) if channel_dim < 0 else channel_dim
)
tmp_bias = _C_ops.reshape(
bias,
bias.shape + [1 for i in range(len(x.shape) - channel_dim - 1)],
)
return _C_ops.add(pre_bias, tmp_bias)
new_shape = [1] * len(x.shape)
new_shape[channel_dim] = -1
bias = bias.reshape(new_shape)
return _C_ops.add(pre_bias, bias)
else:
return pre_bias
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册