From 2ad1e4c708dcfcc6bb56efb5717f4ad01140697f Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 30 Jul 2021 16:50:28 +0800 Subject: [PATCH] [NPU] disable EmbeddingDenseGrad temporarily (#34498) --- .../fluid/operators/lookup_table_v2_op_npu.cc | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index 2a8f4746234..c65fa634070 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -65,8 +65,9 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel { ctx.template device_context() .stream(); - int embedding_dim = table_grad_t->dims()[1]; + /* EmbeddingDenseGrad has bug on large shape, temporarily disable it. + int embedding_dim = table_grad_t->dims()[1]; if (embedding_dim % 32 == 0) { // NOTE(pangyoki): The embedding_dim of Tensor used in // EmbeddingDenseGrad must be an integer multiple of 32. @@ -77,19 +78,21 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel { {"padding_idx", -1}, {"scale_grad_by_freq", false}}); runner.Run(stream); - } else { - const auto &runner_zeros = - NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t}); - runner_zeros.Run(stream); - - // NOTE(zhiqiu): It seems in cann 20.1, the first input and output - // can be different tensor, but in cann 20.2+, it does inplace operation. - // Thus, the first input and output should be same tensor. - const auto &runner_scatter = - NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t}, - {*table_grad_t}, {{"use_locking", true}}); - runner_scatter.Run(stream); + return; } + */ + + const auto &runner_zeros = + NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t}); + runner_zeros.Run(stream); + + // NOTE(zhiqiu): It seems in cann 20.1, the first input and output + // can be different tensor, but in cann 20.2+, it does inplace operation. + // Thus, the first input and output should be same tensor. + const auto &runner_scatter = + NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t}, + {*table_grad_t}, {{"use_locking", true}}); + runner_scatter.Run(stream); } }; } // namespace operators -- GitLab