Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2cd10fc4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2cd10fc4
编写于
11月 17, 2020
作者:
Z
zhupengyang
提交者:
GitHub
11月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix 2.0 api docs (#28445)
上级
a083c76a
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
154 addition
and
170 deletion
+154
-170
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+16
-24
python/paddle/nn/functional/activation.py
python/paddle/nn/functional/activation.py
+20
-25
python/paddle/nn/layer/activation.py
python/paddle/nn/layer/activation.py
+16
-23
python/paddle/tensor/creation.py
python/paddle/tensor/creation.py
+19
-35
python/paddle/tensor/random.py
python/paddle/tensor/random.py
+75
-43
python/paddle/tensor/stat.py
python/paddle/tensor/stat.py
+8
-20
未找到文件。
python/paddle/fluid/layers/nn.py
浏览文件 @
2cd10fc4
...
...
@@ -9730,15 +9730,13 @@ def swish(x, beta=1.0, name=None):
return out
@deprecated(since="2.0.0", update_to="paddle.
nn.functional
.prelu")
@deprecated(since="2.0.0", update_to="paddle.
static.nn
.prelu")
def prelu(x, mode, param_attr=None, name=None):
"""
:api_attr: Static Graph
Equation:
prelu activation.
.. math::
y = \max(0, x) + \\alpha * \
min(0, x)
prelu(x) = max(0, x) + \\alpha *
min(0, x)
There are three modes for the activation:
...
...
@@ -9748,34 +9746,28 @@ def prelu(x, mode, param_attr=None, name=None):
channel: Elements in same channel share same alpha.
element: All elements do not share alpha. Each element has its own alpha.
Arg
s:
x (
Variable
): The input Tensor or LoDTensor with data type float32.
Parameter
s:
x (
Tensor
): The input Tensor or LoDTensor with data type float32.
mode (str): The mode for weight sharing.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight (alpha), it can be create by ParamAttr. None by default.
For detailed information, please refer to :ref:`api_fluid_ParamAttr`.
name(str|None): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
param_attr (ParamAttr|None, optional): The parameter attribute for the learnable
weight (alpha), it can be create by ParamAttr. None by default.
For detailed information, please refer to :ref:`api_fluid_ParamAttr`.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable:
output(Variable): The tensor or LoDTensor with the same shape as input.
The data type is float32.
Tensor: A tensor with the same shape and data type as x.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
from paddle.fluid.param_attr import ParamAttr
x = fluid.data(name="x", shape=[None,5,10,10], dtype="float32")
mode = 'channel'
output = fluid.layers.prelu(
x,mode,param_attr=ParamAttr(name='alpha'))
x = paddle.to_tensor([-1., 2., 3.])
param = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.2))
out = paddle.static.nn.prelu(x, 'all', param)
# [-0.2, 2., 3.]
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'prelu')
...
...
python/paddle/nn/functional/activation.py
浏览文件 @
2cd10fc4
...
...
@@ -79,9 +79,8 @@ def elu(x, alpha=1.0, name=None):
import paddle
import paddle.nn.functional as F
import numpy as np
x = paddle.to_tensor(
np.array([[-1,6],[1,15.6]])
)
x = paddle.to_tensor(
[[-1., 6.], [1., 15.6]]
)
out = F.elu(x, alpha=0.2)
# [[-0.12642411 6. ]
# [ 1. 15.6 ]]
...
...
@@ -131,11 +130,14 @@ def gelu(x, approximate=False, name=None):
import paddle
import paddle.nn.functional as F
import numpy as np
x = paddle.to_tensor(np.array([[-1, 0.5],[1, 1.5]]))
out1 = F.gelu(x) # [-0.158655 0.345731 0.841345 1.39979]
out2 = F.gelu(x, True) # [-0.158808 0.345714 0.841192 1.39957]
x = paddle.to_tensor([[-1, 0.5], [1, 1.5]])
out1 = F.gelu(x)
# [[-0.15865529, 0.34573123],
# [ 0.84134471, 1.39978933]]
out2 = F.gelu(x, True)
# [[-0.15880799, 0.34571400],
# [ 0.84119201, 1.39957154]]
"""
if
in_dygraph_mode
():
...
...
@@ -181,11 +183,8 @@ def hardshrink(x, threshold=0.5, name=None):
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(
np.array([-1, 0.3, 2.5])
)
x = paddle.to_tensor(
[-1, 0.3, 2.5]
)
out = F.hardshrink(x) # [-1., 0., 2.5]
"""
...
...
@@ -385,11 +384,8 @@ def leaky_relu(x, negative_slope=0.01, name=None):
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(
np.array([-2, 0, 1], 'float32')
)
x = paddle.to_tensor(
[-2., 0., 1.]
)
out = F.leaky_relu(x) # [-0.02, 0., 1.]
"""
...
...
@@ -1147,8 +1143,10 @@ def log_softmax(x, axis=-1, dtype=None, name=None):
.. math::
log
\\
_softmax[i, j] = log(softmax(x))
= log(
\\
frac{\exp(X[i, j])}{
\\
sum_j(exp(X[i, j])})
\\
begin{aligned}
log
\\
_softmax[i, j] &= log(softmax(x))
\\\\
&= log(
\\
frac{
\\
exp(X[i, j])}{
\\
sum_j(
\\
exp(X[i, j])})
\\
end{aligned}
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
...
...
@@ -1174,16 +1172,13 @@ def log_softmax(x, axis=-1, dtype=None, name=None):
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x =
np.array(
[[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]], 'float32')
x = [[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]]
x = paddle.to_tensor(x)
out1 = F.log_softmax(x)
out2 = F.log_softmax(x, dtype='float64')
...
...
python/paddle/nn/layer/activation.py
浏览文件 @
2cd10fc4
...
...
@@ -70,9 +70,8 @@ class ELU(layers.Layer):
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor(
np.array([[-1,6],[1,15.6]])
)
x = paddle.to_tensor(
[[-1. ,6.], [1., 15.6]]
)
m = paddle.nn.ELU(0.2)
out = m(x)
# [[-0.12642411 6. ]
...
...
@@ -166,11 +165,8 @@ class Hardshrink(layers.Layer):
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-1, 0.3, 2.5]))
x = paddle.to_tensor([-1, 0.3, 2.5])
m = paddle.nn.Hardshrink()
out = m(x) # [-1., 0., 2.5]
"""
...
...
@@ -293,11 +289,10 @@ class Hardtanh(layers.Layer):
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor(
np.array([-1.5, 0.3, 2.5])
)
x = paddle.to_tensor(
[-1.5, 0.3, 2.5]
)
m = paddle.nn.Hardtanh()
out = m(x) #
#
[-1., 0.3, 1.]
out = m(x) # [-1., 0.3, 1.]
"""
def
__init__
(
self
,
min
=-
1.0
,
max
=
1.0
,
name
=
None
):
...
...
@@ -397,9 +392,8 @@ class ReLU(layers.Layer):
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor(
np.array([-2, 0, 1]).astype('float32')
)
x = paddle.to_tensor(
[-2., 0., 1.]
)
m = paddle.nn.ReLU()
out = m(x) # [0., 0., 1.]
"""
...
...
@@ -613,7 +607,7 @@ class Hardsigmoid(layers.Layer):
import paddle
m = paddle.nn.
S
igmoid()
m = paddle.nn.
Hards
igmoid()
x = paddle.to_tensor([-4., 5., 1.])
out = m(x) # [0., 1, 0.666667]
"""
...
...
@@ -1016,8 +1010,10 @@ class LogSoftmax(layers.Layer):
.. math::
Out[i, j] = log(softmax(x))
= log(
\\
frac{\exp(X[i, j])}{
\\
sum_j(exp(X[i, j])})
\\
begin{aligned}
Out[i, j] &= log(softmax(x))
\\\\
&= log(
\\
frac{
\\
exp(X[i, j])}{
\\
sum_j(
\\
exp(X[i, j])})
\\
end{aligned}
Parameters:
axis (int, optional): The axis along which to perform log_softmax
...
...
@@ -1035,16 +1031,13 @@ class LogSoftmax(layers.Layer):
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x =
np.array(
[[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]])
x = [[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]]
m = paddle.nn.LogSoftmax()
x = paddle.to_tensor(x)
out = m(x)
...
...
python/paddle/tensor/creation.py
浏览文件 @
2cd10fc4
...
...
@@ -300,9 +300,6 @@ def ones(shape, dtype=None, name=None):
def
ones_like
(
x
,
dtype
=
None
,
name
=
None
):
"""
:alias_main: paddle.ones_like
:alias: paddle.tensor.ones_like, paddle.tensor.creation.ones_like
This OP returns a Tensor filled with the value 1, with the same shape and
data type (use ``dtype`` if ``dtype`` is not None) as ``x``.
...
...
@@ -323,18 +320,16 @@ def ones_like(x, dtype=None, name=None):
Raise:
TypeError: If ``dtype`` is not None and is not bool, float16, float32,
float64, int32 or int64.
float64, int32 or int64.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
x = paddle.to_tensor([1,2,3])
out1 = paddle.
zero
s_like(x) # [1., 1., 1.]
out2 = paddle.
zero
s_like(x, dtype='int32') # [1, 1, 1]
out1 = paddle.
one
s_like(x) # [1., 1., 1.]
out2 = paddle.
one
s_like(x, dtype='int32') # [1, 1, 1]
"""
return
full_like
(
x
=
x
,
fill_value
=
1
,
dtype
=
dtype
,
name
=
name
)
...
...
@@ -380,9 +375,6 @@ def zeros(shape, dtype=None, name=None):
def
zeros_like
(
x
,
dtype
=
None
,
name
=
None
):
"""
:alias_main: paddle.zeros_like
:alias: paddle.tensor.zeros_like, paddle.tensor.creation.zeros_like
This OP returns a Tensor filled with the value 0, with the same shape and
data type (use ``dtype`` if ``dtype`` is not None) as ``x``.
...
...
@@ -403,16 +395,14 @@ def zeros_like(x, dtype=None, name=None):
Raise:
TypeError: If ``dtype`` is not None and is not bool, float16, float32,
float64, int32 or int64.
float64, int32 or int64.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
x = paddle.to_tensor([1,2,3])
x = paddle.to_tensor([1, 2, 3])
out1 = paddle.zeros_like(x) # [0., 0., 0.]
out2 = paddle.zeros_like(x, dtype='int32') # [0, 0, 0]
...
...
@@ -519,9 +509,6 @@ def full(shape, fill_value, dtype=None, name=None):
def
arange
(
start
=
0
,
end
=
None
,
step
=
1
,
dtype
=
None
,
name
=
None
):
"""
:alias_main: paddle.arange
:alias: paddle.tensor.arange, paddle.tensor.creation.arange
This OP returns a 1-D Tensor with spaced values within a given interval.
Values are generated into the half-open interval [``start``, ``end``) with
...
...
@@ -552,33 +539,30 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
Returns:
Tensor: A 1-D Tensor with values from the interval [``start``, ``end``)
taken with common difference ``step`` beginning from ``start``. Its
data type is set by ``dtype``.
taken with common difference ``step`` beginning from ``start``. Its
data type is set by ``dtype``.
Raises:
TypeError: If ``dtype`` is not int32, int64, float32, float64.
examples:
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
import paddle
out1 = paddle.arange(5)
# [0, 1, 2, 3, 4]
out1 = paddle.arange(5)
# [0, 1, 2, 3, 4]
out2 = paddle.arange(3, 9, 2.0)
# [3, 5, 7]
out2 = paddle.arange(3, 9, 2.0)
# [3, 5, 7]
# use 4.999 instead of 5.0 to avoid floating point rounding errors
out3 = paddle.arange(4.999, dtype='float32')
# [0., 1., 2., 3., 4.]
# use 4.999 instead of 5.0 to avoid floating point rounding errors
out3 = paddle.arange(4.999, dtype='float32')
# [0., 1., 2., 3., 4.]
start_var = paddle.to_tensor([3])
out4 = paddle.arange(start_var, 7)
# [3, 4, 5, 6]
start_var = paddle.to_tensor([3])
out4 = paddle.arange(start_var, 7)
# [3, 4, 5, 6]
"""
if
dtype
is
None
:
...
...
python/paddle/tensor/random.py
浏览文件 @
2cd10fc4
...
...
@@ -252,16 +252,14 @@ def standard_normal(shape, dtype=None, name=None):
import paddle
paddle.disable_static()
# example 1: attr shape is a list which doesn't contain Tensor.
out1 = paddle.standard_normal(shape=[2, 3])
# [[-2.923464 , 0.11934398, -0.51249987], # random
# [ 0.39632758, 0.08177969, 0.2692008 ]] # random
# example 2: attr shape is a list which contains Tensor.
dim1 = paddle.
full([1], 2, "int64"
)
dim2 = paddle.
full([1], 3, "int32"
)
dim1 = paddle.
to_tensor([2], 'int64'
)
dim2 = paddle.
to_tensor([3], 'int32'
)
out2 = paddle.standard_normal(shape=[dim1, dim2, 2])
# [[[-2.8852394 , -0.25898588], # random
# [-0.47420555, 0.17683524], # random
...
...
@@ -272,8 +270,7 @@ def standard_normal(shape, dtype=None, name=None):
# example 3: attr shape is a Tensor, the data type must be int64 or int32.
shape_tensor = paddle.to_tensor([2, 3])
result_3 = paddle.standard_normal(shape_tensor)
out3 = paddle.standard_normal(shape_tensor)
# [[-2.878077 , 0.17099959, 0.05111201] # random
# [-0.3761474, -1.044801 , 1.1870178 ]] # random
...
...
@@ -281,7 +278,58 @@ def standard_normal(shape, dtype=None, name=None):
return
gaussian
(
shape
=
shape
,
mean
=
0.0
,
std
=
1.0
,
dtype
=
dtype
,
name
=
name
)
randn
=
standard_normal
def
randn
(
shape
,
dtype
=
None
,
name
=
None
):
"""
This OP returns a Tensor filled with random values sampled from a standard
normal distribution with mean 0 and standard deviation 1, with ``shape``
and ``dtype``.
Args:
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype``
for details).
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor filled with random values sampled from a standard
normal distribution with mean 0 and standard deviation 1, with
``shape`` and ``dtype``.
Examples:
.. code-block:: python
import paddle
# example 1: attr shape is a list which doesn't contain Tensor.
out1 = paddle.randn(shape=[2, 3])
# [[-2.923464 , 0.11934398, -0.51249987], # random
# [ 0.39632758, 0.08177969, 0.2692008 ]] # random
# example 2: attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64')
dim2 = paddle.to_tensor([3], 'int32')
out2 = paddle.randn(shape=[dim1, dim2, 2])
# [[[-2.8852394 , -0.25898588], # random
# [-0.47420555, 0.17683524], # random
# [-0.7989969 , 0.00754541]], # random
# [[ 0.85201347, 0.32320443], # random
# [ 1.1399018 , 0.48336947], # random
# [ 0.8086993 , 0.6868893 ]]] # random
# example 3: attr shape is a Tensor, the data type must be int64 or int32.
shape_tensor = paddle.to_tensor([2, 3])
out3 = paddle.randn(shape_tensor)
# [[-2.878077 , 0.17099959, 0.05111201] # random
# [-0.3761474, -1.044801 , 1.1870178 ]] # random
"""
return
standard_normal
(
shape
,
dtype
,
name
)
def
normal
(
mean
=
0.0
,
std
=
1.0
,
shape
=
None
,
name
=
None
):
...
...
@@ -322,8 +370,6 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
import paddle
paddle.disable_static()
out1 = paddle.normal(shape=[2, 3])
# [[ 0.17501129 0.32364586 1.561118 ] # random
# [-1.7232178 1.1545963 -0.76156676]] # random
...
...
@@ -381,7 +427,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
Examples:
::
.. code-block:: text
Input:
shape = [1, 2]
...
...
@@ -423,33 +469,27 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
import paddle
paddle.disable_static()
# example 1:
# attr shape is a list which doesn't contain Tensor.
result_1 = paddle.tensor.random
.uniform(shape=[3, 4])
# [[ 0.84524226, 0.6921872, 0.56528175, 0.71690357],
# [-0.34646994, -0.45116323, -0.09902662, -0.11397249],
# [ 0.433519, 0.39483607, -0.8660099, 0.83664286]]
out1 = paddle
.uniform(shape=[3, 4])
# [[ 0.84524226, 0.6921872, 0.56528175, 0.71690357],
# random
# [-0.34646994, -0.45116323, -0.09902662, -0.11397249],
# random
# [ 0.433519, 0.39483607, -0.8660099, 0.83664286]]
# random
# example 2:
# attr shape is a list which contains Tensor.
dim
_1 = paddle.full([1], 2, "int64"
)
dim
_2 = paddle.full([1], 3, "int32"
)
result_2 = paddle.tensor.random.uniform(shape=[dim_1, dim_
2])
# [[-0.9951253, 0.30757582, 0.9899647 ],
# [ 0.5864527, 0.6607096, -0.8886161
]]
dim
1 = paddle.to_tensor([2], 'int64'
)
dim
2 = paddle.to_tensor([3], 'int32'
)
out2 = paddle.uniform(shape=[dim1, dim
2])
# [[-0.9951253, 0.30757582, 0.9899647 ],
# random
# [ 0.5864527, 0.6607096, -0.8886161
]] # random
# example 3:
# attr shape is a Tensor, the data type must be int64 or int32.
shape_tensor = paddle.to_tensor([2, 3])
result_3 = paddle.tensor.random.uniform(shape_tensor)
# if shape_tensor's value is [2, 3]
# result_3 is:
# [[-0.8517412, -0.4006908, 0.2551912 ],
# [ 0.3364414, 0.36278176, -0.16085452]]
out3 = paddle.uniform(shape_tensor)
# [[-0.8517412, -0.4006908, 0.2551912 ], # random
# [ 0.3364414, 0.36278176, -0.16085452]] # random
"""
if
dtype
is
None
:
dtype
=
paddle
.
framework
.
get_default_dtype
()
...
...
@@ -517,8 +557,6 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
import paddle
paddle.disable_static()
# example 1:
# attr shape is a list which doesn't contain Tensor.
out1 = paddle.randint(low=-5, high=5, shape=[3])
...
...
@@ -526,18 +564,16 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
# example 2:
# attr shape is a list which contains Tensor.
dim1 = paddle.
full([1], 2, "int64"
)
dim2 = paddle.
full([1], 3, "int32"
)
out2 = paddle.randint(low=-5, high=5, shape=[dim1, dim2]
, dtype="int32"
)
dim1 = paddle.
to_tensor([2], 'int64'
)
dim2 = paddle.
to_tensor([3], 'int32'
)
out2 = paddle.randint(low=-5, high=5, shape=[dim1, dim2])
# [[0, -1, -3], # random
# [4, -2, 0]] # random
# example 3:
# attr shape is a Tensor
shape_tensor = paddle.to_tensor(3)
result_3 = paddle.randint(low=-5, high=5, shape=shape_tensor)
out3 = paddle.randint(low=-5, high=5, shape=shape_tensor)
# [-2, 2, 3] # random
# example 4:
...
...
@@ -611,8 +647,6 @@ def randperm(n, dtype="int64", name=None):
import paddle
paddle.disable_static()
out1 = paddle.randperm(5)
# [4, 1, 2, 3, 0] # random
...
...
@@ -668,15 +702,14 @@ def rand(shape, dtype=None, name=None):
import paddle
paddle.disable_static()
# example 1: attr shape is a list which doesn't contain Tensor.
out1 = paddle.rand(shape=[2, 3])
# [[0.451152 , 0.55825245, 0.403311 ], # random
# [0.22550228, 0.22106001, 0.7877319 ]] # random
# example 2: attr shape is a list which contains Tensor.
dim1 = paddle.
full([1], 2, "int64"
)
dim2 = paddle.
full([1], 3, "int32"
)
dim1 = paddle.
to_tensor([2], 'int64'
)
dim2 = paddle.
to_tensor([3], 'int32'
)
out2 = paddle.rand(shape=[dim1, dim2, 2])
# [[[0.8879919 , 0.25788337], # random
# [0.28826773, 0.9712097 ], # random
...
...
@@ -687,8 +720,7 @@ def rand(shape, dtype=None, name=None):
# example 3: attr shape is a Tensor, the data type must be int64 or int32.
shape_tensor = paddle.to_tensor([2, 3])
result_3 = paddle.rand(shape_tensor)
out3 = paddle.rand(shape_tensor)
# [[0.22920267, 0.841956 , 0.05981819], # random
# [0.4836288 , 0.24573246, 0.7516129 ]] # random
...
...
python/paddle/tensor/stat.py
浏览文件 @
2cd10fc4
...
...
@@ -56,17 +56,13 @@ def mean(x, axis=None, keepdim=False, name=None):
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = np.array([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]], 'float32')
x = paddle.to_tensor(x)
x = paddle.to_tensor([[[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]])
out1 = paddle.mean(x)
# [12.5]
out2 = paddle.mean(x, axis=-1)
...
...
@@ -145,12 +141,8 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = np.array([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]])
x = paddle.to_tensor(x)
x = paddle.to_tensor([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]])
out1 = paddle.var(x)
# [2.66666667]
out2 = paddle.var(x, axis=1)
...
...
@@ -208,12 +200,8 @@ def std(x, axis=None, unbiased=True, keepdim=False, name=None):
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = np.array([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]])
x = paddle.to_tensor(x)
x = paddle.to_tensor([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]])
out1 = paddle.std(x)
# [1.63299316]
out2 = paddle.std(x, axis=1)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录