From 368f1dda7e3c9ae5a50ba70344c4577215e6a6cf Mon Sep 17 00:00:00 2001 From: ykkk2333 <77383312+ykkk2333@users.noreply.github.com> Date: Mon, 11 Apr 2022 15:35:09 +0800 Subject: [PATCH] fix arg_max for int type, *test=kunlun (#41522) --- paddle/fluid/operators/arg_max_op_xpu.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/arg_max_op_xpu.cc b/paddle/fluid/operators/arg_max_op_xpu.cc index ba2ef81b5cd..e2acd84bd4e 100644 --- a/paddle/fluid/operators/arg_max_op_xpu.cc +++ b/paddle/fluid/operators/arg_max_op_xpu.cc @@ -28,12 +28,15 @@ class ArgMaxXPUKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto dtype = ctx.Attr("dtype"); PADDLE_ENFORCE_EQ( - (dtype < 0 || dtype == 3), true, + (dtype < 0 || dtype == 2 || dtype == 3), true, platform::errors::InvalidArgument( - "The attribute of dtype in xpu argmin/argmax must be [%s], but " + "The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], " + "but " "received [%s]", paddle::framework::DataTypeToString( framework::proto::VarType::INT64), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), paddle::framework::DataTypeToString( static_cast(dtype)))); -- GitLab