未验证 提交 368f1dda 编写于 作者: Y ykkk2333 提交者: GitHub

fix arg_max for int type, *test=kunlun (#41522)

上级 9287d5a1
...@@ -28,12 +28,15 @@ class ArgMaxXPUKernel : public framework::OpKernel<T> { ...@@ -28,12 +28,15 @@ class ArgMaxXPUKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
auto dtype = ctx.Attr<int>("dtype"); auto dtype = ctx.Attr<int>("dtype");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 3), true, (dtype < 0 || dtype == 2 || dtype == 3), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The attribute of dtype in xpu argmin/argmax must be [%s], but " "The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], "
"but "
"received [%s]", "received [%s]",
paddle::framework::DataTypeToString( paddle::framework::DataTypeToString(
framework::proto::VarType::INT64), framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString( paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype)))); static_cast<framework::proto::VarType::Type>(dtype))));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册