未验证 提交 086b5406 编写于 作者: 欧阳罢笔's avatar 欧阳罢笔 提交者: GitHub

Add gather api combined implementation (#730)

* Add gather api combined implementation

* Update gather api combined implementation
上级 26d2ba03
......@@ -46,8 +46,8 @@ if axis == 1:
将tensor_list中的tensor沿axis轴拼接
```
### 代码示例
``` python
# PyTorch示例:
t = torch.tensor([[1, 2], [3, 4]])
......@@ -66,3 +66,34 @@ paddle.gather(t, paddle.to_tensor([1, 0]), 1)
# [[2, 1],
# [4, 3]])
```
### 组合实现
```python
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out
t = paddle.to_tensor([[1, 2], [3, 4]])
paddle_gather(t, 1, paddle.to_tensor([[0, 0], [1, 0]]))
# 输出
# Tensor(shape=[2, 2], dtype=int32, place=CPUPlace, stop_gradient=True,
# [[1, 1],
# [4, 3]])
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册