From 96a8bbe78fc16296cbb2d62dd66e2a0a8c0ebdef Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Wed, 30 Nov 2022 18:34:20 +0800 Subject: [PATCH] refine conv add for xpu (#48432) --- python/paddle/nn/functional/conv.py | 43 ++++++++--------------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 38c5064a1c..d29f91d035 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -145,20 +145,9 @@ def _conv_nd( data_format, ) if bias is not None: - channel_dim = ( - channel_dim + len(x.shape) if channel_dim < 0 else channel_dim - ) - if isinstance(x, tuple): - x = x[0] - if isinstance(bias, tuple): - bias = bias[0] - if len(bias.shape) < len(x.shape): - bias = _C_ops.reshape( - bias, - [1 for i in range(channel_dim)] - + bias.shape - + [1 for i in range(len(x.shape) - channel_dim - 1)], - ) + new_shape = [1] * len(x.shape) + new_shape[channel_dim] = -1 + bias = bias.reshape(new_shape) # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op if 'npu' in get_all_custom_device_type(): with no_grad(): @@ -182,16 +171,10 @@ def _conv_nd( data_format, ) if bias is not None: - channel_dim = ( - channel_dim + len(x.shape) if channel_dim < 0 else channel_dim - ) - tmp_bias = _C_ops.reshape( - bias, - [1 for i in range(channel_dim)] - + bias.shape - + [1 for i in range(len(x.shape) - channel_dim - 1)], - ) - return _C_ops.add(pre_bias, tmp_bias) + new_shape = [1] * len(x.shape) + new_shape[channel_dim] = -1 + bias = bias.reshape(new_shape) + return _C_ops.add(pre_bias, bias) else: return pre_bias @@ -207,14 +190,10 @@ def _conv_nd( data_format, ) if bias is not None: - channel_dim = ( - channel_dim + len(x.shape) if channel_dim < 0 else channel_dim - ) - tmp_bias = _C_ops.reshape( - bias, - bias.shape + [1 for i in range(len(x.shape) - channel_dim - 1)], - ) - return _C_ops.add(pre_bias, tmp_bias) + new_shape = [1] * len(x.shape) + new_shape[channel_dim] = -1 + bias = bias.reshape(new_shape) + return _C_ops.add(pre_bias, bias) else: return pre_bias -- GitLab