未验证 提交 ac48599d 编写于 作者: L liym27 提交者: GitHub

[cherry-pick2.0beta][Api2.0] sum: bug fix - support attr(dtype) is float32 or...

[cherry-pick2.0beta][Api2.0] sum: bug fix - support attr(dtype) is float32 or int32 and add ValueError (#26946) (#27047)
上级 88256df6
...@@ -475,87 +475,71 @@ class API_TestSumOpError(unittest.TestCase): ...@@ -475,87 +475,71 @@ class API_TestSumOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
def test_dtype1(): def test_dtype1():
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32") data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="int32") paddle.sum(data, dtype="float32")
self.assertRaises(ValueError, test_dtype1) self.assertRaises(ValueError, test_dtype1)
def test_dtype2(): def test_dtype2():
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32") data = fluid.data(name="data", shape=[10], dtype="int64")
paddle.sum(data, dtype="float32") paddle.sum(data, dtype="int32")
self.assertRaises(ValueError, test_dtype2) self.assertRaises(ValueError, test_dtype2)
def test_dtype3(): def test_dtype3():
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int32") data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="bool") paddle.sum(data, dtype="int32")
self.assertRaises(ValueError, test_dtype3) self.assertRaises(ValueError, test_dtype3)
def test_dtype4(): def test_type():
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int32") data = fluid.data(name="data", shape=[10], dtype="int32")
paddle.sum(data, dtype="int32") paddle.sum(data, dtype="bool")
self.assertRaises(ValueError, test_dtype3) self.assertRaises(TypeError, test_type)
class API_TestSumOp(unittest.TestCase): class API_TestSumOp(unittest.TestCase):
def test_static(self): def run_static(self,
with fluid.program_guard(fluid.Program(), fluid.Program()): shape,
data = fluid.data("data", shape=[10, 10], dtype="float32") x_dtype,
result_sum = paddle.sum(x=data, axis=1, dtype="float64") attr_axis,
place = fluid.CPUPlace() attr_dtype=None,
exe = fluid.Executor(place) np_axis=None):
input_data = np.random.rand(10, 10).astype(np.float32) if np_axis is None:
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) np_axis = attr_axis
self.assertEqual(
(res == np.sum(input_data.astype(np.float64), axis=1)).all(), True)
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=[10, 10], dtype="int32") data = fluid.data("data", shape=shape, dtype=x_dtype)
result_sum = paddle.sum(x=data, axis=1, dtype="int64") result_sum = paddle.sum(x=data, axis=attr_axis, dtype=attr_dtype)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int32)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum])
self.assertEqual(
(res == np.sum(input_data.astype(np.int64), axis=1)).all(), True)
with fluid.program_guard(fluid.Program(), fluid.Program()): exe = fluid.Executor(fluid.CPUPlace())
data = fluid.data("data", shape=[10, 10], dtype="int32") input_data = np.random.rand(*shape).astype(x_dtype)
result_sum = paddle.sum(x=data, axis=1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int32)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum])
self.assertEqual((res == np.sum(input_data, axis=1)).all(), True)
with fluid.program_guard(fluid.Program(), fluid.Program()): self.assertTrue(
data = fluid.data("data", shape=[10, 10], dtype="int32") np.allclose(
result_sum = paddle.sum(x=data, axis=1) res, np.sum(input_data.astype(attr_dtype), axis=np_axis)))
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int32)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum])
self.assertEqual((res == np.sum(input_data, axis=1)).all(), True)
with fluid.program_guard(fluid.Program(), fluid.Program()): def test_static(self):
input_data = np.random.randint(10, size=(5, 5, 5)).astype(np.int32) shape = [10, 10]
data = fluid.data("data", shape=[5, 5, 5], dtype="int32") axis = 1
sum1 = paddle.sum(x=data, axis=[0, 1])
sum2 = paddle.sum(x=data, axis=()) self.run_static(shape, "int32", axis, attr_dtype=None)
self.run_static(shape, "int32", axis, attr_dtype="int32")
place = fluid.CPUPlace() self.run_static(shape, "int32", axis, attr_dtype="int64")
exe = fluid.Executor(place)
res1, res2 = exe.run(feed={"data": input_data}, self.run_static(shape, "float32", axis, attr_dtype=None)
fetch_list=[sum1, sum2]) self.run_static(shape, "float32", axis, attr_dtype="float32")
self.run_static(shape, "float32", axis, attr_dtype="float64")
self.assertEqual((res1 == np.sum(input_data, axis=(0, 1))).all(), True)
self.assertEqual( shape = [5, 5, 5]
(res2 == np.sum(input_data, axis=(0, 1, 2))).all(), True) self.run_static(shape, "int32", (0, 1), attr_dtype="int32")
self.run_static(
shape, "int32", (), attr_dtype="int32", np_axis=(0, 1, 2))
def test_dygraph(self): def test_dygraph(self):
np_x = np.random.random([2, 3, 4]).astype('int32') np_x = np.random.random([2, 3, 4]).astype('int32')
......
...@@ -760,7 +760,8 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -760,7 +760,8 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
it's data type is the same as `x`. it's data type is the same as `x`.
Raises: Raises:
ValueError: The :attr:`dtype` must be float64 or int64. ValueError: If the data type of `x` is float64, :attr:`dtype` can not be float32 or int32.
ValueError: If the data type of `x` is int64, :attr:`dtype` can not be int32.
TypeError: The type of :attr:`axis` must be int, list or tuple. TypeError: The type of :attr:`axis` must be int, list or tuple.
Examples: Examples:
...@@ -815,10 +816,6 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -815,10 +816,6 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'out_dtype': convert_np_dtype_to_dtype_(dtype) 'out_dtype': convert_np_dtype_to_dtype_(dtype)
}) })
dtype_flag = True dtype_flag = True
else:
raise ValueError(
"The value of 'dtype' in sum op must be float64, int64, but received of {}".
format(dtype))
if in_dygraph_mode(): if in_dygraph_mode():
axis = axis if axis != None and axis != [] else [0] axis = axis if axis != None and axis != [] else [0]
...@@ -832,6 +829,17 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -832,6 +829,17 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'reduce_all', reduce_all_flag) 'reduce_all', reduce_all_flag)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'sum') x, 'x', ['float32', 'float64', 'int32', 'int64'], 'sum')
if dtype is not None:
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'sum')
x_dtype = convert_dtype(x.dtype)
if (x_dtype == "float64" and dtype in ["float32", "int32"]) or \
(x_dtype == "int64" and dtype == "int32"):
raise ValueError("The input(x)'s dtype is {} but the attr(dtype) of sum is {}, "
"which may cause data type overflows. Please reset attr(dtype) of sum."
.format(x_dtype, dtype))
check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum') check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum')
helper = LayerHelper('sum', **locals()) helper = LayerHelper('sum', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册