未验证 提交 70cee22f 编写于 作者: C cnn 提交者: GitHub

New features, add sinh and cosh op, test=develop (#25495)

* New features, add sinh and cosh op, test=develop

* remove duplicate test function and remove out paramters, test=develop

* Add out paramters temporary, remove later. test=develop

* remove out args, PR 25570, test=develop

* remove TestParameter, test=developx

* add test api for static dygraph, test=develop

* add backword unittests for sinh and cosh, test=develop
上级 2f95e663
......@@ -250,6 +250,20 @@ $$out = sin(x)$$
)DOC";
UNUSED constexpr char SinhDoc[] = R"DOC(
Sinh Activation Operator.
$$out = sinh(x)$$
)DOC";
UNUSED constexpr char CoshDoc[] = R"DOC(
Cosh Activation Operator.
$$out = cosh(x)$$
)DOC";
UNUSED constexpr char RoundDoc[] = R"DOC(
The OP rounds the values in the input to the nearest integer value.
......@@ -642,6 +656,8 @@ REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
......
......@@ -584,6 +584,72 @@ struct SinFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct Sinh {
HOSTDEVICE T operator()(const T& val) const { return sinh(val); }
};
template <>
struct Sinh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(sinhf(static_cast<float>(val)));
}
};
template <typename T>
struct Cosh {
HOSTDEVICE T operator()(const T& val) const { return cosh(val); }
};
template <>
struct Cosh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(coshf(static_cast<float>(val)));
}
};
// sinh(x) = sinh(x)
template <typename T>
struct SinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sinh<T>());
}
};
// cosh(x) = cosh(x)
template <typename T>
struct CoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosh<T>());
}
};
// sinh'(x) = cosh(x)
template <typename T>
struct SinhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Cosh<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// cosh'(x) = sinh(x)
template <typename T>
struct CoshGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Sinh<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Acos {
HOSTDEVICE T operator()(const T& val) const { return acos(val); }
......@@ -1752,6 +1818,8 @@ class PowGradKernel
__macro(acos, Acos, AcosFunctor, AcosGradFunctor); \
__macro(sin, Sin, SinFunctor, SinGradFunctor); \
__macro(asin, Asin, AsinFunctor, AsinGradFunctor); \
__macro(sinh, Sinh, SinhFunctor, SinhGradFunctor); \
__macro(cosh, Cosh, CoshFunctor, CoshGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, Log, LogFunctor, LogGradFunctor); \
......
......@@ -132,6 +132,7 @@ from .tensor.math import asin #DEFINE_ALIAS
from .tensor.math import atan #DEFINE_ALIAS
from .tensor.math import ceil #DEFINE_ALIAS
from .tensor.math import cos #DEFINE_ALIAS
from .tensor.math import cosh #DEFINE_ALIAS
from .tensor.math import cumsum #DEFINE_ALIAS
from .tensor.math import elementwise_add #DEFINE_ALIAS
from .tensor.math import elementwise_div #DEFINE_ALIAS
......@@ -157,6 +158,7 @@ from .tensor.math import rsqrt #DEFINE_ALIAS
from .tensor.math import scale #DEFINE_ALIAS
from .tensor.math import sign #DEFINE_ALIAS
from .tensor.math import sin #DEFINE_ALIAS
from .tensor.math import sinh #DEFINE_ALIAS
from .tensor.math import sqrt #DEFINE_ALIAS
from .tensor.math import square #DEFINE_ALIAS
from .tensor.math import stanh #DEFINE_ALIAS
......
......@@ -35,6 +35,8 @@ __activations_noattr__ = [
'acos',
'asin',
'sin',
'sinh',
'cosh',
'round',
'reciprocal',
'square',
......@@ -80,9 +82,9 @@ def softshrink(x, alpha=None):
softshrink.__doc__ = """
:alias_main: paddle.nn.functional.softshrink
:alias: paddle.nn.functional.softshrink,paddle.nn.functional.activation.softshrink
:old_api: paddle.fluid.layers.softshrink
:alias_main: paddle.nn.functional.softshrink
:alias: paddle.nn.functional.softshrink,paddle.nn.functional.activation.softshrink
:old_api: paddle.fluid.layers.softshrink
:strong:`Softshrink Activation Operator`
......@@ -127,9 +129,9 @@ def hard_shrink(x, threshold=None):
hard_shrink.__doc__ = _hard_shrink_.__doc__ + """
:alias_main: paddle.nn.functional.hard_shrink
:alias: paddle.nn.functional.hard_shrink,paddle.nn.functional.activation.hard_shrink
:old_api: paddle.fluid.layers.hard_shrink
:alias_main: paddle.nn.functional.hard_shrink
:alias: paddle.nn.functional.hard_shrink,paddle.nn.functional.activation.hard_shrink
:old_api: paddle.fluid.layers.hard_shrink
Examples:
......@@ -154,9 +156,9 @@ def cumsum(x, axis=None, exclusive=None, reverse=None):
cumsum.__doc__ = """
:alias_main: paddle.cumsum
:alias: paddle.cumsum,paddle.tensor.cumsum,paddle.tensor.math.cumsum
:old_api: paddle.fluid.layers.cumsum
:alias_main: paddle.cumsum
:alias: paddle.cumsum,paddle.tensor.cumsum,paddle.tensor.math.cumsum
:old_api: paddle.fluid.layers.cumsum
The cumulative sum of the elements along a given axis. By default, the first element of the result is the same of the first element of the input. If exlusive is true, the first element of the result is 0.
......@@ -196,9 +198,9 @@ def thresholded_relu(x, threshold=None):
thresholded_relu.__doc__ = """
:alias_main: paddle.nn.functional.thresholded_relu
:alias: paddle.nn.functional.thresholded_relu,paddle.nn.functional.activation.thresholded_relu
:old_api: paddle.fluid.layers.thresholded_relu
:alias_main: paddle.nn.functional.thresholded_relu
:alias: paddle.nn.functional.thresholded_relu,paddle.nn.functional.activation.thresholded_relu
:old_api: paddle.fluid.layers.thresholded_relu
:strong:`Thresholded ReLU Activation Operator`
......@@ -282,9 +284,9 @@ def gelu(x, approximate=False):
gelu.__doc__ = """
:alias_main: paddle.nn.functional.gelu
:alias: paddle.nn.functional.gelu,paddle.nn.functional.activation.gelu
:old_api: paddle.fluid.layers.gelu
:alias_main: paddle.nn.functional.gelu
:alias: paddle.nn.functional.gelu,paddle.nn.functional.activation.gelu
:old_api: paddle.fluid.layers.gelu
:strong:`GeLU Activation Operator`
For more details, see [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
......@@ -370,9 +372,9 @@ def erf(x):
erf.__doc__ = """
:alias_main: paddle.erf
:alias: paddle.erf,paddle.tensor.erf,paddle.tensor.math.erf,paddle.nn.functional.erf,paddle.nn.functional.activation.erf
:old_api: paddle.fluid.layers.erf
:alias_main: paddle.erf
:alias: paddle.erf,paddle.tensor.erf,paddle.tensor.math.erf,paddle.nn.functional.erf,paddle.nn.functional.activation.erf
:old_api: paddle.fluid.layers.erf
:strong:`Erf Operator`
For more details, see [Error function](https://en.wikipedia.org/wiki/Error_function).
......
......@@ -183,6 +183,148 @@ class TestAtan(TestActivation, TestParameter):
self.assertEqual(z, z_expected)
class TestSinh(TestActivation):
def setUp(self):
self.op_type = "sinh"
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.sinh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([0.1])
x = fluid.dygraph.to_variable(np_x)
z = fluid.layers.sinh(x).numpy()
z_expected = np.sinh(np_x)
self.assertTrue(np.allclose(z, z_expected))
def test_api(self):
test_data_shape = [11, 17]
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_x = np.random.uniform(0.1, 1,
test_data_shape).astype("float32")
data_x = fluid.layers.data(
name="data_x",
shape=test_data_shape,
append_batch_size=False,
dtype="float32")
pd_sinh_out = fluid.layers.sinh(data_x)
exe = fluid.Executor(place=fluid.CPUPlace())
exe.run(fluid.default_startup_program())
np_sinh_res = exe.run(fluid.default_main_program(),
feed={"data_x": input_x},
fetch_list=[pd_sinh_out])
expected_res = np.sinh(input_x)
self.assertTrue(np.allclose(np_sinh_res, expected_res))
def test_backward(self):
test_data_shape = [11, 17]
with fluid.dygraph.guard():
input_x = np.random.uniform(0.1, 1,
test_data_shape).astype("float32")
var = fluid.dygraph.to_variable(input_x)
var.stop_gradient = False
loss = fluid.layers.sinh(var)
loss.backward()
grad_var = var.gradient()
self.assertEqual(grad_var.shape, input_x.shape)
class TestSinhOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.sinh, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.sinh, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.sinh(x_fp16)
class TestCosh(TestActivation):
def setUp(self):
self.op_type = "cosh"
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.cosh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([0.1])
x = fluid.dygraph.to_variable(np_x)
z = fluid.layers.cosh(x).numpy()
z_expected = np.cosh(np_x)
self.assertTrue(np.allclose(z, z_expected))
def test_api(self):
test_data_shape = [11, 17]
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_x = np.random.uniform(0.1, 1,
test_data_shape).astype("float32")
data_x = fluid.layers.data(
name="data_x",
shape=test_data_shape,
append_batch_size=False,
dtype="float32")
pd_cosh_out = paddle.cosh(data_x)
exe = fluid.Executor(place=fluid.CPUPlace())
exe.run(fluid.default_startup_program())
np_cosh_res = exe.run(fluid.default_main_program(),
feed={"data_x": input_x},
fetch_list=[pd_cosh_out])
expected_res = np.cosh(input_x)
self.assertTrue(np.allclose(np_cosh_res, expected_res))
def test_backward(self):
test_data_shape = [11, 17]
with fluid.dygraph.guard():
input_x = np.random.uniform(0.1, 1,
test_data_shape).astype("float32")
var = fluid.dygraph.to_variable(input_x)
var.stop_gradient = False
loss = fluid.layers.cosh(var)
loss.backward()
grad_var = var.gradient()
self.assertEqual(grad_var.shape, input_x.shape)
class TestCoshOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.cosh, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.cosh, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.cosh(x_fp16)
class TestTanhShrink(TestActivation):
def setUp(self):
self.op_type = "tanh_shrink"
......@@ -1204,8 +1346,10 @@ create_test_act_fp16_class(TestAbs)
create_test_act_fp16_class(TestCeil, grad_check=False)
create_test_act_fp16_class(TestFloor, grad_check=False)
create_test_act_fp16_class(TestCos, grad_atol=0.85)
create_test_act_fp16_class(TestCosh, grad_atol=0.85)
create_test_act_fp16_class(TestAcos, grad_atol=0.85)
create_test_act_fp16_class(TestSin)
create_test_act_fp16_class(TestSinh)
create_test_act_fp16_class(TestAsin)
create_test_act_fp16_class(TestAtan)
create_test_act_fp16_class(TestRound, grad_check=False)
......
......@@ -105,6 +105,7 @@ from .math import asin #DEFINE_ALIAS
from .math import atan #DEFINE_ALIAS
from .math import ceil #DEFINE_ALIAS
from .math import cos #DEFINE_ALIAS
from .math import cosh #DEFINE_ALIAS
from .math import cumsum #DEFINE_ALIAS
from .math import elementwise_add #DEFINE_ALIAS
from .math import elementwise_div #DEFINE_ALIAS
......@@ -130,6 +131,7 @@ from .math import rsqrt #DEFINE_ALIAS
from .math import scale #DEFINE_ALIAS
from .math import sign #DEFINE_ALIAS
from .math import sin #DEFINE_ALIAS
from .math import sinh #DEFINE_ALIAS
from .math import sqrt #DEFINE_ALIAS
from .math import square #DEFINE_ALIAS
from .math import stanh #DEFINE_ALIAS
......
......@@ -31,6 +31,8 @@ from ..fluid.layers import acos #DEFINE_ALIAS
from ..fluid.layers import asin #DEFINE_ALIAS
from ..fluid.layers import ceil #DEFINE_ALIAS
from ..fluid.layers import cos #DEFINE_ALIAS
from ..fluid.layers import sinh #DEFINE_ALIAS
from ..fluid.layers import cosh #DEFINE_ALIAS
from ..fluid.layers import cumsum #DEFINE_ALIAS
from ..fluid.layers import elementwise_add #DEFINE_ALIAS
from ..fluid.layers import elementwise_div #DEFINE_ALIAS
......@@ -69,6 +71,7 @@ __all__ = [
'atan',
'ceil',
'cos',
'cosh',
'cumsum',
'elementwise_add',
'elementwise_div',
......@@ -95,6 +98,7 @@ __all__ = [
'scale',
'sign',
'sin',
'sinh',
'sqrt',
'square',
'stanh',
......@@ -171,7 +175,7 @@ Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
......@@ -201,9 +205,9 @@ def pow(input, exponent, out=None, name=None):
Args:
input(Variable): A ``Tensor`` or ``LoDTensor`` . The data type is ``float32`` or ``float64``.
exponent(float32|Variable): A scalar with type ``float32`` or a ``Tensor`` with shape [1] and type ``float32``.
out (Variable, optional): The Variable that stores results of the operation.
out (Variable, optional): The Variable that stores results of the operation.
If out is None, a new Variable will be created to store the results.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
......@@ -666,7 +670,7 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None):
Tensor variable with a single element, otherwise must be in the
range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`,
the dimension to reduce is :math:`rank + dim[i]`.
dtype(str, optional): The dtype of output tensor. The default value is None, the dtype
dtype(str, optional): The dtype of output tensor. The default value is None, the dtype
of output is the same as input tensor.
keep_dim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
......@@ -681,7 +685,7 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None):
Raises:
ValueError, the :attr:`dtype` must be float64 or int64.
Examples:
.. code-block:: python
......@@ -762,7 +766,7 @@ def elementwise_sum(inputs, name=None):
:alias: paddle.elementwise_sum,paddle.tensor.elementwise_sum,paddle.tensor.math.elementwise_sum
${comment}
Case 1:
::
Input:
......@@ -794,13 +798,13 @@ def elementwise_sum(inputs, name=None):
[14, 16, 18]]
Args:
inputs (Variable|list(Variable)): A Varaible list. The shape and data type of the list elementsshould be consistent.
Variable can be multi-dimensional Tensoror LoDTensor, and data types can be: float32, float64, int32, int64.
inputs (Variable|list(Variable)): A Varaible list. The shape and data type of the list elementsshould be consistent.
Variable can be multi-dimensional Tensoror LoDTensor, and data types can be: float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: the sum of input :math:`inputs` . its shape and data types are consistent with :math:`inputs` .
Variable: the sum of input :math:`inputs` . its shape and data types are consistent with :math:`inputs` .
Examples:
.. code-block:: python
......@@ -826,8 +830,8 @@ def elementwise_sum(inputs, name=None):
# the sum of input0 and input1 is 2-D Tensor with shape [2,3].
# dtype is the corresponding C++ data type, which may vary in different environments.
# Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# and '__int64' on Windows. They both represent 64-bit integer variables.
"""
......@@ -872,7 +876,7 @@ def mm(input, mat2, out=None, name=None):
Args:
x (Variable): The input variable which is a Tensor or LoDTensor.
mat2 (Variable): The input variable which is a Tensor or LoDTensor.
out(Variable, optional): Optional output which can be any created
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
name(str, optional): The default value is None. Normally there is no need for
......@@ -1004,7 +1008,7 @@ def addmm(input, x, y, alpha=1.0, beta=1.0, name=None):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda() else fluid.CPUPlace()
exe = fluid.Executor(place)
results = exe.run(fluid.default_main_program(),
results = exe.run(fluid.default_main_program(),
fetch_list=[out], feed={"input": data_input, 'x': data_x, "y": data_y})
print( np.array(results[0]) )
# [[10.5 10.5]
......@@ -1113,7 +1117,7 @@ def inverse(input, out=None, name=None):
dimensions should be equal. When the number of dimensions is
greater than 2, it is treated as batches of square matrix. The data
type can be float32 and float64.
out (Variable, optional): Optional output which can be any created
out (Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
If out is None, a new Varibale will be create to store the result.
name (str, optional): The default value is None. Normally there is no need for
......@@ -1136,7 +1140,7 @@ def inverse(input, out=None, name=None):
# example for static graph
input = fluid.data("input", shape=[2, 2], dtype="float32")
out = paddle.inverse(input)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
results = exe.run(feed={"input": mat_np },
......@@ -1193,10 +1197,10 @@ def max(input, dim=None, keep_dim=False, out=None, name=None):
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true, default
value is False.
out(Variable, optional): Optional output which can be any created
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
name(str, optional): The default value is None. Normally there is no need for
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
......@@ -1273,10 +1277,10 @@ def min(input, dim=None, keep_dim=False, out=None, name=None):
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true, default
value is False.
out(Variable, optional): Optional output which can be any created
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
name(str, optional): The default value is None. Normally there is no need for
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
......@@ -1437,17 +1441,17 @@ def clamp(input, min=None, max=None, output=None, name=None):
.. math::
Out = MIN(MAX(x, min), max)
Out = MIN(MAX(x, min), max)
Args:
input (Variable): An input N-D Tensor or LoDTensor
with data type float32, float64.
input (Variable): An input N-D Tensor or LoDTensor
with data type float32, float64.
min (float32|Variable): The lower bound with type ``float32`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``.
max (float32|Variable): The upper bound with type ``float32`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``.
output (Variable, optional): A tensor or LoDTensor. If :attr:`output` is None,
a new tensor will be created as :attr:`output`. Default: None.
output (Variable, optional): A tensor or LoDTensor. If :attr:`output` is None,
a new tensor will be created as :attr:`output`. Default: None.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
......@@ -1519,11 +1523,11 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
:alias: paddle.trace,paddle.tensor.trace,paddle.tensor.math.trace
This OP computes the sum along diagonals of the input tensor x.
If ``x`` is 2D, returns the sum of diagonal.
If ``x`` is 2D, returns the sum of diagonal.
If ``x`` has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from
the 2D planes specified by axis1 and axis2. By default, the 2D planes formed by the first and second axes
the 2D planes specified by axis1 and axis2. By default, the 2D planes formed by the first and second axes
of the input tensor x.
The argument ``offset`` determines where diagonals are taken from input tensor x:
......@@ -1531,7 +1535,7 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
- If offset = 0, it is the main diagonal.
- If offset > 0, it is above the main diagonal.
- If offset < 0, it is below the main diagonal.
Args:
x(Variable): The input tensor x. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64.
offset(int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals).
......@@ -1547,11 +1551,11 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
import paddle
import numpy as np
case1 = np.random.randn(2, 3).astype('float32')
case2 = np.random.randn(3, 10, 10).astype('float32')
case3 = np.random.randn(3, 10, 5, 10).astype('float32')
paddle.enable_imperative()
case1 = paddle.imperative.to_variable(case1)
......@@ -1615,17 +1619,17 @@ def kron(x, y, out=None, name=None):
${comment}
Args:
x (Variable): the fist operand of kron op, data type: float16, float32,
x (Variable): the fist operand of kron op, data type: float16, float32,
float64, int32 or int64.
y (Variable): the second operand of kron op, data type: float16,
float32, float64, int32 or int64. Its data type should be the same
y (Variable): the second operand of kron op, data type: float16,
float32, float64, int32 or int64. Its data type should be the same
with x.
out (Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of
operation. If out is None, a new Varibale will be create to store
out (Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of
operation. If out is None, a new Varibale will be create to store
the result. Defaults to None.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
......@@ -1633,7 +1637,7 @@ ${comment}
Examples:
.. code-block:: python
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册