Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7dac3226
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7dac3226
编写于
4月 24, 2020
作者:
L
Li Fuchen
提交者:
GitHub
4月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modified the example of diag_embed english doc, test=develop (#24012)
上级
2961a4f0
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
52 addition
and
13 deletion
+52
-13
python/paddle/nn/functional/extension.py
python/paddle/nn/functional/extension.py
+52
-13
未找到文件。
python/paddle/nn/functional/extension.py
浏览文件 @
7dac3226
...
...
@@ -46,27 +46,68 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1):
This OP creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2)
are filled by ``input``. By default, a 2D plane formed by the last two dimensions
of the returned tensor will be selected.
The argument ``offset`` determines which diagonal is generated:
- If offset = 0, it is the main diagonal.
- If offset > 0, it is above the main diagonal.
- If offset < 0, it is below the main diagonal.
Args:
input(Variable|numpy.ndarray): The input tensor. Must be at least 1-dimensional. The input data type should be float32, float64, int32, int64.
offset(int, optional): Which diagonal to consider. Default: 0 (main diagonal).
dim1(int, optional): The first dimension with respect to which to take diagonal. Default: -2.
dim2(int, optional): The second dimension with respect to which to take diagonal. Default: -1.
Returns:
Variable, the output data type is the same as input data type.
Examples:
.. code-block:: python
import paddle.nn.functional as F
import paddle.fluid.dygraph as dg
import numpy as np
diag_embed = np.random.randn(2, 3).astype('float32')
# [[ 0.7545889 , -0.25074545, 0.5929117 ],
# [-0.6097662 , -0.01753256, 0.619769 ]]
with dg.guard():
data1 = F.diag_embed(diag_embed)
data2 = F.diag_embed(diag_embed, offset=1, dim1=0, dim2=2)
data1.numpy()
# [[[ 0.7545889 , 0. , 0. ],
# [ 0. , -0.25074545, 0. ],
# [ 0. , 0. , 0.5929117 ]],
# [[-0.6097662 , 0. , 0. ],
# [ 0. , -0.01753256, 0. ],
# [ 0. , 0. , 0.619769 ]]]
data2 = F.diag_embed(diag_embed, offset=-1, dim1=0, dim2=2)
data2.numpy()
# [[[ 0. , 0. , 0. , 0. ],
# [ 0.7545889 , 0. , 0. , 0. ],
# [ 0. , -0.25074545, 0. , 0. ],
# [ 0. , 0. , 0.5929117 , 0. ]],
#
# [[ 0. , 0. , 0. , 0. ],
# [-0.6097662 , 0. , 0. , 0. ],
# [ 0. , -0.01753256, 0. , 0. ],
# [ 0. , 0. , 0.619769 , 0. ]]]
data3 = F.diag_embed(diag_embed, offset=1, dim1=0, dim2=2)
data3.numpy()
# [[[ 0. , 0.7545889 , 0. , 0. ],
# [ 0. , -0.6097662 , 0. , 0. ]],
#
# [[ 0. , 0. , -0.25074545, 0. ],
# [ 0. , 0. , -0.01753256, 0. ]],
#
# [[ 0. , 0. , 0. , 0.5929117 ],
# [ 0. , 0. , 0. , 0.619769 ]],
#
# [[ 0. , 0. , 0. , 0. ],
# [ 0. , 0. , 0. , 0. ]]]
"""
inputs
=
{
'Input'
:
[
input
]}
attrs
=
{
'offset'
:
offset
,
'dim1'
:
dim1
,
'dim2'
:
dim2
}
...
...
@@ -80,26 +121,24 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1):
'diag_embed'
)
input_shape
=
list
(
input
.
shape
)
assert
(
len
(
input_shape
)
>=
1
,
\
assert
len
(
input_shape
)
>=
1
,
\
"Input must be at least 1-dimensional, "
\
"But received Input's dimensional: %s.
\n
"
%
\
len
(
input_shape
)
)
len
(
input_shape
)
assert
(
np
.
abs
(
dim1
)
<=
len
(
input_shape
),
"Dim1 is out of range (expected to be in range of [%d, %d], but got %d).
\n
"
%
(
-
(
len
(
input_shape
)
+
1
),
len
(
input_shape
),
dim1
))
assert
np
.
abs
(
dim1
)
<=
len
(
input_shape
),
\
"Dim1 is out of range (expected to be in range of [%d, %d], but got %d).
\n
"
\
%
(
-
(
len
(
input_shape
)
+
1
),
len
(
input_shape
),
dim1
)
assert
(
np
.
abs
(
dim2
)
<=
len
(
input_shape
),
"Dim2 is out of range (expected to be in range of [%d, %d], but got %d).
\n
"
%
(
-
(
len
(
input_shape
)
+
1
),
len
(
input_shape
),
dim2
))
assert
np
.
abs
(
dim2
)
<=
len
(
input_shape
),
\
"Dim2 is out of range (expected to be in range of [%d, %d], but got %d).
\n
"
\
%
(
-
(
len
(
input_shape
)
+
1
),
len
(
input_shape
),
dim2
)
dim1_
=
dim1
if
dim1
>=
0
else
len
(
input_shape
)
+
dim1
+
1
dim2_
=
dim2
if
dim2
>=
0
else
len
(
input_shape
)
+
dim2
+
1
assert
(
dim1_
!=
dim2_
,
assert
dim1_
!=
dim2_
,
\
"dim1 and dim2 cannot be the same dimension."
\
"But received dim1 = %d, dim2 = %d
\n
"
%
(
dim1
,
dim2
)
)
"But received dim1 = %d, dim2 = %d
\n
"
%
(
dim1
,
dim2
)
if
not
in_dygraph_mode
():
__check_input
(
input
,
offset
,
dim1
,
dim2
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录