未验证 提交 ac48599d 编写于 作者: L liym27 提交者: GitHub

[cherry-pick2.0beta][Api2.0] sum: bug fix - support attr(dtype) is float32 or...

[cherry-pick2.0beta][Api2.0] sum: bug fix - support attr(dtype) is float32 or int32 and add ValueError (#26946) (#27047)
上级 88256df6
......@@ -475,87 +475,71 @@ class API_TestSumOpError(unittest.TestCase):
def test_errors(self):
def test_dtype1():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
paddle.sum(data, dtype="int32")
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="float32")
self.assertRaises(ValueError, test_dtype1)
def test_dtype2():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
paddle.sum(data, dtype="float32")
data = fluid.data(name="data", shape=[10], dtype="int64")
paddle.sum(data, dtype="int32")
self.assertRaises(ValueError, test_dtype2)
def test_dtype3():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int32")
paddle.sum(data, dtype="bool")
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="int32")
self.assertRaises(ValueError, test_dtype3)
def test_dtype4():
def test_type():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int32")
paddle.sum(data, dtype="int32")
paddle.sum(data, dtype="bool")
self.assertRaises(ValueError, test_dtype3)
self.assertRaises(TypeError, test_type)
class API_TestSumOp(unittest.TestCase):
def test_static(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=[10, 10], dtype="float32")
result_sum = paddle.sum(x=data, axis=1, dtype="float64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input_data = np.random.rand(10, 10).astype(np.float32)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum])
self.assertEqual(
(res == np.sum(input_data.astype(np.float64), axis=1)).all(), True)
def run_static(self,
shape,
x_dtype,
attr_axis,
attr_dtype=None,
np_axis=None):
if np_axis is None:
np_axis = attr_axis
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=[10, 10], dtype="int32")
result_sum = paddle.sum(x=data, axis=1, dtype="int64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int32)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum])
self.assertEqual(
(res == np.sum(input_data.astype(np.int64), axis=1)).all(), True)
data = fluid.data("data", shape=shape, dtype=x_dtype)
result_sum = paddle.sum(x=data, axis=attr_axis, dtype=attr_dtype)
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=[10, 10], dtype="int32")
result_sum = paddle.sum(x=data, axis=1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int32)
exe = fluid.Executor(fluid.CPUPlace())
input_data = np.random.rand(*shape).astype(x_dtype)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum])
self.assertEqual((res == np.sum(input_data, axis=1)).all(), True)
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=[10, 10], dtype="int32")
result_sum = paddle.sum(x=data, axis=1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int32)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum])
self.assertEqual((res == np.sum(input_data, axis=1)).all(), True)
self.assertTrue(
np.allclose(
res, np.sum(input_data.astype(attr_dtype), axis=np_axis)))
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data = np.random.randint(10, size=(5, 5, 5)).astype(np.int32)
data = fluid.data("data", shape=[5, 5, 5], dtype="int32")
sum1 = paddle.sum(x=data, axis=[0, 1])
sum2 = paddle.sum(x=data, axis=())
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res1, res2 = exe.run(feed={"data": input_data},
fetch_list=[sum1, sum2])
self.assertEqual((res1 == np.sum(input_data, axis=(0, 1))).all(), True)
self.assertEqual(
(res2 == np.sum(input_data, axis=(0, 1, 2))).all(), True)
def test_static(self):
shape = [10, 10]
axis = 1
self.run_static(shape, "int32", axis, attr_dtype=None)
self.run_static(shape, "int32", axis, attr_dtype="int32")
self.run_static(shape, "int32", axis, attr_dtype="int64")
self.run_static(shape, "float32", axis, attr_dtype=None)
self.run_static(shape, "float32", axis, attr_dtype="float32")
self.run_static(shape, "float32", axis, attr_dtype="float64")
shape = [5, 5, 5]
self.run_static(shape, "int32", (0, 1), attr_dtype="int32")
self.run_static(
shape, "int32", (), attr_dtype="int32", np_axis=(0, 1, 2))
def test_dygraph(self):
np_x = np.random.random([2, 3, 4]).astype('int32')
......
......@@ -760,7 +760,8 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
it's data type is the same as `x`.
Raises:
ValueError: The :attr:`dtype` must be float64 or int64.
ValueError: If the data type of `x` is float64, :attr:`dtype` can not be float32 or int32.
ValueError: If the data type of `x` is int64, :attr:`dtype` can not be int32.
TypeError: The type of :attr:`axis` must be int, list or tuple.
Examples:
......@@ -815,10 +816,6 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})
dtype_flag = True
else:
raise ValueError(
"The value of 'dtype' in sum op must be float64, int64, but received of {}".
format(dtype))
if in_dygraph_mode():
axis = axis if axis != None and axis != [] else [0]
......@@ -832,6 +829,17 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'reduce_all', reduce_all_flag)
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'sum')
if dtype is not None:
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'sum')
x_dtype = convert_dtype(x.dtype)
if (x_dtype == "float64" and dtype in ["float32", "int32"]) or \
(x_dtype == "int64" and dtype == "int32"):
raise ValueError("The input(x)'s dtype is {} but the attr(dtype) of sum is {}, "
"which may cause data type overflows. Please reset attr(dtype) of sum."
.format(x_dtype, dtype))
check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum')
helper = LayerHelper('sum', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册