From 340dfb26135aa9be903575aa29691402ccd40467 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Tue, 28 Dec 2021 14:49:31 +0800 Subject: [PATCH] Add Amax and Amin API (#38417) * add amax/amin * support axis is list --- .../operators/reduce_ops/reduce_amax_op.cc | 34 +++ .../operators/reduce_ops/reduce_amax_op.cu | 23 ++ .../reduce_ops/reduce_amax_op.part.cu | 25 ++ .../operators/reduce_ops/reduce_amin_op.cc | 34 +++ .../operators/reduce_ops/reduce_amin_op.cu | 23 ++ .../reduce_ops/reduce_amin_op.part.cu | 25 ++ .../operators/reduce_ops/reduce_min_max_op.h | 89 +++++++ python/paddle/__init__.py | 4 + ...min_op.py => test_max_min_amax_amin_op.py} | 56 +++- python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/math.py | 249 ++++++++++++++++-- 11 files changed, 531 insertions(+), 35 deletions(-) create mode 100644 paddle/fluid/operators/reduce_ops/reduce_amax_op.cc create mode 100644 paddle/fluid/operators/reduce_ops/reduce_amax_op.cu create mode 100644 paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu create mode 100644 paddle/fluid/operators/reduce_ops/reduce_amin_op.cc create mode 100644 paddle/fluid/operators/reduce_ops/reduce_amin_op.cu create mode 100644 paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu rename python/paddle/fluid/tests/unittests/{test_max_min_op.py => test_max_min_amax_amin_op.py} (73%) diff --git a/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc b/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc new file mode 100644 index 0000000000..c5bc66e23c --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" + +REGISTER_REDUCE_OP(reduce_amax); +REGISTER_OP_CPU_KERNEL( + reduce_amax, ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL( + reduce_amax_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_amax_op.cu b/paddle/fluid/operators/reduce_ops/reduce_amax_op.cu new file mode 100644 index 0000000000..16c7a4794b --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_amax_op.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" + +// reduce_max +REGISTER_OP_CUDA_KERNEL( + reduce_amax, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu new file mode 100644 index 0000000000..27f2e2b70c --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" + +REGISTER_OP_CUDA_KERNEL( + reduce_amax_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc b/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc new file mode 100644 index 0000000000..027bf8ea00 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" + +REGISTER_REDUCE_OP(reduce_amin); +REGISTER_OP_CPU_KERNEL( + reduce_amin, ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL( + reduce_amin_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_amin_op.cu b/paddle/fluid/operators/reduce_ops/reduce_amin_op.cu new file mode 100644 index 0000000000..f9f015804e --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_amin_op.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" + +// reduce_min +REGISTER_OP_CUDA_KERNEL( + reduce_amin, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu new file mode 100644 index 0000000000..a296c4c5d6 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" + +REGISTER_OP_CUDA_KERNEL( + reduce_amin_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_max_op.h b/paddle/fluid/operators/reduce_ops/reduce_min_max_op.h index 2557e8dd48..dfd0c9d74d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_min_max_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_min_max_op.h @@ -46,5 +46,94 @@ struct MaxOrMinGradFunctor { } }; +#define HANDLE_AXIS_DIM(BROADCAST_DIM, AXIS_DIM) \ + if (broadcast_dim_size == BROADCAST_DIM && rank == AXIS_DIM) { \ + AMaxOrAMinAxisIsListGradFunctor( \ + place, x, y, dx, dy, dim, axis_dim); \ + } + +template +void AMaxOrAMinAxisIsListGradFunctor(const DeviceContext& place, X* x, Y* y, + DX* dx, DY* dy, const Dim& dim, + const std::vector& axis_dim) { + // R is x->dimensions().size(); + // D is axis_dim->dimensions().size(); + auto axis = Eigen::array(); + auto reshape_x = Eigen::array(); + auto reshape_y = Eigen::array(); + + for (int i = 0; i < D; i++) axis[i] = axis_dim[i]; + for (int i = 0; i < R; i++) { + reshape_x[i] = x->dimensions()[i]; + reshape_y[i] = y->dimensions()[i]; + } + + auto equals = (*x) == y->broadcast(dim); + auto ones = dx->constant(1); + auto zeros = dx->constant(0); + auto mask = equals.select(ones, zeros); + dx->device(place) = + dy->broadcast(dim) * mask / + mask.reshape(reshape_x).sum(axis).reshape(reshape_y).broadcast(dim); +} + +struct AMaxOrAMinGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + auto equals = (*x) == y->broadcast(dim); + auto ones = dx->constant(1); + auto zeros = dx->constant(0); + auto mask = equals.select(ones, zeros); + + // If there are multiple minimum or maximum elements, + // we evenly distribute gradient between these equal values + size_t x_numel = 1; + for (size_t i = 0; i < x->dimensions().size(); i++) + x_numel *= x->dimensions()[i]; + // reduce_all + if (size == static_cast(x_numel)) { + auto equal_number = mask.sum() + .reshape(Eigen::array({1})) + .broadcast(Eigen::array({size})); + dx->device(place) = dy->broadcast(dim) * mask / equal_number; + return; + } + + // compute forward reduce axis_dim by dim (which is broadcast_dim) + std::vector axis_dim; + int broadcast_dim_size = static_cast(dim.size()); + for (int i = 0; i < broadcast_dim_size; i++) { + if (dim[i] > 1) { + axis_dim.push_back(i); + } + } + + int rank = static_cast(axis_dim.size()); + // axis is a int element + if (rank == 1) { + auto axis = Eigen::array({axis_dim[0]}); + dx->device(place) = + dy->broadcast(dim) * mask / + mask.sum(axis).reshape(dy->dimensions()).broadcast(dim); + return; + } + // axis is list, HANDLE_AXIS_DIM(broadcast_dim_size, rank) + HANDLE_AXIS_DIM(3, 2); + HANDLE_AXIS_DIM(4, 2); + HANDLE_AXIS_DIM(4, 3); + HANDLE_AXIS_DIM(5, 2); + HANDLE_AXIS_DIM(5, 3); + HANDLE_AXIS_DIM(5, 4); + HANDLE_AXIS_DIM(6, 2); + HANDLE_AXIS_DIM(6, 3); + HANDLE_AXIS_DIM(6, 4); + HANDLE_AXIS_DIM(6, 5); + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 77aceef50f..e6311ea2e6 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -203,8 +203,10 @@ from .tensor.math import tanh_ # noqa: F401 from .tensor.math import add_n # noqa: F401 from .tensor.math import max # noqa: F401 from .tensor.math import maximum # noqa: F401 +from .tensor.math import amax # noqa: F401 from .tensor.math import min # noqa: F401 from .tensor.math import minimum # noqa: F401 +from .tensor.math import amin # noqa: F401 from .tensor.math import mm # noqa: F401 from .tensor.math import divide # noqa: F401 from .tensor.math import floor_divide # noqa: F401 @@ -400,6 +402,7 @@ __all__ = [ # noqa 'mv', 'in_dynamic_mode', 'min', + 'amin', 'any', 'slice', 'normal', @@ -442,6 +445,7 @@ __all__ = [ # noqa 'roll', 'batch', 'max', + 'amax', 'logical_or', 'bitwise_and', 'bitwise_or', diff --git a/python/paddle/fluid/tests/unittests/test_max_min_op.py b/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py similarity index 73% rename from python/paddle/fluid/tests/unittests/test_max_min_op.py rename to python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py index 1a42715dc1..fe00a825ba 100644 --- a/python/paddle/fluid/tests/unittests/test_max_min_op.py +++ b/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py @@ -25,7 +25,7 @@ from op_test import OpTest paddle.enable_static() -class TestMaxMinAPI(unittest.TestCase): +class TestMaxMinAmaxAminAPI(unittest.TestCase): def setUp(self): self.init_case() self.cal_np_out_and_gradient() @@ -36,38 +36,54 @@ class TestMaxMinAPI(unittest.TestCase): self.x_np = np.array([[0.2, 0.3, 0.5, 0.9], [0.1, 0.2, 0.6, 0.7]]) self.shape = [2, 4] self.dtype = 'float64' - self.axis = None + self.axis = 0 self.keepdim = False - # If there are multiple minimum or maximum elements, max/min/ is non-derivable, + # If there are multiple minimum or maximum elements, max/min/amax/amin is non-derivable, # its gradient check is not supported by unittest framework, # thus we calculate the gradient by numpy function. def cal_np_out_and_gradient(self): def _cal_np_out_and_gradient(func): - if func is 'max': + if func is 'amax': + out = np.amax(self.x_np, axis=self.axis, keepdims=self.keepdim) + elif func is 'amin': + out = np.amin(self.x_np, axis=self.axis, keepdims=self.keepdim) + elif func is 'max': out = np.max(self.x_np, axis=self.axis, keepdims=self.keepdim) elif func is 'min': out = np.min(self.x_np, axis=self.axis, keepdims=self.keepdim) else: - print('This unittest only test max/min, but now is', func) + print('This unittest only test amax/amin/max/min, but now is', + func) self.np_out[func] = out grad = np.zeros(self.shape) - out_b = np.broadcast_to(out, self.shape) + out_b = np.broadcast_to(out.view(), self.shape) grad[self.x_np == out_b] = 1 + if func in ['amax', 'amin']: + grad_sum = grad.sum(self.axis).reshape(out.shape) + grad_b = np.broadcast_to(grad_sum, self.shape) + grad /= grad_sum + self.np_grad[func] = grad self.np_out = dict() self.np_grad = dict() + _cal_np_out_and_gradient('amax') + _cal_np_out_and_gradient('amin') _cal_np_out_and_gradient('max') _cal_np_out_and_gradient('min') def _choose_paddle_func(self, func, x): - if func is 'max': + if func is 'amax': + out = paddle.amax(x, self.axis, self.keepdim) + elif func is 'amin': + out = paddle.amin(x, self.axis, self.keepdim) + elif func is 'max': out = paddle.max(x, self.axis, self.keepdim) elif func is 'min': out = paddle.min(x, self.axis, self.keepdim) else: - print('This unittest only test max/min, but now is', func) + print('This unittest only test amax/amin/max/min, but now is', func) return out # We check the output between paddle API and numpy in static graph. @@ -86,6 +102,8 @@ class TestMaxMinAPI(unittest.TestCase): fetch_list=[out]) self.assertTrue((np.array(res[0]) == self.np_out[func]).all()) + _test_static_graph('amax') + _test_static_graph('amin') _test_static_graph('max') _test_static_graph('min') @@ -104,12 +122,14 @@ class TestMaxMinAPI(unittest.TestCase): self.assertEqual(np.allclose(self.np_grad[func], x.grad), True) paddle.enable_static() + _test_dygraph('amax') + _test_dygraph('amin') _test_dygraph('max') _test_dygraph('min') -# test multiple minimum or maximum elements -class TestMaxMinAPI2(TestMaxMinAPI): + # test two minimum or maximum elements +class TestMaxMinAmaxAminAPI2(TestMaxMinAmaxAminAPI): def init_case(self): self.x_np = np.array([[0.2, 0.3, 0.9, 0.9], [0.1, 0.1, 0.6, 0.7]]) self.shape = [2, 4] @@ -119,7 +139,7 @@ class TestMaxMinAPI2(TestMaxMinAPI): # test different axis -class TestMaxMinAPI3(TestMaxMinAPI): +class TestMaxMinAmaxAminAPI3(TestMaxMinAmaxAminAPI): def init_case(self): self.x_np = np.array([[0.2, 0.3, 0.9, 0.9], [0.1, 0.1, 0.6, 0.7]]) self.shape = [2, 4] @@ -129,7 +149,7 @@ class TestMaxMinAPI3(TestMaxMinAPI): # test keepdim = True -class TestMaxMinAPI4(TestMaxMinAPI): +class TestMaxMinAmaxAminAPI4(TestMaxMinAmaxAminAPI): def init_case(self): self.x_np = np.array([[0.2, 0.3, 0.9, 0.9], [0.1, 0.1, 0.6, 0.7]]) self.shape = [2, 4] @@ -139,7 +159,7 @@ class TestMaxMinAPI4(TestMaxMinAPI): # test axis is tuple -class TestMaxMinAPI5(TestMaxMinAPI): +class TestMaxMinAmaxAminAPI5(TestMaxMinAmaxAminAPI): def init_case(self): self.x_np = np.array( [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype(np.int32) @@ -147,3 +167,13 @@ class TestMaxMinAPI5(TestMaxMinAPI): self.dtype = 'int32' self.axis = (0, 1) self.keepdim = False + + +# test multiple minimum or maximum elements +class TestMaxMinAmaxAminAPI6(TestMaxMinAmaxAminAPI): + def init_case(self): + self.x_np = np.array([[0.2, 0.9, 0.9, 0.9], [0.9, 0.9, 0.2, 0.2]]) + self.shape = [2, 4] + self.dtype = 'float64' + self.axis = None + self.keepdim = False diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ce6c3e5350..957a42fc69 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -162,8 +162,10 @@ from .math import tanh # noqa: F401 from .math import tanh_ # noqa: F401 from .math import add_n # noqa: F401 from .math import max # noqa: F401 +from .math import amax # noqa: F401 from .math import maximum # noqa: F401 from .math import min # noqa: F401 +from .math import amin # noqa: F401 from .math import minimum # noqa: F401 from .math import mm # noqa: F401 from .math import divide # noqa: F401 @@ -321,8 +323,10 @@ tensor_method_func = [ #noqa 'tanh_', 'add_n', 'max', + 'amax', 'maximum', 'min', + 'amin', 'minimum', 'fmax', 'fmin', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 2e2d443e5c..7d790934c3 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1546,12 +1546,35 @@ def inverse(x, name=None): type='inverse', inputs={'Input': [x] }, outputs={'Output': [out]}) return out +def _get_reduce_all_value(axis): + """ + Internal function for max, min, amax and amin. + It computes the attribute reduce_all value based on axis. + """ + if axis is not None and not isinstance(axis, list): + if isinstance(axis, tuple): + axis = list(axis) + elif isinstance(axis, int): + axis= [axis] + else: + raise TypeError( + "The type of axis must be int, list or tuple, but received {}".format(type(axis))) + + reduce_all = True if axis == None or axis == [] else False + axis = axis if axis != None and axis != [] else [0] + return reduce_all, axis def max(x, axis=None, keepdim=False, name=None): """ Computes the maximum of tensor elements over the given axis. + Note: + The difference between max and amax is: If there are multiple maximum elements, + amax evenly distributes gradient between these equal values, + while max propagates gradient to all of them. + + Args: x(Tensor): A tensor, the data type is float32, float64, int32, int64. axis(int|list|tuple, optional): The axis along which the maximum is computed. @@ -1620,17 +1643,7 @@ def max(x, axis=None, keepdim=False, name=None): #[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]] """ - if axis is not None and not isinstance(axis, list): - if isinstance(axis, tuple): - axis = list(axis) - elif isinstance(axis, int): - axis= [axis] - else: - raise TypeError( - "The type of axis must be int, list or tuple, but received {}".format(type(axis))) - - reduce_all = True if axis == None or axis == [] else False - axis = axis if axis != None and axis != [] else [0] + reduce_all, axis = _get_reduce_all_value(axis) if in_dygraph_mode(): return _C_ops.reduce_max(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) @@ -1657,6 +1670,11 @@ def min(x, axis=None, keepdim=False, name=None): Computes the minimum of tensor elements over the given axis + Note: + The difference between min and amin is: If there are multiple minimum elements, + amin evenly distributes gradient between these equal values, + while min propagates gradient to all of them. + Args: x(Tensor): A tensor, the data type is float32, float64, int32, int64. axis(int|list|tuple, optional): The axis along which the minimum is computed. @@ -1725,17 +1743,7 @@ def min(x, axis=None, keepdim=False, name=None): #[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]] """ - if axis is not None and not isinstance(axis, list): - if isinstance(axis, tuple): - axis = list(axis) - elif isinstance(axis, int): - axis= [axis] - else: - raise TypeError( - "The type of axis must be int, list or tuple, but received {}".format(type(axis))) - - reduce_all = True if axis == None or axis == [] else False - axis = axis if axis != None and axis != [] else [0] + reduce_all, axis = _get_reduce_all_value(axis) if in_dygraph_mode(): return _C_ops.reduce_min(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) @@ -1757,6 +1765,203 @@ def min(x, axis=None, keepdim=False, name=None): }) return out +def amax(x, axis=None, keepdim=False, name=None): + """ + Computes the maximum of tensor elements over the given axis. + + Note: + The difference between max and amax is: If there are multiple maximum elements, + amax evenly distributes gradient between these equal values, + while max propagates gradient to all of them. + + Args: + x(Tensor): A tensor, the data type is float32, float64, int32, int64. + axis(int|list|tuple, optional): The axis along which the maximum is computed. + If :attr:`None`, compute the maximum over all elements of + `x` and return a Tensor with a single element, + otherwise must be in the range :math:`[-x.ndim(x), x.ndim(x))`. + If :math:`axis[i] < 0`, the axis to reduce is :math:`x.ndim + axis[i]`. + keepdim(bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the `x` unless :attr:`keepdim` is true, default + value is False. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tensor, results of maximum on the specified axis of input tensor, + it's data type is the same as `x`. + + Examples: + .. code-block:: python + + import paddle + # data_x is a Tensor with shape [2, 4] with multiple maximum elements + # the axis is a int element + + x = paddle.to_tensor([[0.1, 0.9, 0.9, 0.9], + [0.9, 0.9, 0.6, 0.7]], + dtype='float64', stop_gradient=False) + result1 = paddle.amax(x) + result1.backward() + print(result1, x.grad) + #[0.9], [[0., 0.2, 0.2, 0.2], [0.2, 0.2, 0., 0.]] + + x.clear_grad() + result2 = paddle.amax(x, axis=0) + result2.backward() + print(result2, x.grad) + #[0.9, 0.9, 0.9, 0.9], [[0., 0.5, 1., 1.], [1., 0.5, 0., 0.]] + + x.clear_grad() + result3 = paddle.amax(x, axis=-1) + result3.backward() + print(result3, x.grad) + #[0.9, 0.9], [[0., 0.3333, 0.3333, 0.3333], [0.5, 0.5, 0., 0.]] + + x.clear_grad() + result4 = paddle.amax(x, axis=1, keepdim=True) + result4.backward() + print(result4, x.grad) + #[[0.9], [0.9]], [[0., 0.3333, 0.3333, 0.3333.], [0.5, 0.5, 0., 0.]] + + # data_y is a Tensor with shape [2, 2, 2] + # the axis is list + y = paddle.to_tensor([[[0.1, 0.9], [0.9, 0.9]], + [[0.9, 0.9], [0.6, 0.7]]], + dtype='float64', stop_gradient=False) + result5 = paddle.amax(y, axis=[1, 2]) + result5.backward() + print(result5, y.grad) + #[0.9., 0.9], [[[0., 0.3333], [0.3333, 0.3333]], [[0.5, 0.5], [0., 1.]]] + + y.clear_grad() + result6 = paddle.amax(y, axis=[0, 1]) + result6.backward() + print(result6, y.grad) + #[0.9., 0.9], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] + """ + + reduce_all, axis = _get_reduce_all_value(axis) + if in_dygraph_mode(): + return _C_ops.reduce_amax(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) + + helper = LayerHelper('amax', **locals()) + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'int32', 'int64'], 'amax') + + out = helper.create_variable_for_type_inference( + dtype=x.dtype) + helper.append_op( + type='reduce_amax', + inputs={'X': x}, + outputs={'Out': out}, + attrs={ + 'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': reduce_all + }) + return out + +def amin(x, axis=None, keepdim=False, name=None): + """ + + Computes the minimum of tensor elements over the given axis + + Note: + The difference between min and amin is: If there are multiple minimum elements, + amin evenly distributes gradient between these equal values, + while min propagates gradient to all of them. + + Args: + x(Tensor): A tensor, the data type is float32, float64, int32, int64. + axis(int|list|tuple, optional): The axis along which the minimum is computed. + If :attr:`None`, compute the minimum over all elements of + `x` and return a Tensor with a single element, + otherwise must be in the range :math:`[-x.ndim, x.ndim)`. + If :math:`axis[i] < 0`, the axis to reduce is :math:`x.ndim + axis[i]`. + keepdim(bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the `x` unless :attr:`keepdim` is true, default + value is False. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tensor, results of minimum on the specified axis of input tensor, + it's data type is the same as input's Tensor. + + Examples: + .. code-block:: python + + import paddle + # data_x is a Tensor with shape [2, 4] with multiple minimum elements + # the axis is a int element + + x = paddle.to_tensor([[0.2, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.7]], + dtype='float64', stop_gradient=False) + result1 = paddle.amin(x) + result1.backward() + print(result1, x.grad) + #[0.1], [[0., 0.2, 0.2, 0.2], [0.2, 0.2, 0., 0.]] + + x.clear_grad() + result2 = paddle.amin(x, axis=0) + result2.backward() + print(result2, x.grad) + #[0.1, 0.1, 0.1, 0.1], [[0., 0.5, 1., 1.], [1., 0.5, 0., 0.]] + + x.clear_grad() + result3 = paddle.amin(x, axis=-1) + result3.backward() + print(result3, x.grad) + #[0.1, 0.1], [[0., 0.3333, 0.3333, 0.3333], [0.5, 0.5, 0., 0.]] + + x.clear_grad() + result4 = paddle.amin(x, axis=1, keepdim=True) + result4.backward() + print(result4, x.grad) + #[[0.1], [0.1]], [[0., 0.3333, 0.3333, 0.3333.], [0.5, 0.5, 0., 0.]] + + # data_y is a Tensor with shape [2, 2, 2] + # the axis is list + y = paddle.to_tensor([[[0.2, 0.1], [0.1, 0.1]], + [[0.1, 0.1], [0.6, 0.7]]], + dtype='float64', stop_gradient=False) + result5 = paddle.amin(y, axis=[1, 2]) + result5.backward() + print(result5, y.grad) + #[0.1., 0.1], [[[0., 0.3333], [0.3333, 0.3333]], [[0.5, 0.5], [0., 1.]]] + + y.clear_grad() + result6 = paddle.amin(y, axis=[0, 1]) + result6.backward() + print(result6, y.grad) + #[0.1., 0.1], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] + """ + + reduce_all, axis = _get_reduce_all_value(axis) + if in_dygraph_mode(): + return _C_ops.reduce_amin(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) + + helper = LayerHelper('amin', **locals()) + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'int32', 'int64'], 'amin') + + out = helper.create_variable_for_type_inference( + dtype=x.dtype) + helper.append_op( + type='reduce_amin', + inputs={'X': x}, + outputs={'Out': out}, + attrs={ + 'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': reduce_all + }) + return out + def log1p(x, name=None): r""" Calculates the natural log of the given input tensor, element-wise. -- GitLab