diff --git a/python/paddle/fluid/tests/unittests/test_max_min_op.py b/python/paddle/fluid/tests/unittests/test_max_min_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1a42715dc11314a0188dedfaf2ecd0fe982217ef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_max_min_op.py @@ -0,0 +1,149 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard +from op_test import OpTest + +paddle.enable_static() + + +class TestMaxMinAPI(unittest.TestCase): + def setUp(self): + self.init_case() + self.cal_np_out_and_gradient() + self.place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + def init_case(self): + self.x_np = np.array([[0.2, 0.3, 0.5, 0.9], [0.1, 0.2, 0.6, 0.7]]) + self.shape = [2, 4] + self.dtype = 'float64' + self.axis = None + self.keepdim = False + + # If there are multiple minimum or maximum elements, max/min/ is non-derivable, + # its gradient check is not supported by unittest framework, + # thus we calculate the gradient by numpy function. + def cal_np_out_and_gradient(self): + def _cal_np_out_and_gradient(func): + if func is 'max': + out = np.max(self.x_np, axis=self.axis, keepdims=self.keepdim) + elif func is 'min': + out = np.min(self.x_np, axis=self.axis, keepdims=self.keepdim) + else: + print('This unittest only test max/min, but now is', func) + self.np_out[func] = out + grad = np.zeros(self.shape) + out_b = np.broadcast_to(out, self.shape) + grad[self.x_np == out_b] = 1 + self.np_grad[func] = grad + + self.np_out = dict() + self.np_grad = dict() + _cal_np_out_and_gradient('max') + _cal_np_out_and_gradient('min') + + def _choose_paddle_func(self, func, x): + if func is 'max': + out = paddle.max(x, self.axis, self.keepdim) + elif func is 'min': + out = paddle.min(x, self.axis, self.keepdim) + else: + print('This unittest only test max/min, but now is', func) + return out + + # We check the output between paddle API and numpy in static graph. + def test_static_graph(self): + def _test_static_graph(func): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(startup_program, train_program): + x = fluid.data(name='input', dtype=self.dtype, shape=self.shape) + x.stop_gradient = False + out = self._choose_paddle_func(func, x) + + exe = fluid.Executor(self.place) + res = exe.run(fluid.default_main_program(), + feed={'input': self.x_np}, + fetch_list=[out]) + self.assertTrue((np.array(res[0]) == self.np_out[func]).all()) + + _test_static_graph('max') + _test_static_graph('min') + + # As dygraph is easy to compute gradient, we check the gradient between + # paddle API and numpy in dygraph. + def test_dygraph(self): + def _test_dygraph(func): + paddle.disable_static() + x = paddle.to_tensor( + self.x_np, dtype=self.dtype, stop_gradient=False) + out = self._choose_paddle_func(func, x) + grad_tensor = paddle.ones_like(x) + paddle.autograd.backward([out], [grad_tensor], True) + + self.assertEqual(np.allclose(self.np_out[func], out.numpy()), True) + self.assertEqual(np.allclose(self.np_grad[func], x.grad), True) + paddle.enable_static() + + _test_dygraph('max') + _test_dygraph('min') + + +# test multiple minimum or maximum elements +class TestMaxMinAPI2(TestMaxMinAPI): + def init_case(self): + self.x_np = np.array([[0.2, 0.3, 0.9, 0.9], [0.1, 0.1, 0.6, 0.7]]) + self.shape = [2, 4] + self.dtype = 'float64' + self.axis = None + self.keepdim = False + + +# test different axis +class TestMaxMinAPI3(TestMaxMinAPI): + def init_case(self): + self.x_np = np.array([[0.2, 0.3, 0.9, 0.9], [0.1, 0.1, 0.6, 0.7]]) + self.shape = [2, 4] + self.dtype = 'float64' + self.axis = 0 + self.keepdim = False + + +# test keepdim = True +class TestMaxMinAPI4(TestMaxMinAPI): + def init_case(self): + self.x_np = np.array([[0.2, 0.3, 0.9, 0.9], [0.1, 0.1, 0.6, 0.7]]) + self.shape = [2, 4] + self.dtype = 'float64' + self.axis = 1 + self.keepdim = True + + +# test axis is tuple +class TestMaxMinAPI5(TestMaxMinAPI): + def init_case(self): + self.x_np = np.array( + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype(np.int32) + self.shape = [2, 2, 2] + self.dtype = 'int32' + self.axis = (0, 1) + self.keepdim = False diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e878aede1332903f10f5cbb615b4d41cb82b2145..47ca3fed0bdff8e02b070398aa523e4fa7c9f4fc 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1437,8 +1437,7 @@ def max(x, axis=None, keepdim=False, name=None): Computes the maximum of tensor elements over the given axis. Args: - x(Tensor): A tensor, the data type is float32, - float64, int32, int64. + x(Tensor): A tensor, the data type is float32, float64, int32, int64. axis(int|list|tuple, optional): The axis along which the maximum is computed. If :attr:`None`, compute the maximum over all elements of `x` and return a Tensor with a single element, @@ -1462,34 +1461,47 @@ def max(x, axis=None, keepdim=False, name=None): # data_x is a Tensor with shape [2, 4] # the axis is a int element - x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9], - [0.1, 0.2, 0.6, 0.7]]) + [0.1, 0.2, 0.6, 0.7]], + dtype='float64', stop_gradient=False) result1 = paddle.max(x) - print(result1) - #[0.9] + result1.backward() + print(result1, x.grad) + #[0.9], [[0., 0., 0., 1.], [0., 0., 0., 0.]] + + x.clear_grad() result2 = paddle.max(x, axis=0) - print(result2) - #[0.2 0.3 0.6 0.9] + result2.backward() + print(result2, x.grad) + #[0.2, 0.3, 0.6, 0.9], [[1., 1., 0., 1.], [0., 0., 1., 0.]] + + x.clear_grad() result3 = paddle.max(x, axis=-1) - print(result3) - #[0.9 0.7] + result3.backward() + print(result3, x.grad) + #[0.9, 0.7], [[0., 0., 0., 1.], [0., 0., 0., 1.]] + + x.clear_grad() result4 = paddle.max(x, axis=1, keepdim=True) - print(result4) - #[[0.9] - # [0.7]] + result4.backward() + print(result4, x.grad) + #[[0.9], [0.7]], [[0., 0., 0., 1.], [0., 0., 0., 1.]] # data_y is a Tensor with shape [2, 2, 2] # the axis is list - y = paddle.to_tensor([[[1.0, 2.0], [3.0, 4.0]], - [[5.0, 6.0], [7.0, 8.0]]]) + [[5.0, 6.0], [7.0, 8.0]]], + dtype='float64', stop_gradient=False) result5 = paddle.max(y, axis=[1, 2]) - print(result5) - #[4. 8.] + result5.backward() + print(result5, y.grad) + #[4., 8.], [[[0., 0.], [0., 1.]], [[0., 0.], [0., 1.]]] + + y.clear_grad() result6 = paddle.max(y, axis=[0, 1]) - print(result6) - #[7. 8.] + result6.backward() + print(result6, y.grad) + #[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]] """ if axis is not None and not isinstance(axis, list): @@ -1552,34 +1564,49 @@ def min(x, axis=None, keepdim=False, name=None): import paddle - # x is a tensor with shape [2, 4] + # data_x is a Tensor with shape [2, 4] # the axis is a int element x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9], - [0.1, 0.2, 0.6, 0.7]]) + [0.1, 0.2, 0.6, 0.7]], + dtype='float64', stop_gradient=False) result1 = paddle.min(x) - print(result1) - #[0.1] + result1.backward() + print(result1, x.grad) + #[0.1], [[0., 0., 0., 0.], [1., 0., 0., 0.]] + + x.clear_grad() result2 = paddle.min(x, axis=0) - print(result2) - #[0.1 0.2 0.5 0.7] + result2.backward() + print(result2, x.grad) + #[0.1, 0.2, 0.5, 0.7], [[0., 0., 1., 0.], [1., 1., 0., 1.]] + + x.clear_grad() result3 = paddle.min(x, axis=-1) - print(result3) - #[0.2 0.1] + result3.backward() + print(result3, x.grad) + #[0.2, 0.1], [[1., 0., 0., 0.], [1., 0., 0., 0.]] + + x.clear_grad() result4 = paddle.min(x, axis=1, keepdim=True) - print(result4) - #[[0.2] - # [0.1]] + result4.backward() + print(result4, x.grad) + #[[0.2], [0.1]], [[1., 0., 0., 0.], [1., 0., 0., 0.]] - # y is a Tensor with shape [2, 2, 2] + # data_y is a Tensor with shape [2, 2, 2] # the axis is list y = paddle.to_tensor([[[1.0, 2.0], [3.0, 4.0]], - [[5.0, 6.0], [7.0, 8.0]]]) + [[5.0, 6.0], [7.0, 8.0]]], + dtype='float64', stop_gradient=False) result5 = paddle.min(y, axis=[1, 2]) - print(result5) - #[1. 5.] + result5.backward() + print(result5, y.grad) + #[1., 5.], [[[1., 0.], [0., 0.]], [[1., 0.], [0., 0.]]] + + y.clear_grad() result6 = paddle.min(y, axis=[0, 1]) - print(result6) - #[1. 2.] + result6.backward() + print(result6, y.grad) + #[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]] """ if axis is not None and not isinstance(axis, list): @@ -1590,6 +1617,7 @@ def min(x, axis=None, keepdim=False, name=None): else: raise TypeError( "The type of axis must be int, list or tuple, but received {}".format(type(axis))) + reduce_all = True if axis == None or axis == [] else False axis = axis if axis != None and axis != [] else [0] if in_dygraph_mode(): @@ -1613,7 +1641,6 @@ def min(x, axis=None, keepdim=False, name=None): }) return out - def log1p(x, name=None): r""" Calculates the natural log of the given input tensor, element-wise.