diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 3c66523aefffea990de1e9bdc30cdeca2f1f0f22..150da6d59b9ff9780fed5974150b0b9a70891d11 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2685,7 +2685,7 @@ DDim ReduceInferDim(const MetaTensor& x, bool full_dim = true; std::set dims_set(formated_axis.begin(), formated_axis.end()); - for (int64_t i = 0; i < x.dims().size(); ++i) { + for (int64_t i = 0; i < x_rank; ++i) { if (dims_set.find(i) == dims_set.end()) { full_dim = false; break; @@ -2695,7 +2695,7 @@ DDim ReduceInferDim(const MetaTensor& x, std::vector out_dim_vector; if (keep_dim) { - for (int64_t i = 0; i < x.dims().size(); ++i) { + for (int64_t i = 0; i < x_rank; ++i) { if (reduce_all || dims_set.find(i) != dims_set.end()) { out_dim_vector.push_back(1); } else { @@ -2703,7 +2703,7 @@ DDim ReduceInferDim(const MetaTensor& x, } } } else { - for (int64_t i = 0; i < x.dims().size(); ++i) { + for (int64_t i = 0; i < x_rank; ++i) { if (reduce_all || dims_set.find(i) != dims_set.end()) { continue; } else { @@ -2711,7 +2711,7 @@ DDim ReduceInferDim(const MetaTensor& x, } } - if (out_dim_vector.size() == 0) { + if (x_rank > 0 && out_dim_vector.size() == 0) { out_dim_vector.push_back(1); } } @@ -3013,6 +3013,7 @@ void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) { phi::errors::InvalidArgument( "The rank of input should be less than 7, but received %d.", in_dims.size())); + out->set_dims(in_dims); } void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) { diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 59c3df0fce5e503691818cc0781fc73d12b99848..22ed5b29d77bc021d1ea0da64113b38121fe0121 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -44,7 +44,7 @@ struct DimensionsTransform { int64_t in_idx = 0; if (in_dim.size() < dim_size) { DimVector tmp_dim(dim_size, 1); - do { + for (; in_idx < in_dim.size();) { if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) { tmp_dim[axis] = in_dim[in_idx]; in_idx++; @@ -59,11 +59,11 @@ struct DimensionsTransform { out_dims[axis], in_dim[in_idx])); } - } while (in_idx < in_dim.size()); + } in_dim.resize(dim_size); std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin()); } else { - do { + for (; in_idx < dim_size;) { if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) { in_idx++; } else { @@ -76,7 +76,7 @@ struct DimensionsTransform { out_dims[in_idx], in_dim[in_idx])); } - } while (in_idx < dim_size); + } } std::reverse(in_dim.begin(), in_dim.end()); } diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 92fe3885b42f0a5a6c23bf6f0ea9445658b09a3d..9138fd85e65aa4ad91d5847024a22c14adfe7465 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -1063,6 +1063,14 @@ void ReduceKernel(const KPDevice& dev_ctx, dev_ctx.Alloc(y); auto x_dim = phi::vectorize(x.dims()); + + if (x_dim.size() == 0) { + std::vector inputs = {&x}; + std::vector outputs = {y}; + funcs::ElementwiseKernel(dev_ctx, inputs, &outputs, transform); + return; + } + auto config = ReduceConfig(origin_reduce_dims, x_dim); config.Run(dev_ctx); int numel = x.numel(); diff --git a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu index 7da2502a5eea7039322a6f849be1736faefe5517..40c317e1262c5b44f0e74ce53e38ae4faa3abd96 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu @@ -16,8 +16,8 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/reduce_function.h" -#include "paddle/phi/kernels/gpu/reduce_grad.h" namespace phi { @@ -29,23 +29,34 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = x.dims().size(); + if (dims.size() == 0) { + reduce_all = true; + } std::vector reduce_dims = funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); + + auto update_dims = vectorize(x.dims()); int reduce_num = 1; for (auto i : reduce_dims) { reduce_num *= (x.dims())[i]; + update_dims[i] = 1; } + + // make new tensor + DenseTensor new_out_grad(out_grad.dtype()); + new_out_grad.ShareDataWith(out_grad); + new_out_grad.Resize(phi::make_ddim(update_dims)); + + // call BroadcastKernel + dev_ctx.Alloc(x_grad, x.dtype()); + std::vector inputs = {&new_out_grad}; + std::vector outputs = {x_grad}; + using MPType = typename kps::details::MPTypeTrait::Type; - ReduceGradKernel>( - dev_ctx, - x, - out_grad, - dims.GetData(), - keep_dim, - reduce_all, - x_grad, - kps::DivideFunctor(reduce_num)); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, 0, kps::DivideFunctor(reduce_num)); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 2230b4b8525b3d4eb9bb42ae1fab7cfce21d262f..74209afe374673b90610361e4c361aeab0a7c760 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -29,42 +29,32 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { - using MPType = typename kps::details::MPTypeTrait::Type; - auto out_dtype = x.dtype(); - auto* in_x = &x; - auto* d_out = &out_grad; - auto* d_x = x_grad; - - // get reduce_dim and reduce_num for reduce_mean_grad - int dim_size = in_x->dims().size(); + // get reduce_dim for reduce_mean_grad + int dim_size = x.dims().size(); if (dims.size() == 0) { reduce_all = true; } std::vector reduce_dims = funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); - auto update_dims = vectorize(d_x->dims()); - int reduce_num = 1; + auto update_dims = vectorize(x.dims()); for (auto i : reduce_dims) { - reduce_num *= (in_x->dims())[i]; update_dims[i] = 1; } + // make new tensor - DenseTensor new_d_out(d_out->dtype()); - new_d_out.ShareDataWith(*d_out); - new_d_out.Resize(phi::make_ddim(update_dims)); + DenseTensor new_out_grad(out_grad.dtype()); + new_out_grad.ShareDataWith(out_grad); + new_out_grad.Resize(phi::make_ddim(update_dims)); - dev_ctx.Alloc(d_x, x.dtype()); - auto pt_out_dtype = x.dtype(); - auto pt_d_out = new_d_out; - auto pt_d_x = *d_x; - std::vector inputs = {&pt_d_out}; - std::vector outputs = {&pt_d_x}; + // call ReduceGrad + dev_ctx.Alloc(x_grad, x.dtype()); + using MPType = typename kps::details::MPTypeTrait::Type; phi::ReduceGrad>( dev_ctx, - &pt_d_out, - &pt_d_x, - pt_out_dtype, + &new_out_grad, + x_grad, + x.dtype(), kps::IdentityFunctor()); } diff --git a/paddle/phi/kernels/reduce_mean_kernel.cc b/paddle/phi/kernels/reduce_mean_kernel.cc index 375172fdb37330aba3dbace00a01bf208c929691..aa615a6bb1ef1ccaecf016155c0bd1a6c07a4feb 100644 --- a/paddle/phi/kernels/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/reduce_mean_kernel.cc @@ -26,6 +26,9 @@ void MeanKernel(const Context& dev_ctx, bool keep_dim, DenseTensor* out) { bool reduce_all = false; + if (dims.size() == 0) { + reduce_all = true; + } MeanRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9f7cbb1141193da67298cc5657efdc3878736b0a..4a5dbe4a106c29cc7aa55fbfcc85778a58d47cf3 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5096,9 +5096,6 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_sum(y, dim=[0, 1]) # [16, 20] """ - if dim is not None and not isinstance(dim, list): - dim = [dim] - reduce_all, dim = _get_reduce_dim(dim, input) if in_dygraph_mode(): diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index ed9313b054696bf0556251e50c98e8bb8b4a5c54..68e88c9ba2a8132d7938ea80e46fd31ee99c5511 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -58,6 +58,21 @@ class TestMeanOp(OpTest): self.check_grad(['X'], 'Out', check_eager=True) +class TestMeanOp_ZeroDim(OpTest): + def setUp(self): + self.op_type = "mean" + self.python_api = paddle.mean + self.dtype = np.float64 + self.inputs = {'X': np.random.random([]).astype(self.dtype)} + self.outputs = {'Out': np.mean(self.inputs["X"])} + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_checkout_grad(self): + self.check_grad(['X'], 'Out', check_eager=True) + + class TestMeanOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 8fa448a6927dd5866504b5dd5bd8bdad832dc02b..bf0a968bdb1ff7d83d77954101e13c1807754f31 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -37,6 +37,21 @@ class TestSumOp(OpTest): self.check_grad(['X'], 'Out', check_eager=True) +class TestSumOp_ZeroDim(OpTest): + def setUp(self): + self.python_api = paddle.sum + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random([]).astype("float64")} + self.outputs = {'Out': self.inputs['X'].sum(axis=None)} + self.attrs = {'dim': [], 'reduce_all': True} + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X'], 'Out', check_eager=True) + + class TestSumOp_fp16(OpTest): def setUp(self): self.python_api = paddle.sum diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py b/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py index df4fa96d4a36cf29e118f69fedf8fd5c87142a07..0cab423aa7b98e5793e61dde2130ebda7f671191 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py @@ -17,6 +17,7 @@ import paddle.fluid as fluid import numpy as np import unittest + unary_api_list = [ paddle.nn.functional.elu, paddle.nn.functional.gelu, @@ -159,5 +160,55 @@ class TestUnaryAPI(unittest.TestCase): paddle.disable_static() +reduce_api_list = [ + paddle.sum, + paddle.mean, + paddle.nansum, + paddle.nanmean, +] + + +class TestReduceAPI(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + for api in reduce_api_list: + x = paddle.rand([]) + x.stop_gradient = False + out = api(x, None) + out.backward() + + self.assertEqual(x.shape, []) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + + paddle.enable_static() + + def test_static(self): + paddle.enable_static() + for api in reduce_api_list: + main_prog = fluid.Program() + with fluid.program_guard(main_prog, fluid.Program()): + x = paddle.rand([]) + + x.stop_gradient = False + out = api(x, None) + fluid.backward.append_backward(out) + + # Test compile shape, grad is always [1] + self.assertEqual(x.shape, ()) + self.assertEqual(out.shape, ()) + + exe = fluid.Executor() + result = exe.run(main_prog, fetch_list=[x, out]) + + # Test runtime shape + self.assertEqual(result[0].shape, ()) + self.assertEqual(result[1].shape, ()) + + paddle.disable_static() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 91388a6f99a02bf7ebde19c3d6c2dcce1409fa58..34bc3b006b3d9021869ae1f511679d195b017623 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1265,22 +1265,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): out8 = paddle.sum(x, axis=0) # [1, 1, 1, 1] out9 = paddle.sum(x, axis=1) # [4, 0] """ - if isinstance(axis, Variable): - reduce_all_flag = True if axis.shape[0] == len(x.shape) else False - else: - if axis is not None and not isinstance(axis, (list, tuple)): - axis = [axis] - - if not axis: - axis = [] - - if len(axis) == 0: - reduce_all_flag = True - else: - if len(axis) == len(x.shape): - reduce_all_flag = True - else: - reduce_all_flag = False + reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) dtype_flag = False if dtype is not None: @@ -1290,11 +1275,6 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): if in_dygraph_mode(): return _C_ops.sum(x, axis, dtype, keepdim) - if not isinstance(axis, Variable): - axis = axis if axis != None and axis != [] and axis != () else [0] - if utils._contain_var(axis): - axis = utils._convert_to_tensor_list(axis) - if _in_legacy_dygraph(): if dtype_flag: return _legacy_C_ops.reduce_sum( @@ -1304,7 +1284,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): 'keep_dim', keepdim, 'reduce_all', - reduce_all_flag, + reduce_all, 'in_dtype', x.dtype, 'out_dtype', @@ -1318,10 +1298,10 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): 'keep_dim', keepdim, 'reduce_all', - reduce_all_flag, + reduce_all, ) - attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all_flag} + attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all} if dtype_flag: attrs.update({'in_dtype': x.dtype, 'out_dtype': dtype}) @@ -2304,13 +2284,13 @@ def inverse(x, name=None): return out -def _get_reduce_axis(axis): +def _get_reduce_axis(axis, x): """ Internal function for max, min, amax and amin. It computes the attribute reduce_all value based on axis. """ if axis is not None and not isinstance(axis, list): - if isinstance(axis, tuple): + if isinstance(axis, (tuple, range)): axis = list(axis) elif isinstance(axis, int): axis = [axis] @@ -2320,37 +2300,25 @@ def _get_reduce_axis(axis): type(axis) ) ) - reduce_all = True if axis == None or axis == [] else False - if axis == None: + if axis is None: axis = [] + if axis == [] or len(axis) == len(x.shape): + reduce_all = True + else: + reduce_all = False return reduce_all, axis -def _get_reduce_axis_with_tensor(axis): +def _get_reduce_axis_with_tensor(axis, x): if isinstance(axis, Variable): - return False, axis - return _get_reduce_axis(axis) - - -def _get_reduce_all_value(axis): - """ - Internal function for max, min, amax and amin. - It computes the attribute reduce_all value based on axis. - """ - if axis is not None and not isinstance(axis, list): - if isinstance(axis, tuple): - axis = list(axis) - elif isinstance(axis, int): - axis = [axis] + if axis.shape[0] == len(x.shape): + reduce_all = True else: - raise TypeError( - "The type of axis must be int, list or tuple, but received {}".format( - type(axis) - ) - ) - - reduce_all = True if axis == None or axis == [] else False - axis = axis if axis != None and axis != [] else [0] + reduce_all = False + else: + reduce_all, axis = _get_reduce_axis(axis, x) + if utils._contain_var(axis): + axis = utils._convert_to_tensor_list(axis) return reduce_all, axis @@ -2432,7 +2400,7 @@ def max(x, axis=None, keepdim=False, name=None): #[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]] """ - reduce_all, axis = _get_reduce_axis_with_tensor(axis) + reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) if in_dygraph_mode(): return _C_ops.max(x, axis, keepdim) if _in_legacy_dygraph(): @@ -2534,7 +2502,7 @@ def min(x, axis=None, keepdim=False, name=None): #[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]] """ - reduce_all, axis = _get_reduce_axis_with_tensor(axis) + reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) if in_dygraph_mode(): return _C_ops.min(x, axis, keepdim) @@ -2650,7 +2618,7 @@ def amax(x, axis=None, keepdim=False, name=None): #[0.9., 0.9], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] """ - reduce_all, axis = _get_reduce_axis(axis) + reduce_all, axis = _get_reduce_axis(axis, x) if in_dygraph_mode(): return _C_ops.amax(x, axis, keepdim) if _in_legacy_dygraph(): @@ -2764,7 +2732,7 @@ def amin(x, axis=None, keepdim=False, name=None): #[0.1., 0.1], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] """ - reduce_all, axis = _get_reduce_axis(axis) + reduce_all, axis = _get_reduce_axis(axis, x) if in_dygraph_mode(): return _C_ops.amin(x, axis, keepdim) elif _in_legacy_dygraph(): diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 371e3fafd057e5beda0fecd824ab974da2bbb3db..ad061673ab9f4f2f411aa8bc4a8bfb3067b07b74 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -20,9 +20,9 @@ from ..framework import core from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode from .search import where from ..fluid.data_feeder import check_type, check_variable_and_dtype -from ..fluid.layers import utils import paddle from paddle import _C_ops, _legacy_C_ops +from .math import _get_reduce_axis_with_tensor __all__ = [] @@ -80,22 +80,9 @@ def mean(x, axis=None, keepdim=False, name=None): # [ 8.5 12.5 16.5] """ - if isinstance(axis, Variable): - reduce_all = True if axis.shape[0] == len(x.shape) else False - else: - if isinstance(axis, int): - axis = [axis] - reduce_all = ( - True - if axis is None or len(axis) == 0 or len(axis) == len(x.shape) - else False - ) - if axis is None or len(axis) == 0: - axis = [0] + reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) if in_dygraph_mode(): - if reduce_all: - axis = list(range(len(x.shape))) return _C_ops.mean(x, axis, keepdim) if _in_legacy_dygraph(): return _legacy_C_ops.reduce_mean( @@ -122,8 +109,6 @@ def mean(x, axis=None, keepdim=False, name=None): helper = LayerHelper('mean', **locals()) - if not isinstance(axis, Variable) and utils._contain_var(axis): - axis = utils._convert_to_tensor_list(axis) attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all} out = helper.create_variable_for_type_inference(x.dtype) helper.append_op(