diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 3d85c9b0a45d60efd2353b1738df33cbdaf1f101..d468e4a17f6b349feb7e2e4feeb0112421c78ce8 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -62,10 +62,11 @@ class SoftmaxOp : public framework::OperatorWithKernel { platform::is_gpu_place(ctx.GetPlace()) || platform::is_npu_place(ctx.GetPlace()) || platform::is_xpu_place(ctx.GetPlace()) || - platform::is_mlu_place(ctx.GetPlace()), + platform::is_mlu_place(ctx.GetPlace()) || + platform::is_custom_place(ctx.GetPlace()), true, platform::errors::InvalidArgument( - "float16 can only be used on GPU/NPU/XPU/MLU place")); + "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } return framework::OpKernelType( @@ -176,9 +177,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { if (!(platform::is_gpu_place(ctx.GetPlace()) || platform::is_npu_place(ctx.GetPlace()) || platform::is_xpu_place(ctx.GetPlace()) || - platform::is_mlu_place(ctx.GetPlace()))) + platform::is_mlu_place(ctx.GetPlace()) || + platform::is_custom_place(ctx.GetPlace()))) PADDLE_THROW(platform::errors::InvalidArgument( - "float16 can only be used on GPU/NPU/XPU/MLU place")); + "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } return framework::OpKernelType(