未验证 提交 413a743e 编写于 作者: T tanzhipeng 提交者: GitHub

remove unnecessary constant fill in sequence conv test=kunlun. (#40126)

上级 81d4142b
......@@ -184,9 +184,6 @@ class SequenceConvGradXPUKernel : public framework::OpKernel<T> {
col_data, paddle::platform::errors::Fatal("XPU memory is not enough"));
if (in_g || filter_g) {
int r = xpu::constant<T>(xpu_context, col_data, col_numel, T(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
bool trans_a = false;
bool trans_b = true;
int m = out_g->dims()[0];
......@@ -208,7 +205,7 @@ class SequenceConvGradXPUKernel : public framework::OpKernel<T> {
const T* data_b = filter->data<T>();
T* data_c = col_data;
r = xpu::fc_fusion<T, T, T, int32_t>(
int r = xpu::fc_fusion<T, T, T, int32_t>(
xpu_context, data_a, data_b, data_c, m, n, k, trans_a, trans_b,
nullptr, nullptr, nullptr, lda, ldb, ldc, alpha, beta, nullptr,
xpu::Activation_t::LINEAR);
......@@ -222,7 +219,6 @@ class SequenceConvGradXPUKernel : public framework::OpKernel<T> {
in_g->mutable_data<T>(context.GetPlace());
in_g->set_lod(in->lod());
xpu::constant<T>(xpu_context, in_g->data<T>(), in_g->numel(), T(0));
int r = xpu::sequence_context_projection_grad<T, int>(
xpu_context, in_g->data<T>(), col_data, nullptr, lodx, sequence_width,
......@@ -232,8 +228,6 @@ class SequenceConvGradXPUKernel : public framework::OpKernel<T> {
if (filter_g) {
filter_g->mutable_data<T>(context.GetPlace());
xpu::constant<T>(xpu_context, filter_g->data<T>(), filter_g->numel(),
T(0));
int r = xpu::sequence_context_projection<T, int>(
xpu_context, in->data<T>(), col_data, nullptr, lodx, sequence_width,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册