diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 23eff58805ee95b679f3a9025c31ecc6e60424f3..816fd3266f184f3fa5cc50c0c7057130fa963251 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -20,6 +20,7 @@ from paddle.device import ( is_compiled_with_rocm, ) from paddle.fluid.framework import _global_flags, in_dygraph_mode +from paddle.tensor.manipulation import reshape from paddle.tensor.math import _add_with_axis from ...common_ops_import import Variable @@ -247,12 +248,30 @@ def _conv_nd( ) if bias is not None: out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type='elementwise_add', - inputs={'X': [pre_bias], 'Y': [bias]}, - outputs={'Out': [out]}, - attrs={'axis': channel_dim, 'use_mkldnn': use_mkldnn}, - ) + x_shape = list(pre_bias.shape) + y_shape = list(bias.shape) + if channel_dim == -1 or len(x_shape) == len(y_shape): + helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], 'Y': [bias]}, + outputs={'Out': [out]}, + attrs={'axis': -1, 'use_mkldnn': use_mkldnn}, + ) + else: + assert len(x_shape) > len( + y_shape + ), 'The length of pre_bias must greater than the length of bias' + padding = len(x_shape) - len(y_shape) - channel_dim + bias = reshape( + bias, [1] * channel_dim + y_shape + [1] * padding + ) + + helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], 'Y': [bias]}, + outputs={'Out': [out]}, + attrs={'axis': -1, 'use_mkldnn': use_mkldnn}, + ) else: out = pre_bias return out @@ -1335,7 +1354,30 @@ def conv2d_transpose( ) if bias is not None: - out = _add_with_axis(pre_bias, bias, axis=channel_dim) + out = helper.create_variable_for_type_inference(x.dtype) + x_shape = list(pre_bias.shape) + y_shape = list(bias.shape) + if channel_dim == -1 or len(x_shape) == len(y_shape): + helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], 'Y': [bias]}, + outputs={'Out': [out]}, + attrs={'axis': -1, 'use_mkldnn': False}, + ) + else: + assert len(x_shape) > len( + y_shape + ), 'The length of pre_bias must greater than the length of bias' + padding = len(x_shape) - len(y_shape) - channel_dim + bias = reshape( + bias, [1] * channel_dim + y_shape + [1] * padding + ) + helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], 'Y': [bias]}, + outputs={'Out': [out]}, + attrs={'axis': -1, 'use_mkldnn': False}, + ) else: out = pre_bias