未验证 提交 96a8bbe7 编写于 作者: W wanghuancoder 提交者: GitHub

refine conv add for xpu (#48432)

上级 cbb1cfbb
...@@ -145,20 +145,9 @@ def _conv_nd( ...@@ -145,20 +145,9 @@ def _conv_nd(
data_format, data_format,
) )
if bias is not None: if bias is not None:
channel_dim = ( new_shape = [1] * len(x.shape)
channel_dim + len(x.shape) if channel_dim < 0 else channel_dim new_shape[channel_dim] = -1
) bias = bias.reshape(new_shape)
if isinstance(x, tuple):
x = x[0]
if isinstance(bias, tuple):
bias = bias[0]
if len(bias.shape) < len(x.shape):
bias = _C_ops.reshape(
bias,
[1 for i in range(channel_dim)]
+ bias.shape
+ [1 for i in range(len(x.shape) - channel_dim - 1)],
)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type(): if 'npu' in get_all_custom_device_type():
with no_grad(): with no_grad():
...@@ -182,16 +171,10 @@ def _conv_nd( ...@@ -182,16 +171,10 @@ def _conv_nd(
data_format, data_format,
) )
if bias is not None: if bias is not None:
channel_dim = ( new_shape = [1] * len(x.shape)
channel_dim + len(x.shape) if channel_dim < 0 else channel_dim new_shape[channel_dim] = -1
) bias = bias.reshape(new_shape)
tmp_bias = _C_ops.reshape( return _C_ops.add(pre_bias, bias)
bias,
[1 for i in range(channel_dim)]
+ bias.shape
+ [1 for i in range(len(x.shape) - channel_dim - 1)],
)
return _C_ops.add(pre_bias, tmp_bias)
else: else:
return pre_bias return pre_bias
...@@ -207,14 +190,10 @@ def _conv_nd( ...@@ -207,14 +190,10 @@ def _conv_nd(
data_format, data_format,
) )
if bias is not None: if bias is not None:
channel_dim = ( new_shape = [1] * len(x.shape)
channel_dim + len(x.shape) if channel_dim < 0 else channel_dim new_shape[channel_dim] = -1
) bias = bias.reshape(new_shape)
tmp_bias = _C_ops.reshape( return _C_ops.add(pre_bias, bias)
bias,
bias.shape + [1 for i in range(len(x.shape) - channel_dim - 1)],
)
return _C_ops.add(pre_bias, tmp_bias)
else: else:
return pre_bias return pre_bias
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册