diff --git a/docs/pytorch_project_convertor/API_docs/ops/README.md b/docs/pytorch_project_convertor/API_docs/ops/README.md index 086c1f41a439129da5d3a1b5e95fc447ac1816a6..00fc26dde7ee78ec3dd5674fbd5cc34a4a0f419f 100644 --- a/docs/pytorch_project_convertor/API_docs/ops/README.md +++ b/docs/pytorch_project_convertor/API_docs/ops/README.md @@ -152,7 +152,7 @@ | 147 | [torch.matmul](https://pytorch.org/docs/stable/generated/torch.matmul.html?highlight=matmul#torch.matmul) | [paddle.matmul](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/matmul_cn.html) | [差异对比](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.matmul.md) | | 148 | [torch.mm](https://pytorch.org/docs/stable/generated/torch.mm.html?highlight=mm#torch.mm) | [paddle.matmul](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/matmul_cn.html) | [差异对比](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.mm.md) | | 149 | [torch.mv](https://pytorch.org/docs/stable/generated/torch.mv.html?highlight=mv#torch.mv) | 无对应实现 | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.mv.md) | - +| 150 | [torch.scatter](https://pytorch.org/docs/stable/generated/torch.scatter.html?highlight=scatter#torch.scatter) | [paddle.scatter_nd_add](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/scatter_nd_add_cn.html) | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md) | diff --git a/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md b/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md new file mode 100644 index 0000000000000000000000000000000000000000..2eb4bb9a854109d6ee5d2770afe552abe8f206f3 --- /dev/null +++ b/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md @@ -0,0 +1,75 @@ +## torch.scatter +### [torch.scatter](https://pytorch.org/docs/stable/generated/torch.scatter.html?highlight=scatter#torch.scatter) + +```python +torch.scatter(tensor, + dim, + index, + src) +``` + +### [paddle.scatter_nd_add](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/scatter_nd_add_cn.html) + +```python +paddle.scatter_nd_add(x, + index, + updates, + name=None) +``` + +### 参数差异 +| PyTorch | PaddlePaddle | 备注 | +| ------------- | ------------ | ------------------------------------------------------ | +| tensor | x | 表示输入Tensor。 | +| dim | - | 表示在哪一个维度scatter,Paddle无此参数 | +| index | index | 输入的索引张量 | +| src | updates | 输入的更新张量 | + + + +### 功能差异 + +#### 使用方式 +因 torch.scatter 与 paddle.scatter_nd_add 差异较大,必须使用 paddle.flatten + paddle.meshgrid + paddle.scatter_nd_add 组合实现,看如下例子 + + +### 代码示例 +``` python +# PyTorch 示例: +src = torch.arange(1, 11).reshape((2, 5)) +# 输出 +# tensor([[ 1, 2, 3, 4, 5], +# [ 6, 7, 8, 9, 10]]) +index = torch.tensor([[0, 1, 2], [0, 1, 4]]) +torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) +# 输出 +# tensor([[1, 2, 3, 0, 0], +# [6, 7, 0, 0, 8], +# [0, 0, 0, 0, 0]]) +``` + +``` python +# PaddlePaddle 组合实现: +x = paddle.zeros([3, 5], dtype="int64") +updates = paddle.arange(1, 11).reshape([2,5]) +# 输出 +# Tensor(shape=[2, 5], dtype=int64, place=CUDAPlace(0), stop_gradient=True, +# [[1 , 2 , 3 , 4 , 5 ], +# [6 , 7 , 8 , 9 , 10]]) +index = paddle.to_tensor([[0, 1, 2], [0, 1, 4]]) +i, j = index.shape +grid_x , grid_y = paddle.meshgrid(paddle.arange(i), paddle.arange(j)) +# 若 PyTorch 的 dim 取 0 +# index = paddle.stack([index.flatten(), grid_y.flatten()], axis=1) +# 若 PyTorch 的 dim 取 1 +index = paddle.stack([grid_x.flatten(), index.flatten()], axis=1) +# PaddlePaddle updates 的 shape 大小必须与 index 对应 +updates_index = paddle.stack([grid_x.flatten(), grid_y.flatten()], axis=1) +updates = paddle.gather_nd(updates, index=updates_index) +paddle.scatter_nd_add(x, index, updates) +# 输出 +# Tensor(shape=[3, 5], dtype=int64, place=CUDAPlace(0), stop_gradient=True, +# [[1, 2, 3, 0, 0], +# [6, 7, 0, 0, 8], +# [0, 0, 0, 0, 0]]) +```