diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 204f854a380abb5110e9b899834d0ee00579254e..b9a92c2207d8e9b86cc95be8285ce6b2e6db597b 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -250,6 +250,20 @@ $$out = sin(x)$$ )DOC"; +UNUSED constexpr char SinhDoc[] = R"DOC( +Sinh Activation Operator. + +$$out = sinh(x)$$ + +)DOC"; + +UNUSED constexpr char CoshDoc[] = R"DOC( +Cosh Activation Operator. + +$$out = cosh(x)$$ + +)DOC"; + UNUSED constexpr char RoundDoc[] = R"DOC( The OP rounds the values in the input to the nearest integer value. @@ -642,6 +656,8 @@ REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc); REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc); REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc); REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc); +REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc); +REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc); REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc); REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc); REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index b3784ed0744095c2032dd8a0de7bd6b12827cf5c..3aac7ae8a5e8a9e889242b59f42a29af08ad1c46 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -584,6 +584,72 @@ struct SinFunctor : public BaseActivationFunctor { } }; +template +struct Sinh { + HOSTDEVICE T operator()(const T& val) const { return sinh(val); } +}; + +template <> +struct Sinh { + HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { + return platform::float16(sinhf(static_cast(val))); + } +}; + +template +struct Cosh { + HOSTDEVICE T operator()(const T& val) const { return cosh(val); } +}; + +template <> +struct Cosh { + HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { + return platform::float16(coshf(static_cast(val))); + } +}; + +// sinh(x) = sinh(x) +template +struct SinhFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Sinh()); + } +}; + +// cosh(x) = cosh(x) +template +struct CoshFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Cosh()); + } +}; + +// sinh'(x) = cosh(x) +template +struct SinhGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * x.unaryExpr(Cosh()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +// cosh'(x) = sinh(x) +template +struct CoshGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * x.unaryExpr(Sinh()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template struct Acos { HOSTDEVICE T operator()(const T& val) const { return acos(val); } @@ -1752,6 +1818,8 @@ class PowGradKernel __macro(acos, Acos, AcosFunctor, AcosGradFunctor); \ __macro(sin, Sin, SinFunctor, SinGradFunctor); \ __macro(asin, Asin, AsinFunctor, AsinGradFunctor); \ + __macro(sinh, Sinh, SinhFunctor, SinhGradFunctor); \ + __macro(cosh, Cosh, CoshFunctor, CoshGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(log, Log, LogFunctor, LogGradFunctor); \ diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 4163d5ed955829b66956b39a7b26a6753ef0b367..0d572599a6678990cca16c18522ff504f4de12e0 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -132,6 +132,7 @@ from .tensor.math import asin #DEFINE_ALIAS from .tensor.math import atan #DEFINE_ALIAS from .tensor.math import ceil #DEFINE_ALIAS from .tensor.math import cos #DEFINE_ALIAS +from .tensor.math import cosh #DEFINE_ALIAS from .tensor.math import cumsum #DEFINE_ALIAS from .tensor.math import elementwise_add #DEFINE_ALIAS from .tensor.math import elementwise_div #DEFINE_ALIAS @@ -157,6 +158,7 @@ from .tensor.math import rsqrt #DEFINE_ALIAS from .tensor.math import scale #DEFINE_ALIAS from .tensor.math import sign #DEFINE_ALIAS from .tensor.math import sin #DEFINE_ALIAS +from .tensor.math import sinh #DEFINE_ALIAS from .tensor.math import sqrt #DEFINE_ALIAS from .tensor.math import square #DEFINE_ALIAS from .tensor.math import stanh #DEFINE_ALIAS diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index c4b6da5629ae45d49b2f63496e73665e693c9efb..3adb243c8f83d0dc0d8c89daf5a630cb4bdce1fd 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -35,6 +35,8 @@ __activations_noattr__ = [ 'acos', 'asin', 'sin', + 'sinh', + 'cosh', 'round', 'reciprocal', 'square', @@ -80,9 +82,9 @@ def softshrink(x, alpha=None): softshrink.__doc__ = """ - :alias_main: paddle.nn.functional.softshrink - :alias: paddle.nn.functional.softshrink,paddle.nn.functional.activation.softshrink - :old_api: paddle.fluid.layers.softshrink + :alias_main: paddle.nn.functional.softshrink + :alias: paddle.nn.functional.softshrink,paddle.nn.functional.activation.softshrink + :old_api: paddle.fluid.layers.softshrink :strong:`Softshrink Activation Operator` @@ -127,9 +129,9 @@ def hard_shrink(x, threshold=None): hard_shrink.__doc__ = _hard_shrink_.__doc__ + """ - :alias_main: paddle.nn.functional.hard_shrink - :alias: paddle.nn.functional.hard_shrink,paddle.nn.functional.activation.hard_shrink - :old_api: paddle.fluid.layers.hard_shrink + :alias_main: paddle.nn.functional.hard_shrink + :alias: paddle.nn.functional.hard_shrink,paddle.nn.functional.activation.hard_shrink + :old_api: paddle.fluid.layers.hard_shrink Examples: @@ -154,9 +156,9 @@ def cumsum(x, axis=None, exclusive=None, reverse=None): cumsum.__doc__ = """ - :alias_main: paddle.cumsum - :alias: paddle.cumsum,paddle.tensor.cumsum,paddle.tensor.math.cumsum - :old_api: paddle.fluid.layers.cumsum + :alias_main: paddle.cumsum + :alias: paddle.cumsum,paddle.tensor.cumsum,paddle.tensor.math.cumsum + :old_api: paddle.fluid.layers.cumsum The cumulative sum of the elements along a given axis. By default, the first element of the result is the same of the first element of the input. If exlusive is true, the first element of the result is 0. @@ -196,9 +198,9 @@ def thresholded_relu(x, threshold=None): thresholded_relu.__doc__ = """ - :alias_main: paddle.nn.functional.thresholded_relu - :alias: paddle.nn.functional.thresholded_relu,paddle.nn.functional.activation.thresholded_relu - :old_api: paddle.fluid.layers.thresholded_relu + :alias_main: paddle.nn.functional.thresholded_relu + :alias: paddle.nn.functional.thresholded_relu,paddle.nn.functional.activation.thresholded_relu + :old_api: paddle.fluid.layers.thresholded_relu :strong:`Thresholded ReLU Activation Operator` @@ -282,9 +284,9 @@ def gelu(x, approximate=False): gelu.__doc__ = """ - :alias_main: paddle.nn.functional.gelu - :alias: paddle.nn.functional.gelu,paddle.nn.functional.activation.gelu - :old_api: paddle.fluid.layers.gelu + :alias_main: paddle.nn.functional.gelu + :alias: paddle.nn.functional.gelu,paddle.nn.functional.activation.gelu + :old_api: paddle.fluid.layers.gelu :strong:`GeLU Activation Operator` For more details, see [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415). @@ -370,9 +372,9 @@ def erf(x): erf.__doc__ = """ - :alias_main: paddle.erf - :alias: paddle.erf,paddle.tensor.erf,paddle.tensor.math.erf,paddle.nn.functional.erf,paddle.nn.functional.activation.erf - :old_api: paddle.fluid.layers.erf + :alias_main: paddle.erf + :alias: paddle.erf,paddle.tensor.erf,paddle.tensor.math.erf,paddle.nn.functional.erf,paddle.nn.functional.activation.erf + :old_api: paddle.fluid.layers.erf :strong:`Erf Operator` For more details, see [Error function](https://en.wikipedia.org/wiki/Error_function). diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 5b9e7bfe62b7f4804c49d43c449d7e3e366f4942..7d687dbd0f85f9235161dfa8528c849975aa5af3 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -183,6 +183,148 @@ class TestAtan(TestActivation, TestParameter): self.assertEqual(z, z_expected) +class TestSinh(TestActivation): + def setUp(self): + self.op_type = "sinh" + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.sinh(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out') + + def test_dygraph(self): + with fluid.dygraph.guard(): + np_x = np.array([0.1]) + x = fluid.dygraph.to_variable(np_x) + z = fluid.layers.sinh(x).numpy() + z_expected = np.sinh(np_x) + self.assertTrue(np.allclose(z, z_expected)) + + def test_api(self): + test_data_shape = [11, 17] + with fluid.program_guard(fluid.Program(), fluid.Program()): + input_x = np.random.uniform(0.1, 1, + test_data_shape).astype("float32") + data_x = fluid.layers.data( + name="data_x", + shape=test_data_shape, + append_batch_size=False, + dtype="float32") + + pd_sinh_out = fluid.layers.sinh(data_x) + exe = fluid.Executor(place=fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + np_sinh_res = exe.run(fluid.default_main_program(), + feed={"data_x": input_x}, + fetch_list=[pd_sinh_out]) + + expected_res = np.sinh(input_x) + self.assertTrue(np.allclose(np_sinh_res, expected_res)) + + def test_backward(self): + test_data_shape = [11, 17] + with fluid.dygraph.guard(): + input_x = np.random.uniform(0.1, 1, + test_data_shape).astype("float32") + var = fluid.dygraph.to_variable(input_x) + var.stop_gradient = False + loss = fluid.layers.sinh(var) + loss.backward() + grad_var = var.gradient() + self.assertEqual(grad_var.shape, input_x.shape) + + +class TestSinhOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program()): + # The input type must be Variable. + self.assertRaises(TypeError, fluid.layers.sinh, 1) + # The input dtype must be float16, float32, float64. + x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, fluid.layers.sinh, x_int32) + # support the input dtype is float16 + x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') + fluid.layers.sinh(x_fp16) + + +class TestCosh(TestActivation): + def setUp(self): + self.op_type = "cosh" + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.cosh(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out') + + def test_dygraph(self): + with fluid.dygraph.guard(): + np_x = np.array([0.1]) + x = fluid.dygraph.to_variable(np_x) + z = fluid.layers.cosh(x).numpy() + z_expected = np.cosh(np_x) + self.assertTrue(np.allclose(z, z_expected)) + + def test_api(self): + test_data_shape = [11, 17] + with fluid.program_guard(fluid.Program(), fluid.Program()): + input_x = np.random.uniform(0.1, 1, + test_data_shape).astype("float32") + data_x = fluid.layers.data( + name="data_x", + shape=test_data_shape, + append_batch_size=False, + dtype="float32") + + pd_cosh_out = paddle.cosh(data_x) + exe = fluid.Executor(place=fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + np_cosh_res = exe.run(fluid.default_main_program(), + feed={"data_x": input_x}, + fetch_list=[pd_cosh_out]) + + expected_res = np.cosh(input_x) + self.assertTrue(np.allclose(np_cosh_res, expected_res)) + + def test_backward(self): + test_data_shape = [11, 17] + with fluid.dygraph.guard(): + input_x = np.random.uniform(0.1, 1, + test_data_shape).astype("float32") + var = fluid.dygraph.to_variable(input_x) + var.stop_gradient = False + loss = fluid.layers.cosh(var) + loss.backward() + grad_var = var.gradient() + self.assertEqual(grad_var.shape, input_x.shape) + + +class TestCoshOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program()): + # The input type must be Variable. + self.assertRaises(TypeError, fluid.layers.cosh, 1) + # The input dtype must be float16, float32, float64. + x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, fluid.layers.cosh, x_int32) + # support the input dtype is float16 + x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') + fluid.layers.cosh(x_fp16) + + class TestTanhShrink(TestActivation): def setUp(self): self.op_type = "tanh_shrink" @@ -1204,8 +1346,10 @@ create_test_act_fp16_class(TestAbs) create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestFloor, grad_check=False) create_test_act_fp16_class(TestCos, grad_atol=0.85) +create_test_act_fp16_class(TestCosh, grad_atol=0.85) create_test_act_fp16_class(TestAcos, grad_atol=0.85) create_test_act_fp16_class(TestSin) +create_test_act_fp16_class(TestSinh) create_test_act_fp16_class(TestAsin) create_test_act_fp16_class(TestAtan) create_test_act_fp16_class(TestRound, grad_check=False) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 7a583e0c38dd620f700e95ba06d9e3ec41042fb0..8ffe9613995a8f4a086d44ca1ed5f2fa691bdaf8 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -105,6 +105,7 @@ from .math import asin #DEFINE_ALIAS from .math import atan #DEFINE_ALIAS from .math import ceil #DEFINE_ALIAS from .math import cos #DEFINE_ALIAS +from .math import cosh #DEFINE_ALIAS from .math import cumsum #DEFINE_ALIAS from .math import elementwise_add #DEFINE_ALIAS from .math import elementwise_div #DEFINE_ALIAS @@ -130,6 +131,7 @@ from .math import rsqrt #DEFINE_ALIAS from .math import scale #DEFINE_ALIAS from .math import sign #DEFINE_ALIAS from .math import sin #DEFINE_ALIAS +from .math import sinh #DEFINE_ALIAS from .math import sqrt #DEFINE_ALIAS from .math import square #DEFINE_ALIAS from .math import stanh #DEFINE_ALIAS diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 89b6b4e47393a61b5a12c6730ff0e0aee1b8f81d..72cf76c5c725bc6664eb91f7aa8e18074a287169 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -31,6 +31,8 @@ from ..fluid.layers import acos #DEFINE_ALIAS from ..fluid.layers import asin #DEFINE_ALIAS from ..fluid.layers import ceil #DEFINE_ALIAS from ..fluid.layers import cos #DEFINE_ALIAS +from ..fluid.layers import sinh #DEFINE_ALIAS +from ..fluid.layers import cosh #DEFINE_ALIAS from ..fluid.layers import cumsum #DEFINE_ALIAS from ..fluid.layers import elementwise_add #DEFINE_ALIAS from ..fluid.layers import elementwise_div #DEFINE_ALIAS @@ -69,6 +71,7 @@ __all__ = [ 'atan', 'ceil', 'cos', + 'cosh', 'cumsum', 'elementwise_add', 'elementwise_div', @@ -95,6 +98,7 @@ __all__ = [ 'scale', 'sign', 'sin', + 'sinh', 'sqrt', 'square', 'stanh', @@ -171,7 +175,7 @@ Examples: .. code-block:: python import numpy as np - + import paddle import paddle.fluid as fluid @@ -201,9 +205,9 @@ def pow(input, exponent, out=None, name=None): Args: input(Variable): A ``Tensor`` or ``LoDTensor`` . The data type is ``float32`` or ``float64``. exponent(float32|Variable): A scalar with type ``float32`` or a ``Tensor`` with shape [1] and type ``float32``. - out (Variable, optional): The Variable that stores results of the operation. + out (Variable, optional): The Variable that stores results of the operation. If out is None, a new Variable will be created to store the results. - name(str, optional): The default value is None. Normally there is no need for user to set this property. + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Returns: @@ -666,7 +670,7 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None): Tensor variable with a single element, otherwise must be in the range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`. - dtype(str, optional): The dtype of output tensor. The default value is None, the dtype + dtype(str, optional): The dtype of output tensor. The default value is None, the dtype of output is the same as input tensor. keep_dim (bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension @@ -681,7 +685,7 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None): Raises: ValueError, the :attr:`dtype` must be float64 or int64. - + Examples: .. code-block:: python @@ -762,7 +766,7 @@ def elementwise_sum(inputs, name=None): :alias: paddle.elementwise_sum,paddle.tensor.elementwise_sum,paddle.tensor.math.elementwise_sum ${comment} - + Case 1: :: Input: @@ -794,13 +798,13 @@ def elementwise_sum(inputs, name=None): [14, 16, 18]] Args: - inputs (Variable|list(Variable)): A Varaible list. The shape and data type of the list elementsshould be consistent. - Variable can be multi-dimensional Tensoror LoDTensor, and data types can be: float32, float64, int32, int64. + inputs (Variable|list(Variable)): A Varaible list. The shape and data type of the list elementsshould be consistent. + Variable can be multi-dimensional Tensoror LoDTensor, and data types can be: float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` Returns: - Variable: the sum of input :math:`inputs` . its shape and data types are consistent with :math:`inputs` . + Variable: the sum of input :math:`inputs` . its shape and data types are consistent with :math:`inputs` . Examples: .. code-block:: python @@ -826,8 +830,8 @@ def elementwise_sum(inputs, name=None): # the sum of input0 and input1 is 2-D Tensor with shape [2,3]. # dtype is the corresponding C++ data type, which may vary in different environments. - # Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t, - # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, + # Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t, + # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, # and '__int64' on Windows. They both represent 64-bit integer variables. """ @@ -872,7 +876,7 @@ def mm(input, mat2, out=None, name=None): Args: x (Variable): The input variable which is a Tensor or LoDTensor. mat2 (Variable): The input variable which is a Tensor or LoDTensor. - out(Variable, optional): Optional output which can be any created + out(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of operation. if out is None, a new Varibale will be create to store the result. name(str, optional): The default value is None. Normally there is no need for @@ -1004,7 +1008,7 @@ def addmm(input, x, y, alpha=1.0, beta=1.0, name=None): place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda() else fluid.CPUPlace() exe = fluid.Executor(place) - results = exe.run(fluid.default_main_program(), + results = exe.run(fluid.default_main_program(), fetch_list=[out], feed={"input": data_input, 'x': data_x, "y": data_y}) print( np.array(results[0]) ) # [[10.5 10.5] @@ -1113,7 +1117,7 @@ def inverse(input, out=None, name=None): dimensions should be equal. When the number of dimensions is greater than 2, it is treated as batches of square matrix. The data type can be float32 and float64. - out (Variable, optional): Optional output which can be any created + out (Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of operation. If out is None, a new Varibale will be create to store the result. name (str, optional): The default value is None. Normally there is no need for @@ -1136,7 +1140,7 @@ def inverse(input, out=None, name=None): # example for static graph input = fluid.data("input", shape=[2, 2], dtype="float32") out = paddle.inverse(input) - + place = fluid.CPUPlace() exe = fluid.Executor(place) results = exe.run(feed={"input": mat_np }, @@ -1193,10 +1197,10 @@ def max(input, dim=None, keep_dim=False, out=None, name=None): output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true, default value is False. - out(Variable, optional): Optional output which can be any created + out(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of operation. if out is None, a new Varibale will be create to store the result. - name(str, optional): The default value is None. Normally there is no need for + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` Returns: @@ -1273,10 +1277,10 @@ def min(input, dim=None, keep_dim=False, out=None, name=None): output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true, default value is False. - out(Variable, optional): Optional output which can be any created + out(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of operation. if out is None, a new Varibale will be create to store the result. - name(str, optional): The default value is None. Normally there is no need for + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` Returns: @@ -1437,17 +1441,17 @@ def clamp(input, min=None, max=None, output=None, name=None): .. math:: - Out = MIN(MAX(x, min), max) + Out = MIN(MAX(x, min), max) Args: - input (Variable): An input N-D Tensor or LoDTensor - with data type float32, float64. + input (Variable): An input N-D Tensor or LoDTensor + with data type float32, float64. min (float32|Variable): The lower bound with type ``float32`` or a ``Tensor`` with shape [1] and type ``int32``, ``float32``, ``float64``. max (float32|Variable): The upper bound with type ``float32`` or a ``Tensor`` with shape [1] and type ``int32``, ``float32``, ``float64``. - output (Variable, optional): A tensor or LoDTensor. If :attr:`output` is None, - a new tensor will be created as :attr:`output`. Default: None. + output (Variable, optional): A tensor or LoDTensor. If :attr:`output` is None, + a new tensor will be created as :attr:`output`. Default: None. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -1519,11 +1523,11 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): :alias: paddle.trace,paddle.tensor.trace,paddle.tensor.math.trace This OP computes the sum along diagonals of the input tensor x. - - If ``x`` is 2D, returns the sum of diagonal. + + If ``x`` is 2D, returns the sum of diagonal. If ``x`` has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from - the 2D planes specified by axis1 and axis2. By default, the 2D planes formed by the first and second axes + the 2D planes specified by axis1 and axis2. By default, the 2D planes formed by the first and second axes of the input tensor x. The argument ``offset`` determines where diagonals are taken from input tensor x: @@ -1531,7 +1535,7 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): - If offset = 0, it is the main diagonal. - If offset > 0, it is above the main diagonal. - If offset < 0, it is below the main diagonal. - + Args: x(Variable): The input tensor x. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64. offset(int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals). @@ -1547,11 +1551,11 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): import paddle import numpy as np - + case1 = np.random.randn(2, 3).astype('float32') case2 = np.random.randn(3, 10, 10).astype('float32') case3 = np.random.randn(3, 10, 5, 10).astype('float32') - + paddle.enable_imperative() case1 = paddle.imperative.to_variable(case1) @@ -1615,17 +1619,17 @@ def kron(x, y, out=None, name=None): ${comment} Args: - x (Variable): the fist operand of kron op, data type: float16, float32, + x (Variable): the fist operand of kron op, data type: float16, float32, float64, int32 or int64. - y (Variable): the second operand of kron op, data type: float16, - float32, float64, int32 or int64. Its data type should be the same + y (Variable): the second operand of kron op, data type: float16, + float32, float64, int32 or int64. Its data type should be the same with x. - out (Variable, optional): Optional output which can be any created - Variable that meets the requirements to store the result of - operation. If out is None, a new Varibale will be create to store + out (Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of + operation. If out is None, a new Varibale will be create to store the result. Defaults to None. - name(str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -1633,7 +1637,7 @@ ${comment} Examples: .. code-block:: python - + import paddle from paddle import fluid import paddle.fluid.dygraph as dg