diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index cfcb64732dc03d6c7aeca95eda0d4851ba17698d..8ec77513530438cad5caecaa2955afa59d474826 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -279,6 +279,15 @@ Natural logarithm of x. )DOC"; +UNUSED constexpr char Log1pDoc[] = R"DOC( +Log Activation Operator. + +$out = \ln(x+1)$ + +Natural logarithm of x. + +)DOC"; + UNUSED constexpr char SquareDoc[] = R"DOC( The OP square each elements of the inputs. @@ -634,6 +643,7 @@ REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc); REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc); REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc); REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc); +REGISTER_ACTIVATION_OP_MAKER(Log1p, Log1pDoc); REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc); REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc); REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index cb210fca31a7b1271c0b11b721d4bd4eeb9cf82a..ec3c39097a01c1404f10455c32c585bdc090900e 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -737,6 +737,26 @@ struct LogGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +// log1p(x) = natural logarithm of x+1 +template +struct Log1pFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = (static_cast(1) + x).log(); + } +}; + +template +struct Log1pGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast(1) / (x + static_cast(1))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { @@ -1718,6 +1738,7 @@ class PowGradKernel __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(log, Log, LogFunctor, LogGradFunctor); \ + __macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \ __macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 7a9101f101e09a04843e0690fcadf41c9d7aaa15..2b70cbc3ecf69ce0e2b4ca3ed779cf9ef728e846 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -43,7 +43,7 @@ import paddle.nn # from .tensor.creation import create_random_int_lod.tensor #DEFINE_ALIAS # from .tensor.creation import crop_.tensor #DEFINE_ALIAS # from .tensor.creation import diag #DEFINE_ALIAS -# from .tensor.creation import eye #DEFINE_ALIAS +from .tensor.creation import eye #DEFINE_ALIAS from .tensor.creation import fill_constant #DEFINE_ALIAS # from .tensor.creation import get_.tensor_from_selected_rows #DEFINE_ALIAS from .tensor.creation import linspace #DEFINE_ALIAS @@ -131,15 +131,15 @@ from .tensor.math import sum #DEFINE_ALIAS # from .tensor.math import sums #DEFINE_ALIAS from .tensor.math import tanh #DEFINE_ALIAS from .tensor.math import elementwise_sum #DEFINE_ALIAS -# from .tensor.math import max #DEFINE_ALIAS -# from .tensor.math import min #DEFINE_ALIAS +from .tensor.math import max #DEFINE_ALIAS +from .tensor.math import min #DEFINE_ALIAS from .tensor.math import mm #DEFINE_ALIAS from .tensor.math import div #DEFINE_ALIAS from .tensor.math import add #DEFINE_ALIAS # from .tensor.math import atan #DEFINE_ALIAS from .tensor.math import logsumexp #DEFINE_ALIAS # from .tensor.math import inverse #DEFINE_ALIAS -# from .tensor.math import log1p #DEFINE_ALIAS +from .tensor.math import log1p #DEFINE_ALIAS # from .tensor.math import erf #DEFINE_ALIAS # from .tensor.math import addcmul #DEFINE_ALIAS from .tensor.math import addmm #DEFINE_ALIAS @@ -153,7 +153,7 @@ from .tensor.linalg import dot #DEFINE_ALIAS from .tensor.linalg import norm #DEFINE_ALIAS # from .tensor.linalg import transpose #DEFINE_ALIAS from .tensor.linalg import dist #DEFINE_ALIAS -# from .tensor.linalg import t #DEFINE_ALIAS +from .tensor.linalg import t #DEFINE_ALIAS # from .tensor.linalg import cross #DEFINE_ALIAS # from .tensor.linalg import cholesky #DEFINE_ALIAS # from .tensor.linalg import .tensordot #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 79ef30f85d5e1795a2e3c27845ddaa7eab875480..48fe1ca2ca11dc2878c081b0e1625de09bc0d121 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -775,6 +775,57 @@ class TestLog(TestActivation): self.assertRaises(TypeError, fluid.layers.log, in2) +class TestLog1p(TestActivation): + def setUp(self): + self.op_type = "log1p" + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.log1p(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out') + + def test_api(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input_x = np.random.uniform(0.1, 1, [11, 17]).astype("float64") + data_x = fluid.layers.data( + name="data_x", + shape=[11, 17], + append_batch_size=False, + dtype="float64") + res_log1p = fluid.layers.data( + name="res_log1p", + shape=[11, 17], + append_batch_size=False, + dtype="float64") + + out1 = paddle.log1p(data_x) + out2 = paddle.log1p(data_x, out=res_log1p) + exe = fluid.Executor(place=fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + res1, res_in = exe.run(fluid.default_main_program(), + feed={"data_x": input_x}, + fetch_list=[out1, res_log1p]) + expected_res = np.log1p(input_x) + np.testing.assert_allclose(res1, expected_res) + np.testing.assert_allclose(res_in, expected_res) + + # dygraph + with fluid.dygraph.guard(): + np_x = np.random.uniform(0.1, 1, [11, 17]).astype("float64") + data_x = fluid.dygraph.to_variable(np_x) + z = paddle.log1p(data_x) + np_z = z.numpy() + z_expected = np.array(np.log1p(np_x)) + np.testing.assert_allclose(np_z, z_expected) + + class TestSquare(TestActivation): def setUp(self): self.op_type = "square" @@ -1173,6 +1224,7 @@ create_test_act_fp16_class(TestSoftRelu) create_test_act_fp16_class(TestELU) create_test_act_fp16_class(TestReciprocal) create_test_act_fp16_class(TestLog) +create_test_act_fp16_class(TestLog1p, grad_atol=0.9) create_test_act_fp16_class(TestSquare) create_test_act_fp16_class(TestPow, atol=5e-2) create_test_act_fp16_class(TestPow_factor_tensor, atol=5e-2) diff --git a/python/paddle/fluid/tests/unittests/test_eye_op.py b/python/paddle/fluid/tests/unittests/test_eye_op.py index ea37584b6a5e1d72badc65c294898bdf08f32a2a..fbbf01abae63829d3e6c34e636bcbc23762334d2 100644 --- a/python/paddle/fluid/tests/unittests/test_eye_op.py +++ b/python/paddle/fluid/tests/unittests/test_eye_op.py @@ -18,6 +18,8 @@ import unittest import numpy as np from op_test import OpTest +import paddle +import paddle.fluid as fluid import paddle.fluid.framework as framework @@ -70,5 +72,45 @@ class TestEyeOp2(OpTest): self.check_output() +class API_TestTensorEye(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program()): + data = paddle.eye(10) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[data]) + expected_result = np.eye(10, dtype="float32") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = paddle.eye(10, num_columns=7, dtype="float64") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[data]) + expected_result = np.eye(10, 7, dtype="float64") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = paddle.eye(10, dtype="int64") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[data]) + expected_result = np.eye(10, dtype="int64") + self.assertEqual((result == expected_result).all(), True) + + def test_errors(self): + with fluid.program_guard(fluid.Program()): + + def test_num_rows_type_check(): + paddle.eye(-1, dtype="int64") + + self.assertRaises(TypeError, test_num_rows_type_check) + + def test_num_columns_type_check(): + paddle.eye(10, num_columns=5.2, dtype="int64") + + self.assertRaises(TypeError, test_num_columns_type_check) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index ae55a7844d16b23b05577d0b4959ccd329057a9b..644e0ca81671bf9fbeda9a9ed1829a1ada25cfc8 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -574,5 +574,69 @@ class API_TestSumOp(unittest.TestCase): self.assertEqual((np_z == z_expected).all(), True) +class API_TestMaxOp(unittest.TestCase): + def test_1(self): + # type: float + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=[10, 10], dtype="float32") + result_max = paddle.max(input=data, dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input_data = np.random.rand(10, 10).astype(np.float32) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_max]) + self.assertEqual((res == np.max(input_data, axis=1)).all(), True) + + # type: int + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=[10, 10], dtype="int64") + result_max = paddle.max(input=data, dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input_data = np.random.randint(10, size=(10, 10)).astype(np.int64) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_max]) + self.assertEqual((res == np.max(input_data, axis=1)).all(), True) + + # dygraph + with fluid.dygraph.guard(): + np_x = np.array([10, 10]).astype('float64') + x = fluid.dygraph.to_variable(np_x) + z = paddle.max(x, dim=0) + np_z = z.numpy() + z_expected = np.array(np.max(np_x, axis=0)) + self.assertEqual((np_z == z_expected).all(), True) + + +class API_TestMinOp(unittest.TestCase): + def test_1(self): + # type: float + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=[10, 10], dtype="float32") + result_min = paddle.min(input=data, dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input_data = np.random.rand(10, 10).astype(np.float32) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_min]) + self.assertEqual((res == np.min(input_data, axis=1)).all(), True) + + # type: int + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=[10, 10], dtype="int64") + result_min = paddle.min(input=data, dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input_data = np.random.randint(10, size=(10, 10)).astype(np.int64) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_min]) + self.assertEqual((res == np.min(input_data, axis=1)).all(), True) + + # dygraph + with fluid.dygraph.guard(): + np_x = np.array([10, 10]).astype('float64') + x = fluid.dygraph.to_variable(np_x) + z = paddle.min(x, dim=0) + np_z = z.numpy() + z_expected = np.array(np.min(np_x, axis=0)) + self.assertEqual((np_z == z_expected).all(), True) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index c8e3a1f2c26177687ef42accc4cf3e74b88c2663..d5d1fdc5b20b9d786f1861d63bb4a646117b80ed 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid as fluid from paddle.fluid import Program, program_guard @@ -137,5 +138,71 @@ class TestTransposeOpError(unittest.TestCase): self.assertRaises(ValueError, test_each_elem_value_check) +class TestTAPI(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10], dtype="float64", name="data") + data_t = paddle.t(data) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + data_np = np.random.random([10]).astype("float64") + result, = exe.run(feed={"data": data_np}, fetch_list=[data_t]) + expected_result = np.transpose(data_np) + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10, 5], dtype="float64", name="data") + data_t = paddle.t(data) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + data_np = np.random.random([10, 5]).astype("float64") + result, = exe.run(feed={"data": data_np}, fetch_list=[data_t]) + expected_result = np.transpose(data_np) + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[1, 5], dtype="float64", name="data") + data_t = paddle.t(data) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + data_np = np.random.random([1, 5]).astype("float64") + result, = exe.run(feed={"data": data_np}, fetch_list=[data_t]) + expected_result = np.transpose(data_np) + self.assertEqual((result == expected_result).all(), True) + + with fluid.dygraph.guard(): + np_x = np.random.random([10]).astype("float64") + data = fluid.dygraph.to_variable(np_x) + z = paddle.t(data) + np_z = z.numpy() + z_expected = np.array(np.transpose(np_x)) + self.assertEqual((np_z == z_expected).all(), True) + + with fluid.dygraph.guard(): + np_x = np.random.random([10, 5]).astype("float64") + data = fluid.dygraph.to_variable(np_x) + z = paddle.t(data) + np_z = z.numpy() + z_expected = np.array(np.transpose(np_x)) + self.assertEqual((np_z == z_expected).all(), True) + + with fluid.dygraph.guard(): + np_x = np.random.random([1, 5]).astype("float64") + data = fluid.dygraph.to_variable(np_x) + z = paddle.t(data) + np_z = z.numpy() + z_expected = np.array(np.transpose(np_x)) + self.assertEqual((np_z == z_expected).all(), True) + + def test_errors(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data(name='x', shape=[10, 5, 3], dtype='float64') + + def test_x_dimension_check(): + paddle.t(x) + + self.assertRaises(ValueError, test_x_dimension_check) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index e9cd06942fe929074cdeea74de39487bfb55f897..710e42fb6f42b117e4fb20b772ca36eba6704f8a 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + +#from .math import * +#from .creation import * +#from .linalg import * # TODO: define alias in tensor and framework directory # from .creation import create_tensor #DEFINE_ALIAS @@ -18,7 +23,7 @@ # from .creation import create_random_int_lod #DEFINE_ALIAS # from .creation import crop_tensor #DEFINE_ALIAS # from .creation import diag #DEFINE_ALIAS -# from .creation import eye #DEFINE_ALIAS +from .creation import eye #DEFINE_ALIAS # from .creation import fill_constant #DEFINE_ALIAS # from .creation import get__from_selected_rows #DEFINE_ALIAS from .creation import linspace #DEFINE_ALIAS @@ -106,15 +111,15 @@ from .math import sum #DEFINE_ALIAS # from .math import sums #DEFINE_ALIAS from .math import tanh #DEFINE_ALIAS from .math import elementwise_sum #DEFINE_ALIAS -# from .math import max #DEFINE_ALIAS -# from .math import min #DEFINE_ALIAS +from .math import max #DEFINE_ALIAS +from .math import min #DEFINE_ALIAS from .math import mm #DEFINE_ALIAS from .math import div #DEFINE_ALIAS from .math import add #DEFINE_ALIAS # from .math import atan #DEFINE_ALIAS from .math import logsumexp #DEFINE_ALIAS # from .math import inverse #DEFINE_ALIAS -# from .math import log1p #DEFINE_ALIAS +from .math import log1p #DEFINE_ALIAS # from .math import erf #DEFINE_ALIAS # from .math import addcmul #DEFINE_ALIAS from .math import addmm #DEFINE_ALIAS @@ -128,7 +133,7 @@ from .linalg import dot #DEFINE_ALIAS from .linalg import norm #DEFINE_ALIAS # from .linalg import transpose #DEFINE_ALIAS from .linalg import dist #DEFINE_ALIAS -# from .linalg import t #DEFINE_ALIAS +from .linalg import t #DEFINE_ALIAS # from .linalg import cross #DEFINE_ALIAS # from .linalg import cholesky #DEFINE_ALIAS # from .manipulation import cast #DEFINE_ALIAS diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index fb199ede4e4eb756021d2396755f3a236adf9bee..232cb6f1b28e936b24c151fff3956391cbd06034 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -38,7 +38,7 @@ __all__ = [ 'zeros', 'zeros_like', # 'arrange', - # 'eye', + 'eye', 'full', 'full_like', 'triu', @@ -396,6 +396,66 @@ def zeros_like(input, dtype=None, device=None, name=None): return out +def eye(num_rows, + num_columns=None, + out=None, + dtype='float32', + stop_gradient=True, + name=None): + """ + **eye** + This function constructs an identity tensor, or a batch of tensor. + Args: + num_rows(int): the number of rows in each batch tensor. + num_columns(int, optional): the number of columns in each batch tensor. + If None, default: num_rows. + out(Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. + dtype(string, optional): The data type of the returned tensor. + It should be int32, int64, float16, float32, float64. + stop_gradient(bool, optional): Whether stop calculating gradients. Default:True. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + Returns: + Variable: An identity Tensor or LoDTensor of shape [num_rows, num_columns]. + Examples: + .. code-block:: python + import paddle + data = paddle.eye(3, dtype='int32') + # [[1, 0, 0] + # [0, 1, 0] + # [0, 0, 1]] + data = paddle.eye(2, 3, dtype='int32') + # [[1, 0, 0] + # [0, 1, 0]] + """ + + helper = LayerHelper("eye", **locals()) + if not isinstance(num_rows, int) or num_rows < 0: + raise TypeError("num_rows should be a non-negative int") + if num_columns is not None: + if not isinstance(num_columns, int) or num_columns < 0: + raise TypeError("num_columns should be a non-negative int") + else: + num_columns = num_rows + if out is None: + out = helper.create_variable_for_type_inference(dtype=dtype) + c_dtype = convert_np_dtype_to_dtype_(dtype) + helper.append_op( + type='eye', + inputs={}, + outputs={'Out': [out]}, + attrs={ + 'num_rows': num_rows, + 'num_columns': num_columns, + 'dtype': c_dtype + }, + stop_gradient=True) + out.stop_gradient = stop_gradient + return out + + def full(shape, fill_value, out=None, diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index aade3e87a567210df1ad13185fcbee7a7ba7899a..a62524c8e677c1b0ad0b361452e54457d1ebc1b3 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -23,7 +23,7 @@ __all__ = [ 'norm', # 'transpose', 'dist', - # 't', + 't', # 'cross', # 'cholesky', # 'tensordot' @@ -458,3 +458,74 @@ def dot(x, y, name=None): type="dot", inputs={'X': x, 'Y': y}, attrs={}, outputs={"Out": out}) return out + + +def t(input, name=None): + """ + Transpose <=2-D tensor. + 0-D and 1-D tensors are returned as it is and 2-D tensor is equal to + the fluid.layers.transpose function which perm dimensions set 0 and 1. + + Args: + input (Variable): The input Tensor. It is a N-D (N<=2) Tensor of data types float32, float64, int32. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + Returns: + Variable: A transposed n-D Tensor, with data type being float32, float64, int32, int64. + + For Example: + .. code-block:: text + # Example 1 (0-D tensor) + x = tensor([0.79]) + paddle.t(x) = tensor([0.79]) + # Example 2 (1-D tensor) + x = tensor([0.79, 0.84, 0.32]) + paddle.t(x) = tensor([0.79, 0.84, 0.32]) + + # Example 3 (2-D tensor) + x = tensor([0.79, 0.84, 0.32], + [0.64, 0.14, 0.57]) + paddle.t(x) = tensor([0.79, 0.64], + [0.84, 0.14], + [0.32, 0.57]) + + Examples: + .. code-block:: python + import paddle + import paddle.fluid as fluid + x = fluid.data(name='x', shape=[2, 3], + dtype='float32') + x_transposed = paddle.t(x) + print x_transposed.shape + #(3L, 2L) + """ + if len(input.shape) > 2: + raise ValueError( + "Input(input) only support N-D (N<=2) tensor, but received " + "length of Input(input) is %s. Perhaps you can use paddle." + "tensor.transpose() instead." % len(input.shape)) + if in_dygraph_mode(): + if len(input.shape) == 1: + return input + # 2-D tensor + perm = [1, 0] + out, _ = core.ops.transpose2(input, 'axis', perm) + return out + + check_variable_and_dtype( + input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'], + 'transpose') + + helper = LayerHelper('t', **locals()) + out = helper.create_variable_for_type_inference(input.dtype) + input_shape = helper.create_variable_for_type_inference(input.dtype) + if len(input.shape) == 1: + out = input + else: + helper.append_op( + type='transpose2', + inputs={'X': [input]}, + outputs={'Out': [out], + 'XShape': [input_shape]}, + attrs={'axis': [1, 0]}) + return out diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 71de0b7284ad8180eacb565d416021fcb4b4feb5..d503f9538794b435e0f2e913638534cd99bc9586 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -65,15 +65,15 @@ __all__ = [ # 'sums', 'tanh', 'elementwise_sum', -# 'max', -# 'min', + 'max', + 'min', 'mm', 'div', 'add', # 'atan', 'logsumexp', # 'inverse', -# 'log1p', + 'log1p', # 'erf', # 'addcmul', 'addmm' @@ -1062,3 +1062,196 @@ Examples: return out return layers.log(sum_out, name) + + +def max(input, dim=None, keep_dim=False, out=None, name=None): + """ + Computes the maximum of tensor elements over the given dimension. + + Args: + input (Variable): The input variable which is a Tensor, the data type is float32, + float64, int32, int64. + dim (list|int, optional): The dimension along which the maximum is computed. + If :attr:`None`, compute the maximum over all elements of + :attr:`input` and return a Tensor variable with a single element, + otherwise must be in the range :math:`[-rank(input), rank(input))`. + If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`. + keep_dim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the :attr:`input` unless :attr:`keep_dim` is true, default + value is False. + out(Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Variable: Tensor, results of maximum on the specified dim of input tensor, + it's data type is the same as input's Tensor. + + Examples: + .. code-block:: python + import paddle + import paddle.fluid as fluid + + # x is a Tensor variable with following elements: + # [[0.2, 0.3, 0.5, 0.9] + # [0.1, 0.2, 0.6, 0.7]] + # Each example is followed by the corresponding output tensor. + x = fluid.data(name='x', shape=[2, 4], dtype='float32') + paddle.max(x) # [0.9] + paddle.max(x, dim=0) # [0.2, 0.3, 0.6, 0.9] + paddle.max(x, dim=-1) # [0.9, 0.7] + paddle.max(x, dim=1, keep_dim=True) # [[0.9], [0.7]] + # y is a Tensor variable with shape [2, 2, 2] and elements as below: + # [[[1.0, 2.0], [3.0, 4.0]], + # [[5.0, 6.0], [7.0, 8.0]]] + # Each example is followed by the corresponding output tensor. + y = fluid.data(name='y', shape=[2, 2, 2], dtype='float32') + paddle.max(y, dim=[1, 2]) # [4.0, 8.0] + paddle.max(y, dim=[0, 1]) # [7.0, 8.0] + """ + + helper = LayerHelper('max', **locals()) + if out is None: + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): + dim = [dim] + + check_variable_and_dtype( + input, 'input', ['float32', 'float64', 'int32', 'int64'], 'max') + + reduce_all = True if dim == None or dim == [] else False + dim = dim if dim != None and dim != [] else [0] + + if in_dygraph_mode(): + return core.ops.reduce_max(input, 'dim', dim, 'keep_dim', keep_dim, + 'reduce_all', reduce_all) + helper.append_op( + type='reduce_max', + inputs={'X': input}, + outputs={'Out': out}, + attrs={ + 'dim': dim, + 'keep_dim': keep_dim, + 'reduce_all': reduce_all + }) + return out + + +def min(input, dim=None, keep_dim=False, out=None, name=None): + """ + Computes the minimum of tensor elements over the given dimension. + Args: + input (Variable): The input variable which is a Tensor, the data type is float32, + float64, int32, int64. + dim (list|int, optional): The dimensions along which the minimum is computed. + If :attr:`None`, compute the minimum over all elements of + :attr:`input` and return a Tensor variable with a single element, + otherwise must be in the range :math:`[-rank(input), rank(input))`. + If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`. + keep_dim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the :attr:`input` unless :attr:`keep_dim` is true, default + value is False. + out(Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + Returns: + Variable: Tensor, result of minimum on the specified dim of input tensor, + it's data type is the same as input's Tensor. + Examples: + .. code-block:: python + import paddle + import paddle.fluid as fluid + # x is a Tensor variable with following elements: + # [[0.2, 0.3, 0.5, 0.9] + # [0.1, 0.2, 0.6, 0.7]] + # Each example is followed by the corresponding output tensor. + x = fluid.data(name='x', shape=[2, 4], dtype='float32') + paddle.min(x) # [0.1] + paddle.min(x, dim=0) # [0.1, 0.2, 0.5, 0.7] + paddle.min(x, dim=-1) # [0.2, 0.1] + paddle.min(x, dim=1, keep_dim=True) # [[0.2], [0.1]] + # y is a Tensor variable with shape [2, 2, 2] and elements as below: + # [[[1.0, 2.0], [3.0, 4.0]], + # [[5.0, 6.0], [7.0, 8.0]]] + # Each example is followed by the corresponding output tensor. + y = fluid.data(name='y', shape=[2, 2, 2], dtype='float32') + paddle.min(y, dim=[1, 2]) # [1.0, 5.0] + paddle.min(y, dim=[0, 1]) # [1.0, 2.0] + """ + + helper = LayerHelper('min', **locals()) + if out is None: + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): + dim = [dim] + + check_variable_and_dtype( + input, 'input', ['float32', 'float64', 'int32', 'int64'], 'max') + + reduce_all = True if dim == None or dim == [] else False + dim = dim if dim != None and dim != [] else [0] + + if in_dygraph_mode(): + return core.ops.reduce_min(input, 'dim', dim, 'keep_dim', keep_dim, + 'reduce_all', reduce_all) + helper.append_op( + type='reduce_min', + inputs={'X': input}, + outputs={'Out': out}, + attrs={ + 'dim': dim, + 'keep_dim': keep_dim, + 'reduce_all': reduce_all + }) + return out + + +def log1p(x, out=None, name=None): + """ + Calculates the natural log of the given input tensor, element-wise. + .. math:: + Out = \\ln(x+1) + Args: + x (Variable): Input LoDTensor or Tensor. Must be one of the following types: float32, float64. + out(Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + Returns: + Variable: The natural log of the input LoDTensor or Tensor computed element-wise. + Examples: + .. code-block:: python + import paddle + import paddle.fluid as fluid + import numpy as np + # Graph Organizing + x = fluid.data(name="x", shape=[2,1], dtype="float32") + res = paddle.log1p(x) + # Create an executor using CPU as an example + exe = fluid.Executor(fluid.CPUPlace()) + # Execute + x_i = np.array([[0], [1]]).astype(np.float32) + res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res]) + print(res_val) # [[0.], [0.6931472]] + """ + + if in_dygraph_mode(): + return core.ops.log1p(x) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], "log1p") + inputs = {'X': [x]} + helper = LayerHelper('log1p', **locals()) + dtype = helper.input_dtype(input_param_name='x') + if out is None: + out = helper.create_variable_for_type_inference(dtype) + helper.append_op(type="log1p", inputs={"X": x}, outputs={"Out": out}) + return out