From 4c9e34dc2eb3ea516dca9ba4ae184ad6cb99aaa9 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 20 Mar 2023 19:07:02 +0800 Subject: [PATCH] fix conv axis is not the default value -1 (#51486) * fix add axis is not default -1 * polish conv logic * Don't import paddle * fix error --- python/paddle/nn/functional/conv.py | 56 +++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 23eff58805e..816fd3266f1 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -20,6 +20,7 @@ from paddle.device import ( is_compiled_with_rocm, ) from paddle.fluid.framework import _global_flags, in_dygraph_mode +from paddle.tensor.manipulation import reshape from paddle.tensor.math import _add_with_axis from ...common_ops_import import Variable @@ -247,12 +248,30 @@ def _conv_nd( ) if bias is not None: out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type='elementwise_add', - inputs={'X': [pre_bias], 'Y': [bias]}, - outputs={'Out': [out]}, - attrs={'axis': channel_dim, 'use_mkldnn': use_mkldnn}, - ) + x_shape = list(pre_bias.shape) + y_shape = list(bias.shape) + if channel_dim == -1 or len(x_shape) == len(y_shape): + helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], 'Y': [bias]}, + outputs={'Out': [out]}, + attrs={'axis': -1, 'use_mkldnn': use_mkldnn}, + ) + else: + assert len(x_shape) > len( + y_shape + ), 'The length of pre_bias must greater than the length of bias' + padding = len(x_shape) - len(y_shape) - channel_dim + bias = reshape( + bias, [1] * channel_dim + y_shape + [1] * padding + ) + + helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], 'Y': [bias]}, + outputs={'Out': [out]}, + attrs={'axis': -1, 'use_mkldnn': use_mkldnn}, + ) else: out = pre_bias return out @@ -1335,7 +1354,30 @@ def conv2d_transpose( ) if bias is not None: - out = _add_with_axis(pre_bias, bias, axis=channel_dim) + out = helper.create_variable_for_type_inference(x.dtype) + x_shape = list(pre_bias.shape) + y_shape = list(bias.shape) + if channel_dim == -1 or len(x_shape) == len(y_shape): + helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], 'Y': [bias]}, + outputs={'Out': [out]}, + attrs={'axis': -1, 'use_mkldnn': False}, + ) + else: + assert len(x_shape) > len( + y_shape + ), 'The length of pre_bias must greater than the length of bias' + padding = len(x_shape) - len(y_shape) - channel_dim + bias = reshape( + bias, [1] * channel_dim + y_shape + [1] * padding + ) + helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], 'Y': [bias]}, + outputs={'Out': [out]}, + attrs={'axis': -1, 'use_mkldnn': False}, + ) else: out = pre_bias -- GitLab