未验证 提交 368f1dda 编写于 作者: Y ykkk2333 提交者: GitHub

fix arg_max for int type, *test=kunlun (#41522)

上级 9287d5a1
......@@ -28,12 +28,15 @@ class ArgMaxXPUKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto dtype = ctx.Attr<int>("dtype");
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 3), true,
(dtype < 0 || dtype == 2 || dtype == 3), true,
platform::errors::InvalidArgument(
"The attribute of dtype in xpu argmin/argmax must be [%s], but "
"The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], "
"but "
"received [%s]",
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype))));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册