From ff062a43f1bdf333eb85455966aa5b2cc687b14a Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Thu, 5 Aug 2021 17:14:59 +0800 Subject: [PATCH] fix output dtype for paddle.sum (#34313) * support bool dtype for paddle.sum --- .../dygraph_to_static/test_break_continue.py | 2 +- .../fluid/tests/unittests/test_reduce_op.py | 43 +++++------------- python/paddle/tensor/math.py | 44 ++++++++----------- 3 files changed, 31 insertions(+), 58 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py index 8423c056b2d..95b5235aaa3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py @@ -184,7 +184,7 @@ def test_optim_break_in_while(x): class TestContinueInFor(unittest.TestCase): def setUp(self): - self.input = np.zeros((1)).astype('int32') + self.input = np.zeros((1)).astype('int64') self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( ) else fluid.CPUPlace() self.init_dygraph_func() diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 2dd5bcb8113..04736614558 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -748,37 +748,6 @@ class TestReduceSumOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.reduce_sum, x2) -class API_TestSumOpError(unittest.TestCase): - def test_errors(self): - def test_dtype1(): - with fluid.program_guard(fluid.Program(), fluid.Program()): - data = fluid.data(name="data", shape=[10], dtype="float64") - paddle.sum(data, dtype="float32") - - self.assertRaises(ValueError, test_dtype1) - - def test_dtype2(): - with fluid.program_guard(fluid.Program(), fluid.Program()): - data = fluid.data(name="data", shape=[10], dtype="int64") - paddle.sum(data, dtype="int32") - - self.assertRaises(ValueError, test_dtype2) - - def test_dtype3(): - with fluid.program_guard(fluid.Program(), fluid.Program()): - data = fluid.data(name="data", shape=[10], dtype="float64") - paddle.sum(data, dtype="int32") - - self.assertRaises(ValueError, test_dtype3) - - def test_type(): - with fluid.program_guard(fluid.Program(), fluid.Program()): - data = fluid.data(name="data", shape=[10], dtype="int32") - paddle.sum(data, dtype="bool") - - self.assertRaises(TypeError, test_type) - - class API_TestSumOp(unittest.TestCase): def run_static(self, shape, @@ -805,14 +774,26 @@ class API_TestSumOp(unittest.TestCase): shape = [10, 10] axis = 1 + self.run_static(shape, "bool", axis, attr_dtype=None) + self.run_static(shape, "bool", axis, attr_dtype="int32") + self.run_static(shape, "bool", axis, attr_dtype="int64") + self.run_static(shape, "int32", axis, attr_dtype=None) self.run_static(shape, "int32", axis, attr_dtype="int32") self.run_static(shape, "int32", axis, attr_dtype="int64") + self.run_static(shape, "int64", axis, attr_dtype=None) + self.run_static(shape, "int64", axis, attr_dtype="int64") + self.run_static(shape, "int64", axis, attr_dtype="int32") + self.run_static(shape, "float32", axis, attr_dtype=None) self.run_static(shape, "float32", axis, attr_dtype="float32") self.run_static(shape, "float32", axis, attr_dtype="float64") + self.run_static(shape, "float64", axis, attr_dtype=None) + self.run_static(shape, "float64", axis, attr_dtype="float32") + self.run_static(shape, "float64", axis, attr_dtype="float64") + shape = [5, 5, 5] self.run_static(shape, "int32", (0, 1), attr_dtype="int32") self.run_static( diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 01be63c5dfe..394d46b9161 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -716,13 +716,15 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): else: reduce_all_flag = False - dtype_flag = False - if dtype is not None: - if dtype in ['float64', 'int64']: - if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \ - (convert_dtype(x.dtype) == "int32" and dtype == "int64"): - dtype_flag = True - + def get_dtype(x, dtype): + if dtype is not None: + return (True, dtype) + src_type = convert_dtype(x.dtype) + if src_type in ['bool','int32', 'int64']: + return (True, 'int64') + return (False, src_type) + + dtype_flag, dtype = get_dtype(x, dtype) if in_dygraph_mode(): axis = axis if axis != None and axis != [] else [0] if dtype_flag: @@ -740,27 +742,17 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): 'reduce_all': reduce_all_flag } - if dtype is not None: - if dtype in ['float64', 'int64']: - if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \ - (convert_dtype(x.dtype) == "int32" and dtype == "int64"): - attrs.update({ - 'in_dtype': x.dtype, - 'out_dtype': convert_np_dtype_to_dtype_(dtype) - }) + if dtype_flag: + attrs.update({ + 'in_dtype': x.dtype, + 'out_dtype': convert_np_dtype_to_dtype_(dtype) + }) check_variable_and_dtype( - x, 'x', ['float32', 'float64', 'int32', 'int64'], 'sum') - - if dtype is not None: - check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'sum') - x_dtype = convert_dtype(x.dtype) - - if (x_dtype == "float64" and dtype in ["float32", "int32"]) or \ - (x_dtype == "int64" and dtype == "int32"): - raise ValueError("The input(x)'s dtype is {} but the attr(dtype) of sum is {}, " - "which may cause data type overflows. Please reset attr(dtype) of sum." - .format(x_dtype, dtype)) + x, 'x', ['bool', 'float16', 'float32', 'float64', + 'int32', 'int64', 'complex64', 'complex128', + u'bool', u'float16', u'float32', u'float64', + u'int32', u'int64', u'complex64', u'complex128'], 'sum') check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum') -- GitLab