未验证 提交 6cd67a81 编写于 作者: L liym27 提交者: GitHub

[API 2.0] Fix api sum:(1)input->x;(2)dim->axis;(3)keep_dim->keepdim (#26337)

* 1.Fix api sum:(1) input->sum; (2)dim->axis; (3)keep_dim->keepdim.

* 2. fix bug when len(axis) == len(x.shape). 
上级 029390b1
...@@ -580,10 +580,10 @@ class API_TestSumOpError(unittest.TestCase): ...@@ -580,10 +580,10 @@ class API_TestSumOpError(unittest.TestCase):
class API_TestSumOp(unittest.TestCase): class API_TestSumOp(unittest.TestCase):
def test_1(self): def test_static(self):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=[10, 10], dtype="float32") data = fluid.data("data", shape=[10, 10], dtype="float32")
result_sum = paddle.sum(input=data, dim=1, dtype="float64") result_sum = paddle.sum(x=data, axis=1, dtype="float64")
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
input_data = np.random.rand(10, 10).astype(np.float32) input_data = np.random.rand(10, 10).astype(np.float32)
...@@ -593,7 +593,7 @@ class API_TestSumOp(unittest.TestCase): ...@@ -593,7 +593,7 @@ class API_TestSumOp(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=[10, 10], dtype="int32") data = fluid.data("data", shape=[10, 10], dtype="int32")
result_sum = paddle.sum(input=data, dim=1, dtype="int64") result_sum = paddle.sum(x=data, axis=1, dtype="int64")
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int32) input_data = np.random.randint(10, size=(10, 10)).astype(np.int32)
...@@ -603,7 +603,7 @@ class API_TestSumOp(unittest.TestCase): ...@@ -603,7 +603,7 @@ class API_TestSumOp(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=[10, 10], dtype="int32") data = fluid.data("data", shape=[10, 10], dtype="int32")
result_sum = paddle.sum(input=data, dim=1) result_sum = paddle.sum(x=data, axis=1)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int32) input_data = np.random.randint(10, size=(10, 10)).astype(np.int32)
...@@ -612,20 +612,41 @@ class API_TestSumOp(unittest.TestCase): ...@@ -612,20 +612,41 @@ class API_TestSumOp(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=[10, 10], dtype="int32") data = fluid.data("data", shape=[10, 10], dtype="int32")
result_sum = paddle.sum(input=data, dim=1) result_sum = paddle.sum(x=data, axis=1)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int32) input_data = np.random.randint(10, size=(10, 10)).astype(np.int32)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum])
self.assertEqual((res == np.sum(input_data, axis=1)).all(), True) self.assertEqual((res == np.sum(input_data, axis=1)).all(), True)
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data = np.random.randint(10, size=(5, 5, 5)).astype(np.int32)
data = fluid.data("data", shape=[5, 5, 5], dtype="int32")
sum1 = paddle.sum(x=data, axis=[0, 1])
sum2 = paddle.sum(x=data, axis=())
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res1, res2 = exe.run(feed={"data": input_data},
fetch_list=[sum1, sum2])
self.assertEqual((res1 == np.sum(input_data, axis=(0, 1))).all(), True)
self.assertEqual(
(res2 == np.sum(input_data, axis=(0, 1, 2))).all(), True)
def test_dygraph(self):
np_x = np.random.random([2, 3, 4]).astype('int32')
with fluid.dygraph.guard(): with fluid.dygraph.guard():
np_x = np.array([10, 10]).astype('float64')
x = fluid.dygraph.to_variable(np_x) x = fluid.dygraph.to_variable(np_x)
z = paddle.sum(x, dim=0) out0 = paddle.sum(x).numpy()
np_z = z.numpy() out1 = paddle.sum(x, axis=0).numpy()
z_expected = np.array(np.sum(np_x, axis=0)) out2 = paddle.sum(x, axis=(0, 1)).numpy()
self.assertEqual((np_z == z_expected).all(), True) out3 = paddle.sum(x, axis=(0, 1, 2)).numpy()
self.assertTrue((out0 == np.sum(np_x, axis=(0, 1, 2))).all())
self.assertTrue((out1 == np.sum(np_x, axis=0)).all())
self.assertTrue((out2 == np.sum(np_x, axis=(0, 1))).all())
self.assertTrue((out3 == np.sum(np_x, axis=(0, 1, 2))).all())
class API_TestReduceMeanOp(unittest.TestCase): class API_TestReduceMeanOp(unittest.TestCase):
......
...@@ -330,8 +330,8 @@ def sum(input, dim=None, keep_dim=False, name=None): ...@@ -330,8 +330,8 @@ def sum(input, dim=None, keep_dim=False, name=None):
""" """
complex_variable_exists([input], "sum") complex_variable_exists([input], "sum")
real = math.sum(input.real, dim=dim, keep_dim=keep_dim, name=name) real = math.sum(input.real, axis=dim, keepdim=keep_dim, name=name)
imag = math.sum(input.imag, dim=dim, keep_dim=keep_dim, name=name) imag = math.sum(input.imag, axis=dim, keepdim=keep_dim, name=name)
return ComplexVariable(real, imag) return ComplexVariable(real, imag)
......
...@@ -532,75 +532,84 @@ for func in [ ...@@ -532,75 +532,84 @@ for func in [
}) + """\n""" + str(func.__doc__) }) + """\n""" + str(func.__doc__)
def sum(input, dim=None, dtype=None, keep_dim=False, name=None): def sum(x, axis=None, dtype=None, keepdim=False, name=None):
""" """
:alias_main: paddle.sum
:alias: paddle.sum,paddle.tensor.sum,paddle.tensor.math.sum
Computes the sum of tensor elements over the given dimension. Computes the sum of tensor elements over the given dimension.
Args: Args:
input (Variable): The input variable which is a Tensor, the data type is float32, x (Tensor): An N-D Tensor, the data type is float32, float64, int32 or int64.
float64, int32, int64. axis (int|list|tuple, optional): The dimensions along which the sum is performed. If
dim (list|int, optional): The dimensions along which the sum is performed. If :attr:`None`, sum all elements of :attr:`x` and return a
:attr:`None`, sum all elements of :attr:`input` and return a
Tensor variable with a single element, otherwise must be in the Tensor variable with a single element, otherwise must be in the
range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`, range :math:`[-rank(x), rank(x))`. If :math:`axis[i] < 0`,
the dimension to reduce is :math:`rank + dim[i]`. the dimension to reduce is :math:`rank + axis[i]`.
dtype(str, optional): The dtype of output tensor. The default value is None, the dtype dtype (str, optional): The dtype of output Tensor. The default value is None, the dtype
of output is the same as input tensor. of output is the same as input Tensor `x`.
keep_dim (bool, optional): Whether to reserve the reduced dimension in the keepdim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension output Tensor. The result Tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true, default than the :attr:`x` unless :attr:`keepdim` is true, default
value is False. value is False.
name(str, optional): The default value is None. Normally there is no need for name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name` user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns: Returns:
Variable: Tensor, results of summation operation on the specified dim of input tensor, Tensor: Results of summation operation on the specified axis of input Tensor `x`,
it's data type is the same as input's Tensor. it's data type is the same as `x`.
Raises: Raises:
ValueError, the :attr:`dtype` must be float64 or int64. ValueError: The :attr:`dtype` must be float64 or int64.
TypeError: The type of :attr:`axis` must be int, list or tuple.
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle import paddle
import paddle.fluid as fluid paddle.disable_static()
# x is a Tensor variable with following elements: # x is a Tensor variable with following elements:
# [[0.2, 0.3, 0.5, 0.9] # [[0.2, 0.3, 0.5, 0.9]
# [0.1, 0.2, 0.6, 0.7]] # [0.1, 0.2, 0.6, 0.7]]
# Each example is followed by the corresponding output tensor. # Each example is followed by the corresponding output tensor.
x = fluid.data(name='x', shape=[2, 4], dtype='float32') x_data = np.array([[0.2, 0.3, 0.5, 0.9],[0.1, 0.2, 0.6, 0.7]]).astype('float32')
x = paddle.to_variable(x_data)
out1 = paddle.sum(x) # [3.5] out1 = paddle.sum(x) # [3.5]
out2 = paddle.sum(x, dim=0) # [0.3, 0.5, 1.1, 1.6] out2 = paddle.sum(x, axis=0) # [0.3, 0.5, 1.1, 1.6]
out3 = paddle.sum(x, dim=-1) # [1.9, 1.6] out3 = paddle.sum(x, axis=-1) # [1.9, 1.6]
out4 = paddle.sum(x, dim=1, keep_dim=True) # [[1.9], [1.6]] out4 = paddle.sum(x, axis=1, keepdim=True) # [[1.9], [1.6]]
# y is a Tensor variable with shape [2, 2, 2] and elements as below: # y is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1, 2], [3, 4]], # [[[1, 2], [3, 4]],
# [[5, 6], [7, 8]]] # [[5, 6], [7, 8]]]
# Each example is followed by the corresponding output tensor. # Each example is followed by the corresponding output tensor.
y = fluid.data(name='y', shape=[2, 2, 2], dtype='float32') y_data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype('float32')
out5 = paddle.sum(y, dim=[1, 2]) # [10, 26] y = paddle.to_variable(y_data)
out6 = paddle.sum(y, dim=[0, 1]) # [16, 20] out5 = paddle.sum(y, axis=[1, 2]) # [10, 26]
out6 = paddle.sum(y, axis=[0, 1]) # [16, 20]
""" """
if dim is not None and not isinstance(dim, list): if axis is not None and not isinstance(axis, (list, tuple)):
dim = [dim] axis = [axis]
if not axis:
reduce_all_flag = True
else:
if len(axis) == len(x.shape):
reduce_all_flag = True
else:
reduce_all_flag = False
attrs = { attrs = {
'dim': dim if dim != None and dim != [] else [0], 'dim': axis if axis != None and axis != [] and axis != () else [0],
'keep_dim': keep_dim, 'keep_dim': keepdim,
'reduce_all': True if dim == None or dim == [] else False, 'reduce_all': reduce_all_flag
} }
dtype_flag = False dtype_flag = False
if dtype is not None: if dtype is not None:
if dtype in ['float64', 'int64']: if dtype in ['float64', 'int64']:
if (convert_dtype(input.dtype) == "float32" and dtype == "float64") or \ if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \
(convert_dtype(input.dtype) == "int32" and dtype == "int64"): (convert_dtype(x.dtype) == "int32" and dtype == "int64"):
attrs.update({ attrs.update({
'in_dtype': input.dtype, 'in_dtype': x.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype) 'out_dtype': convert_np_dtype_to_dtype_(dtype)
}) })
dtype_flag = True dtype_flag = True
...@@ -610,27 +619,28 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None): ...@@ -610,27 +619,28 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None):
format(dtype)) format(dtype))
if in_dygraph_mode(): if in_dygraph_mode():
reduce_all = True if dim == None or dim == [] else False axis = axis if axis != None and axis != [] else [0]
dim = dim if dim != None and dim != [] else [0]
if dtype_flag: if dtype_flag:
return core.ops.reduce_sum(input, 'dim', dim, 'keep_dim', keep_dim, return core.ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all, 'in_dtype', 'reduce_all', reduce_all_flag, 'in_dtype',
input.dtype, 'out_dtype', x.dtype, 'out_dtype',
convert_np_dtype_to_dtype_(dtype)) convert_np_dtype_to_dtype_(dtype))
else: else:
return core.ops.reduce_sum(input, 'dim', dim, 'keep_dim', keep_dim, return core.ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all) 'reduce_all', reduce_all_flag)
check_variable_and_dtype( check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_sum') x, 'x', ['float32', 'float64', 'int32', 'int64'], 'sum')
check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum')
helper = LayerHelper('sum', **locals()) helper = LayerHelper('sum', **locals())
if dtype_flag: if dtype_flag:
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=convert_np_dtype_to_dtype_(dtype)) dtype=convert_np_dtype_to_dtype_(dtype))
else: else:
out = helper.create_variable_for_type_inference(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='reduce_sum', type='reduce_sum',
inputs={'X': input}, inputs={'X': x},
outputs={'Out': out}, outputs={'Out': out},
attrs=attrs) attrs=attrs)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册