未验证 提交 a0bbfbd4 编写于 作者: A Aganlengzi 提交者: GitHub

support fp16 softmax on custom place (#45177)

上级 e26f80ad
...@@ -62,10 +62,11 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -62,10 +62,11 @@ class SoftmaxOp : public framework::OperatorWithKernel {
platform::is_gpu_place(ctx.GetPlace()) || platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) || platform::is_npu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()) || platform::is_xpu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace()), platform::is_mlu_place(ctx.GetPlace()) ||
platform::is_custom_place(ctx.GetPlace()),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU place")); "float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
} }
return framework::OpKernelType( return framework::OpKernelType(
...@@ -176,9 +177,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -176,9 +177,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
if (!(platform::is_gpu_place(ctx.GetPlace()) || if (!(platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) || platform::is_npu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()) || platform::is_xpu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace()))) platform::is_mlu_place(ctx.GetPlace()) ||
platform::is_custom_place(ctx.GetPlace())))
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU place")); "float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
} }
return framework::OpKernelType( return framework::OpKernelType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册