未验证 提交 ff062a43 编写于 作者: G Guoxia Wang 提交者: GitHub

fix output dtype for paddle.sum (#34313)

* support bool dtype for paddle.sum
上级 a842828a
......@@ -184,7 +184,7 @@ def test_optim_break_in_while(x):
class TestContinueInFor(unittest.TestCase):
def setUp(self):
self.input = np.zeros((1)).astype('int32')
self.input = np.zeros((1)).astype('int64')
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.init_dygraph_func()
......
......@@ -748,37 +748,6 @@ class TestReduceSumOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.reduce_sum, x2)
class API_TestSumOpError(unittest.TestCase):
def test_errors(self):
def test_dtype1():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="float32")
self.assertRaises(ValueError, test_dtype1)
def test_dtype2():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int64")
paddle.sum(data, dtype="int32")
self.assertRaises(ValueError, test_dtype2)
def test_dtype3():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="int32")
self.assertRaises(ValueError, test_dtype3)
def test_type():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int32")
paddle.sum(data, dtype="bool")
self.assertRaises(TypeError, test_type)
class API_TestSumOp(unittest.TestCase):
def run_static(self,
shape,
......@@ -805,14 +774,26 @@ class API_TestSumOp(unittest.TestCase):
shape = [10, 10]
axis = 1
self.run_static(shape, "bool", axis, attr_dtype=None)
self.run_static(shape, "bool", axis, attr_dtype="int32")
self.run_static(shape, "bool", axis, attr_dtype="int64")
self.run_static(shape, "int32", axis, attr_dtype=None)
self.run_static(shape, "int32", axis, attr_dtype="int32")
self.run_static(shape, "int32", axis, attr_dtype="int64")
self.run_static(shape, "int64", axis, attr_dtype=None)
self.run_static(shape, "int64", axis, attr_dtype="int64")
self.run_static(shape, "int64", axis, attr_dtype="int32")
self.run_static(shape, "float32", axis, attr_dtype=None)
self.run_static(shape, "float32", axis, attr_dtype="float32")
self.run_static(shape, "float32", axis, attr_dtype="float64")
self.run_static(shape, "float64", axis, attr_dtype=None)
self.run_static(shape, "float64", axis, attr_dtype="float32")
self.run_static(shape, "float64", axis, attr_dtype="float64")
shape = [5, 5, 5]
self.run_static(shape, "int32", (0, 1), attr_dtype="int32")
self.run_static(
......
......@@ -716,13 +716,15 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
else:
reduce_all_flag = False
dtype_flag = False
def get_dtype(x, dtype):
if dtype is not None:
if dtype in ['float64', 'int64']:
if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \
(convert_dtype(x.dtype) == "int32" and dtype == "int64"):
dtype_flag = True
return (True, dtype)
src_type = convert_dtype(x.dtype)
if src_type in ['bool','int32', 'int64']:
return (True, 'int64')
return (False, src_type)
dtype_flag, dtype = get_dtype(x, dtype)
if in_dygraph_mode():
axis = axis if axis != None and axis != [] else [0]
if dtype_flag:
......@@ -740,27 +742,17 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'reduce_all': reduce_all_flag
}
if dtype is not None:
if dtype in ['float64', 'int64']:
if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \
(convert_dtype(x.dtype) == "int32" and dtype == "int64"):
if dtype_flag:
attrs.update({
'in_dtype': x.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'sum')
if dtype is not None:
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'sum')
x_dtype = convert_dtype(x.dtype)
if (x_dtype == "float64" and dtype in ["float32", "int32"]) or \
(x_dtype == "int64" and dtype == "int32"):
raise ValueError("The input(x)'s dtype is {} but the attr(dtype) of sum is {}, "
"which may cause data type overflows. Please reset attr(dtype) of sum."
.format(x_dtype, dtype))
x, 'x', ['bool', 'float16', 'float32', 'float64',
'int32', 'int64', 'complex64', 'complex128',
u'bool', u'float16', u'float32', u'float64',
u'int32', u'int64', u'complex64', u'complex128'], 'sum')
check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册