diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 618d31098563b95bbbf3a413c7b961ace428e9cf..6fe18f2479478a49819da2608dc7c3a0bf5d3017 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -187,6 +187,4 @@ endif() if(WITH_ASCEND_CL) cc_test(gelu_op_npu_test SRCS gelu_op_npu_test.cc DEPS op_registry gelu_op scope device_context enforce executor) -cc_test(mean_op_npu_test SRCS mean_op_npu_test.cc DEPS op_registry mean_op scope device_context enforce executor) endif() - diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index e7cc048ed3ce4b7986e74ecebc3c6603f4d5553a..8ab4d70fd3ff6ba96f5eb434899ad618bf800d77 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -54,6 +54,7 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel { auto *table_t = ctx.Input("W"); auto *table_grad_t = ctx.Output(framework::GradVarName("W")); + table_grad_t->mutable_data(ctx.GetPlace()); framework::NPUAttributeMap attr_input = {{"use_locking", true}}; auto runner = NpuOpRunner("ScatterAdd", {*table_t, *ids_t, *output_grad_t}, diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 7d4194227b4cd4b989d2c4a274cd730a2b78f90a..a38800da87fd1e07b569f649ca1d2149d46b8406 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -206,9 +206,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); if (input_data_type == framework::proto::VarType::FP16) { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::InvalidArgument( - "float16 can only be used on GPU place")); + if (!(platform::is_gpu_place(ctx.GetPlace()) || + platform::is_npu_place(ctx.GetPlace()))) + PADDLE_THROW(platform::errors::InvalidArgument( + "float16 can only be used on GPU/NPU place")); } return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,