未验证 提交 8dba3d0b 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick] Two detail fix prs (#41728)

* [Eager] Remove elementwise add in conv (#41515)

* remove elementwise add in conv

* use reshape

* fix warpctc grad kernel dep eror (#41598)
上级 0053ba8a
...@@ -61,7 +61,7 @@ kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_re ...@@ -61,7 +61,7 @@ kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_re
kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute) kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute) kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale) kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale) kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
# 4. auto parse and build kernel targets by cmake # 4. auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} ) register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )
......
...@@ -127,8 +127,12 @@ def _conv_nd(x, ...@@ -127,8 +127,12 @@ def _conv_nd(x,
x, weight, stride, padding, padding_algorithm, groups, dilation, x, weight, stride, padding, padding_algorithm, groups, dilation,
data_format, False, -1, False) data_format, False, -1, False)
if bias is not None: if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) channel_dim = channel_dim + len(
return out x.shape) if channel_dim < 0 else channel_dim
tmp_bias = _C_ops.final_state_reshape(
bias, bias.shape +
[1 for i in range(len(x.shape) - channel_dim - 1)])
return _C_ops.final_state_add(pre_bias, tmp_bias)
else: else:
return pre_bias return pre_bias
if in_dynamic_mode(): if in_dynamic_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册