From 8857e3911f5063d45ae8ceebddc09acc18e8ad5c Mon Sep 17 00:00:00 2001 From: wawltor <fangzeyang0904@hotmail.com> Date: Sun, 6 Sep 2020 17:22:41 +0800 Subject: [PATCH] add the dynamic dtype check for the argmin/argma update the check for the dtype check for the argmin, argmax --- paddle/fluid/operators/arg_min_max_op_base.h | 14 +++++++++++++- .../tests/unittests/test_arg_min_max_v2_op.py | 14 ++++++++++++++ python/paddle/tensor/search.py | 12 ++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 69365357084..c296ddcfbef 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -166,10 +166,22 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size())); + const int& dtype = ctx->Attrs().Get<int>("dtype"); + PADDLE_ENFORCE_EQ( + (dtype < 0 || dtype == 2 || dtype == 3), true, + platform::errors::InvalidArgument( + "The attribute of dtype in argmin/argmax must be [%s] or [%s], but " + "received [%s]", + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64), + paddle::framework::DataTypeToString( + static_cast<framework::proto::VarType::Type>(dtype)))); + auto x_rank = x_dims.size(); if (axis < 0) axis += x_rank; if (ctx->IsRuntime()) { - const int& dtype = ctx->Attrs().Get<int>("dtype"); if (dtype == framework::proto::VarType::INT32) { int64_t all_element_num = 0; if (flatten) { diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index 0fd9863948a..1b1b1d7c983 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -322,6 +322,20 @@ class TestArgMinMaxOpError(unittest.TestCase): self.assertRaises(TypeError, test_argmin_axis_type) + def test_argmax_dtype_type(): + data = paddle.static.data( + name="test_argmax", shape=[10], dtype="float32") + output = paddle.argmax(x=data, dtype=1) + + self.assertRaises(TypeError, test_argmax_dtype_type) + + def test_argmin_dtype_type(): + data = paddle.static.data( + name="test_argmin", shape=[10], dtype="float32") + output = paddle.argmin(x=data, dtype=1) + + self.assertRaises(TypeError, test_argmin_dtype_type) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 7543b091383..ce03d0ef15f 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -166,6 +166,12 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): raise TypeError( "The type of 'axis' must be int or None in argmax, but received %s." % (type(axis))) + + if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): + raise TypeError( + "the type of 'dtype' in argmax must be str or np.dtype, but received {}". + format(type(dtype))) + var_dtype = convert_np_dtype_to_dtype_(dtype) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') flatten = False @@ -238,6 +244,12 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): raise TypeError( "The type of 'axis' must be int or None in argmin, but received %s." % (type(axis))) + + if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): + raise TypeError( + "the type of 'dtype' in argmin must be str or np.dtype, but received {}". + format(dtype(dtype))) + var_dtype = convert_np_dtype_to_dtype_(dtype) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') flatten = False -- GitLab