From 8857e3911f5063d45ae8ceebddc09acc18e8ad5c Mon Sep 17 00:00:00 2001
From: wawltor <fangzeyang0904@hotmail.com>
Date: Sun, 6 Sep 2020 17:22:41 +0800
Subject: [PATCH] add the dynamic dtype check for the  argmin/argma

update the check for the dtype check for the argmin, argmax
---
 paddle/fluid/operators/arg_min_max_op_base.h       | 14 +++++++++++++-
 .../tests/unittests/test_arg_min_max_v2_op.py      | 14 ++++++++++++++
 python/paddle/tensor/search.py                     | 12 ++++++++++++
 3 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h
index 69365357084..c296ddcfbef 100644
--- a/paddle/fluid/operators/arg_min_max_op_base.h
+++ b/paddle/fluid/operators/arg_min_max_op_base.h
@@ -166,10 +166,22 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
         platform::errors::InvalidArgument(
             "'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
 
+    const int& dtype = ctx->Attrs().Get<int>("dtype");
+    PADDLE_ENFORCE_EQ(
+        (dtype < 0 || dtype == 2 || dtype == 3), true,
+        platform::errors::InvalidArgument(
+            "The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
+            "received [%s]",
+            paddle::framework::DataTypeToString(
+                framework::proto::VarType::INT32),
+            paddle::framework::DataTypeToString(
+                framework::proto::VarType::INT64),
+            paddle::framework::DataTypeToString(
+                static_cast<framework::proto::VarType::Type>(dtype))));
+
     auto x_rank = x_dims.size();
     if (axis < 0) axis += x_rank;
     if (ctx->IsRuntime()) {
-      const int& dtype = ctx->Attrs().Get<int>("dtype");
       if (dtype == framework::proto::VarType::INT32) {
         int64_t all_element_num = 0;
         if (flatten) {
diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
index 0fd9863948a..1b1b1d7c983 100644
--- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
+++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
@@ -322,6 +322,20 @@ class TestArgMinMaxOpError(unittest.TestCase):
 
             self.assertRaises(TypeError, test_argmin_axis_type)
 
+            def test_argmax_dtype_type():
+                data = paddle.static.data(
+                    name="test_argmax", shape=[10], dtype="float32")
+                output = paddle.argmax(x=data, dtype=1)
+
+            self.assertRaises(TypeError, test_argmax_dtype_type)
+
+            def test_argmin_dtype_type():
+                data = paddle.static.data(
+                    name="test_argmin", shape=[10], dtype="float32")
+                output = paddle.argmin(x=data, dtype=1)
+
+            self.assertRaises(TypeError, test_argmin_dtype_type)
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py
index 7543b091383..ce03d0ef15f 100644
--- a/python/paddle/tensor/search.py
+++ b/python/paddle/tensor/search.py
@@ -166,6 +166,12 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
         raise TypeError(
             "The type of 'axis'  must be int or None in argmax, but received %s."
             % (type(axis)))
+
+    if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)):
+        raise TypeError(
+            "the type of 'dtype' in argmax must be str or np.dtype, but received {}".
+            format(type(dtype)))
+
     var_dtype = convert_np_dtype_to_dtype_(dtype)
     check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
     flatten = False
@@ -238,6 +244,12 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
         raise TypeError(
             "The type of 'axis'  must be int or None in argmin, but received %s."
             % (type(axis)))
+
+    if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)):
+        raise TypeError(
+            "the type of 'dtype' in argmin must be str or np.dtype, but received {}".
+            format(dtype(dtype)))
+
     var_dtype = convert_np_dtype_to_dtype_(dtype)
     check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
     flatten = False
-- 
GitLab