From 8dba3d0b00a55bc4ec5d14f9a2a333d45a6270d0 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 13 Apr 2022 16:05:28 +0800 Subject: [PATCH] [Cherry-pick] Two detail fix prs (#41728) * [Eager] Remove elementwise add in conv (#41515) * remove elementwise add in conv * use reshape * fix warpctc grad kernel dep eror (#41598) --- paddle/phi/kernels/CMakeLists.txt | 2 +- python/paddle/nn/functional/conv.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 5aae2bbe36e..c063c389df8 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -61,7 +61,7 @@ kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_re kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute) kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute) kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale) -kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale) +kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale) # 4. auto parse and build kernel targets by cmake register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} ) diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 086ae789194..84aadbbac64 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -127,8 +127,12 @@ def _conv_nd(x, x, weight, stride, padding, padding_algorithm, groups, dilation, data_format, False, -1, False) if bias is not None: - out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) - return out + channel_dim = channel_dim + len( + x.shape) if channel_dim < 0 else channel_dim + tmp_bias = _C_ops.final_state_reshape( + bias, bias.shape + + [1 for i in range(len(x.shape) - channel_dim - 1)]) + return _C_ops.final_state_add(pre_bias, tmp_bias) else: return pre_bias if in_dynamic_mode(): -- GitLab