Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
1b5969d9
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1b5969d9
编写于
1月 25, 2022
作者:
W
wjj19950828
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'upstream/develop' into Fixed_pytorch_readme
上级
b4eceac8
e6ddfb4d
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
32 addition
and
1 deletion
+32
-1
docs/pytorch_project_convertor/API_docs/ops/torch.gather.md
docs/pytorch_project_convertor/API_docs/ops/torch.gather.md
+32
-1
未找到文件。
docs/pytorch_project_convertor/API_docs/ops/torch.gather.md
浏览文件 @
1b5969d9
...
...
@@ -46,8 +46,8 @@ if axis == 1:
将tensor_list中的tensor沿axis轴拼接
```
### 代码示例
```
python
# PyTorch示例:
t
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]])
...
...
@@ -66,3 +66,34 @@ paddle.gather(t, paddle.to_tensor([1, 0]), 1)
# [[2, 1],
# [4, 3]])
```
### 组合实现
```
python
def
paddle_gather
(
x
,
dim
,
index
):
index_shape
=
index
.
shape
index_flatten
=
index
.
flatten
()
if
dim
<
0
:
dim
=
len
(
x
.
shape
)
+
dim
nd_index
=
[]
for
k
in
range
(
len
(
x
.
shape
)):
if
k
==
dim
:
nd_index
.
append
(
index_flatten
)
else
:
reshape_shape
=
[
1
]
*
len
(
x
.
shape
)
reshape_shape
[
k
]
=
x
.
shape
[
k
]
x_arange
=
paddle
.
arange
(
x
.
shape
[
k
],
dtype
=
index
.
dtype
)
x_arange
=
x_arange
.
reshape
(
reshape_shape
)
dim_index
=
paddle
.
expand
(
x_arange
,
index_shape
).
flatten
()
nd_index
.
append
(
dim_index
)
ind2
=
paddle
.
transpose
(
paddle
.
stack
(
nd_index
),
[
1
,
0
]).
astype
(
"int64"
)
paddle_out
=
paddle
.
gather_nd
(
x
,
ind2
).
reshape
(
index_shape
)
return
paddle_out
t
=
paddle
.
to_tensor
([[
1
,
2
],
[
3
,
4
]])
paddle_gather
(
t
,
1
,
paddle
.
to_tensor
([[
0
,
0
],
[
1
,
0
]]))
# 输出
# Tensor(shape=[2, 2], dtype=int32, place=CPUPlace, stop_gradient=True,
# [[1, 1],
# [4, 3]])
```
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录