From e6af7c0dd8597ac46b6cee14b8dcc80974ec315e Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 19 Mar 2021 15:04:05 +0800 Subject: [PATCH] [NPU] fix some bugs of npu op (#31739) * fix softmax * fix mean * fix lookup_table_v2 --- paddle/fluid/operators/CMakeLists.txt | 2 -- paddle/fluid/operators/lookup_table_v2_op_npu.cc | 1 + paddle/fluid/operators/softmax_op.cc | 7 ++++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 618d3109856..6fe18f24794 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -187,6 +187,4 @@ endif() if(WITH_ASCEND_CL) cc_test(gelu_op_npu_test SRCS gelu_op_npu_test.cc DEPS op_registry gelu_op scope device_context enforce executor) -cc_test(mean_op_npu_test SRCS mean_op_npu_test.cc DEPS op_registry mean_op scope device_context enforce executor) endif() - diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index e7cc048ed3c..8ab4d70fd3f 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -54,6 +54,7 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel { auto *table_t = ctx.Input("W"); auto *table_grad_t = ctx.Output(framework::GradVarName("W")); + table_grad_t->mutable_data(ctx.GetPlace()); framework::NPUAttributeMap attr_input = {{"use_locking", true}}; auto runner = NpuOpRunner("ScatterAdd", {*table_t, *ids_t, *output_grad_t}, diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 7d4194227b4..a38800da87f 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -206,9 +206,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); if (input_data_type == framework::proto::VarType::FP16) { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::InvalidArgument( - "float16 can only be used on GPU place")); + if (!(platform::is_gpu_place(ctx.GetPlace()) || + platform::is_npu_place(ctx.GetPlace()))) + PADDLE_THROW(platform::errors::InvalidArgument( + "float16 can only be used on GPU/NPU place")); } return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, -- GitLab