diff --git a/docs/pytorch_project_convertor/API_docs/ops/torch.gather.md b/docs/pytorch_project_convertor/API_docs/ops/torch.gather.md index d0b0ce6adfed1945788aa5b3be9e25f6dde75a89..edd86ec1e200e7c90838e16ee6dd583f84c8c8c3 100644 --- a/docs/pytorch_project_convertor/API_docs/ops/torch.gather.md +++ b/docs/pytorch_project_convertor/API_docs/ops/torch.gather.md @@ -46,8 +46,8 @@ if axis == 1: 将tensor_list中的tensor沿axis轴拼接 ``` - ### 代码示例 + ``` python # PyTorch示例: t = torch.tensor([[1, 2], [3, 4]]) @@ -66,3 +66,34 @@ paddle.gather(t, paddle.to_tensor([1, 0]), 1) # [[2, 1], # [4, 3]]) ``` + +### 组合实现 + +```python +def paddle_gather(x, dim, index): + index_shape = index.shape + index_flatten = index.flatten() + if dim < 0: + dim = len(x.shape) + dim + nd_index = [] + for k in range(len(x.shape)): + if k == dim: + nd_index.append(index_flatten) + else: + reshape_shape = [1] * len(x.shape) + reshape_shape[k] = x.shape[k] + x_arange = paddle.arange(x.shape[k], dtype=index.dtype) + x_arange = x_arange.reshape(reshape_shape) + dim_index = paddle.expand(x_arange, index_shape).flatten() + nd_index.append(dim_index) + ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64") + paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) + return paddle_out + +t = paddle.to_tensor([[1, 2], [3, 4]]) +paddle_gather(t, 1, paddle.to_tensor([[0, 0], [1, 0]])) +# 输出 +# Tensor(shape=[2, 2], dtype=int32, place=CPUPlace, stop_gradient=True, +# [[1, 1], +# [4, 3]]) +```