From 3ed69e0d3ad26c50bfc732f2f3d4fa84a5881889 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Mon, 8 Aug 2022 17:17:18 +0800 Subject: [PATCH] move reduce_all_flag from python to c++ (#44926) * move reduce_all_flag from python to c++ * fix infer shape bug * fix bug; * fix sum infer meta bug * fix reduce sum grad gpu bug * fix amin amax bug; --- paddle/phi/infermeta/unary.cc | 6 +++ .../phi/kernels/gpu/reduce_amin_amax_common.h | 3 ++ .../phi/kernels/gpu/reduce_sum_grad_kernel.cu | 3 ++ paddle/phi/kernels/impl/reduce_grad.h | 3 ++ paddle/phi/kernels/reduce_amax_kernel.cc | 3 ++ paddle/phi/kernels/reduce_amin_kernel.cc | 3 ++ paddle/phi/kernels/reduce_max_kernel.cc | 3 ++ paddle/phi/kernels/reduce_min_kernel.cc | 3 ++ paddle/phi/kernels/reduce_sum_kernel.cc | 3 ++ .../unittests/test_max_min_amax_amin_op.py | 4 ++ .../fluid/tests/unittests/test_mse_loss.py | 1 + .../fluid/tests/unittests/test_split_op.py | 1 + python/paddle/nn/functional/distance.py | 4 +- python/paddle/tensor/math.py | 54 +++++++++++-------- 14 files changed, 69 insertions(+), 25 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 14f0951c3d..d6395c8a2e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2586,6 +2586,9 @@ void ReduceInferMeta(const MetaTensor& x, bool keep_dim, MetaTensor* out) { bool reduce_all = false; + if (axis.size() == 0) { + reduce_all = true; + } ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out); } @@ -3254,6 +3257,9 @@ void SumInferMeta(const MetaTensor& x, bool keep_dim, MetaTensor* out) { bool reduce_all = false; + if (axis.size() == 0) { + reduce_all = true; + } SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out); } diff --git a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h index fe3cd89d5b..5d90433ad2 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h +++ b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h @@ -38,6 +38,9 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, auto* d_x = x_grad; // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = in_x->dims().size(); + if (dims.size() == 0) { + reduce_all = true; + } auto reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all); auto update_dims = vectorize(d_x->dims()); int reduce_num = 1; diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 8b111641cf..c0955cd742 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -37,6 +37,9 @@ void ReduceSumGradKernel(const Context& dev_ctx, // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = in_x->dims().size(); + if (dims.size() == 0) { + reduce_all = true; + } std::vector reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all); diff --git a/paddle/phi/kernels/impl/reduce_grad.h b/paddle/phi/kernels/impl/reduce_grad.h index 8dcd3c2ba8..40b62cc83f 100644 --- a/paddle/phi/kernels/impl/reduce_grad.h +++ b/paddle/phi/kernels/impl/reduce_grad.h @@ -91,6 +91,9 @@ void ReduceGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + if (dims.size() == 0) { + reduce_all = true; + } if (x.dtype() != out_grad.dtype()) { DenseTensorMeta x_grad_meta( out_grad.dtype(), x_grad->dims(), x_grad->layout()); diff --git a/paddle/phi/kernels/reduce_amax_kernel.cc b/paddle/phi/kernels/reduce_amax_kernel.cc index acec25d83d..47b5e97467 100644 --- a/paddle/phi/kernels/reduce_amax_kernel.cc +++ b/paddle/phi/kernels/reduce_amax_kernel.cc @@ -26,6 +26,9 @@ void AMaxKernel(const Context& dev_ctx, bool keep_dim, DenseTensor* out) { bool reduce_all = false; + if (dims.size() == 0) { + reduce_all = true; + } AMaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_amin_kernel.cc b/paddle/phi/kernels/reduce_amin_kernel.cc index 28e6e587f4..8da4f3afd9 100644 --- a/paddle/phi/kernels/reduce_amin_kernel.cc +++ b/paddle/phi/kernels/reduce_amin_kernel.cc @@ -26,6 +26,9 @@ void AMinKernel(const Context& dev_ctx, bool keep_dim, DenseTensor* out) { bool reduce_all = false; + if (dims.size() == 0) { + reduce_all = true; + } AMinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_max_kernel.cc b/paddle/phi/kernels/reduce_max_kernel.cc index 26b8bc196c..7bdf9ba2bb 100644 --- a/paddle/phi/kernels/reduce_max_kernel.cc +++ b/paddle/phi/kernels/reduce_max_kernel.cc @@ -26,6 +26,9 @@ void MaxKernel(const Context& dev_ctx, bool keep_dim, DenseTensor* out) { bool reduce_all = false; + if (dims.size() == 0) { + reduce_all = true; + } MaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_min_kernel.cc b/paddle/phi/kernels/reduce_min_kernel.cc index 75d906aa4b..69725759e4 100644 --- a/paddle/phi/kernels/reduce_min_kernel.cc +++ b/paddle/phi/kernels/reduce_min_kernel.cc @@ -26,6 +26,9 @@ void MinKernel(const Context& dev_ctx, bool keep_dim, DenseTensor* out) { bool reduce_all = false; + if (dims.size() == 0) { + reduce_all = true; + } MinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index 0d79fa34bc..c9622768c4 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -27,6 +27,9 @@ void SumKernel(const Context& dev_ctx, bool keep_dim, DenseTensor* out) { bool reduce_all = false; + if (dims.size() == 0) { + reduce_all = true; + } SumRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); } diff --git a/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py b/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py index cadbca93ad..608fad131f 100644 --- a/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py +++ b/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py @@ -187,3 +187,7 @@ class TestMaxMinAmaxAminAPI6(TestMaxMinAmaxAminAPI): self.dtype = 'float64' self.axis = None self.keepdim = False + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_mse_loss.py b/python/paddle/fluid/tests/unittests/test_mse_loss.py index b32833916e..d3dd0d2774 100644 --- a/python/paddle/fluid/tests/unittests/test_mse_loss.py +++ b/python/paddle/fluid/tests/unittests/test_mse_loss.py @@ -316,4 +316,5 @@ class TestNNFunctionalMseLoss(unittest.TestCase): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index 4f438e26a7..4fb33e53ba 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -540,4 +540,5 @@ class API_TestEmptySplit(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/distance.py b/python/paddle/nn/functional/distance.py index 8c672ffc69..4d6f447d67 100644 --- a/python/paddle/nn/functional/distance.py +++ b/python/paddle/nn/functional/distance.py @@ -67,12 +67,12 @@ def pairwise_distance(x, y, p=2., epsilon=1e-6, keepdim=False, name=None): check_type(epsilon, 'epsilon', (float), 'PairwiseDistance') check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance') if in_dygraph_mode(): - sub = _C_ops.elementwise_sub(x, y) + sub = _C_ops.final_state_subtract(x, y) # p_norm op has not uesd epsilon, so change it to the following. if epsilon != 0.0: epsilon = paddle.fluid.dygraph.base.to_variable([epsilon], dtype=sub.dtype) - sub = _C_ops.elementwise_add(sub, epsilon) + sub = _C_ops.final_state_add(sub, epsilon) return _C_ops.final_state_p_norm(sub, p, -1, 0., keepdim, False) if _in_legacy_dygraph(): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 86b3b71998..b94329eb92 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1162,12 +1162,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): axis = [axis] if not axis: - reduce_all_flag = True - else: - if len(axis) == len(x.shape): - reduce_all_flag = True - else: - reduce_all_flag = False + axis = [] dtype_flag = False if dtype is not None: @@ -1175,13 +1170,16 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): dtype = convert_np_dtype_to_dtype_(dtype) if in_dygraph_mode(): - if reduce_all_flag: - axis = range(len(x.shape)) - else: - axis = axis if axis != None and axis != [] else [0] - return _C_ops.final_state_sum(x, axis, dtype, keepdim) + if len(axis) == 0: + reduce_all_flag = True + else: + if len(axis) == len(x.shape): + reduce_all_flag = True + else: + reduce_all_flag = False + if _in_legacy_dygraph(): axis = axis if axis != None and axis != [] else [0] if dtype_flag: @@ -2056,6 +2054,24 @@ def inverse(x, name=None): type='inverse', inputs={'Input': [x] }, outputs={'Output': [out]}) return out +def _get_reduce_axis(axis): + """ + Internal function for max, min, amax and amin. + It computes the attribute reduce_all value based on axis. + """ + if axis is not None and not isinstance(axis, list): + if isinstance(axis, tuple): + axis = list(axis) + elif isinstance(axis, int): + axis= [axis] + else: + raise TypeError( + "The type of axis must be int, list or tuple, but received {}".format(type(axis))) + reduce_all = True if axis == None or axis == [] else False + if axis == None: + axis = [] + return reduce_all, axis + def _get_reduce_all_value(axis): """ Internal function for max, min, amax and amin. @@ -2152,10 +2168,8 @@ def max(x, axis=None, keepdim=False, name=None): #[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]] """ - reduce_all, axis = _get_reduce_all_value(axis) + reduce_all, axis = _get_reduce_axis(axis) if in_dygraph_mode(): - if reduce_all: - axis = range(len(x.shape)) return _C_ops.final_state_max(x, axis, keepdim) if _in_legacy_dygraph(): return _C_ops.reduce_max(x, 'dim', axis, 'keep_dim', keepdim, @@ -2255,10 +2269,8 @@ def min(x, axis=None, keepdim=False, name=None): #[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]] """ - reduce_all, axis = _get_reduce_all_value(axis) + reduce_all, axis = _get_reduce_axis(axis) if in_dygraph_mode(): - if reduce_all: - axis = range(len(x.shape)) return _C_ops.final_state_min(x, axis, keepdim) if _in_legacy_dygraph(): @@ -2372,10 +2384,8 @@ def amax(x, axis=None, keepdim=False, name=None): #[0.9., 0.9], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] """ - reduce_all, axis = _get_reduce_all_value(axis) + reduce_all, axis = _get_reduce_axis(axis) if in_dygraph_mode(): - if reduce_all: - axis = range(len(x.shape)) return _C_ops.final_state_amax(x, axis, keepdim) if _in_legacy_dygraph(): return _C_ops.reduce_amax(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) @@ -2488,10 +2498,8 @@ def amin(x, axis=None, keepdim=False, name=None): #[0.1., 0.1], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] """ - reduce_all, axis = _get_reduce_all_value(axis) + reduce_all, axis = _get_reduce_axis( axis ) if in_dygraph_mode(): - if reduce_all: - axis = range(len(x.shape)) return _C_ops.final_state_amin(x, axis, keepdim) elif _in_legacy_dygraph(): return _C_ops.reduce_amin(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) -- GitLab