未验证 提交 e6af7c0d 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] fix some bugs of npu op (#31739)

* fix softmax

* fix mean

* fix lookup_table_v2
上级 17862b72
......@@ -187,6 +187,4 @@ endif()
if(WITH_ASCEND_CL)
cc_test(gelu_op_npu_test SRCS gelu_op_npu_test.cc DEPS op_registry gelu_op scope device_context enforce executor)
cc_test(mean_op_npu_test SRCS mean_op_npu_test.cc DEPS op_registry mean_op scope device_context enforce executor)
endif()
......@@ -54,6 +54,7 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
auto *table_t = ctx.Input<framework::LoDTensor>("W");
auto *table_grad_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
table_grad_t->mutable_data<T>(ctx.GetPlace());
framework::NPUAttributeMap attr_input = {{"use_locking", true}};
auto runner = NpuOpRunner("ScatterAdd", {*table_t, *ids_t, *output_grad_t},
......
......@@ -206,9 +206,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"float16 can only be used on GPU place"));
if (!(platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace())))
PADDLE_THROW(platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU place"));
}
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册