diff --git a/paddle/fluid/operators/arg_max_op_xpu.cc b/paddle/fluid/operators/arg_max_op_xpu.cc index ba2ef81b5cdf1d9d74a12b68b5b1404ca0592233..e2acd84bd4e9db9c7feeea6948ab37a8e562766e 100644 --- a/paddle/fluid/operators/arg_max_op_xpu.cc +++ b/paddle/fluid/operators/arg_max_op_xpu.cc @@ -28,12 +28,15 @@ class ArgMaxXPUKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto dtype = ctx.Attr("dtype"); PADDLE_ENFORCE_EQ( - (dtype < 0 || dtype == 3), true, + (dtype < 0 || dtype == 2 || dtype == 3), true, platform::errors::InvalidArgument( - "The attribute of dtype in xpu argmin/argmax must be [%s], but " + "The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], " + "but " "received [%s]", paddle::framework::DataTypeToString( framework::proto::VarType::INT64), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), paddle::framework::DataTypeToString( static_cast(dtype))));