未验证 提交 e92c3b26 编写于 作者: W wawltor 提交者: GitHub

cherry-pick PR, add the dynamic dtype check for the argmin/argmax (#27083)

update the check for the dtype check for the argmin, argmax
上级 a497123e
......@@ -166,10 +166,22 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
const int& dtype = ctx->Attrs().Get<int>("dtype");
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3), true,
platform::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype))));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (ctx->IsRuntime()) {
const int& dtype = ctx->Attrs().Get<int>("dtype");
if (dtype == framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
......
......@@ -322,6 +322,20 @@ class TestArgMinMaxOpError(unittest.TestCase):
self.assertRaises(TypeError, test_argmin_axis_type)
def test_argmax_dtype_type():
data = paddle.static.data(
name="test_argmax", shape=[10], dtype="float32")
output = paddle.argmax(x=data, dtype=1)
self.assertRaises(TypeError, test_argmax_dtype_type)
def test_argmin_dtype_type():
data = paddle.static.data(
name="test_argmin", shape=[10], dtype="float32")
output = paddle.argmin(x=data, dtype=1)
self.assertRaises(TypeError, test_argmin_dtype_type)
if __name__ == '__main__':
unittest.main()
......@@ -166,6 +166,12 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
raise TypeError(
"The type of 'axis' must be int or None in argmax, but received %s."
% (type(axis)))
if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)):
raise TypeError(
"the type of 'dtype' in argmax must be str or np.dtype, but received {}".
format(type(dtype)))
var_dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
flatten = False
......@@ -238,6 +244,12 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
raise TypeError(
"The type of 'axis' must be int or None in argmin, but received %s."
% (type(axis)))
if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)):
raise TypeError(
"the type of 'dtype' in argmin must be str or np.dtype, but received {}".
format(dtype(dtype)))
var_dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
flatten = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册