未验证 提交 a0bbfbd4 编写于 作者: A Aganlengzi 提交者: GitHub

support fp16 softmax on custom place (#45177)

上级 e26f80ad
......@@ -62,10 +62,11 @@ class SoftmaxOp : public framework::OperatorWithKernel {
platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace()),
platform::is_mlu_place(ctx.GetPlace()) ||
platform::is_custom_place(ctx.GetPlace()),
true,
platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU place"));
"float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
}
return framework::OpKernelType(
......@@ -176,9 +177,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
if (!(platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace())))
platform::is_mlu_place(ctx.GetPlace()) ||
platform::is_custom_place(ctx.GetPlace())))
PADDLE_THROW(platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU place"));
"float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
}
return framework::OpKernelType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册