diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 38c5064a1cfc394162d37336cd188c413045df2c..d29f91d035f288b3ec658ba9ba242befa181f659 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -145,20 +145,9 @@ def _conv_nd( data_format, ) if bias is not None: - channel_dim = ( - channel_dim + len(x.shape) if channel_dim < 0 else channel_dim - ) - if isinstance(x, tuple): - x = x[0] - if isinstance(bias, tuple): - bias = bias[0] - if len(bias.shape) < len(x.shape): - bias = _C_ops.reshape( - bias, - [1 for i in range(channel_dim)] - + bias.shape - + [1 for i in range(len(x.shape) - channel_dim - 1)], - ) + new_shape = [1] * len(x.shape) + new_shape[channel_dim] = -1 + bias = bias.reshape(new_shape) # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op if 'npu' in get_all_custom_device_type(): with no_grad(): @@ -182,16 +171,10 @@ def _conv_nd( data_format, ) if bias is not None: - channel_dim = ( - channel_dim + len(x.shape) if channel_dim < 0 else channel_dim - ) - tmp_bias = _C_ops.reshape( - bias, - [1 for i in range(channel_dim)] - + bias.shape - + [1 for i in range(len(x.shape) - channel_dim - 1)], - ) - return _C_ops.add(pre_bias, tmp_bias) + new_shape = [1] * len(x.shape) + new_shape[channel_dim] = -1 + bias = bias.reshape(new_shape) + return _C_ops.add(pre_bias, bias) else: return pre_bias @@ -207,14 +190,10 @@ def _conv_nd( data_format, ) if bias is not None: - channel_dim = ( - channel_dim + len(x.shape) if channel_dim < 0 else channel_dim - ) - tmp_bias = _C_ops.reshape( - bias, - bias.shape + [1 for i in range(len(x.shape) - channel_dim - 1)], - ) - return _C_ops.add(pre_bias, tmp_bias) + new_shape = [1] * len(x.shape) + new_shape[channel_dim] = -1 + bias = bias.reshape(new_shape) + return _C_ops.add(pre_bias, bias) else: return pre_bias