提交 adfa4688 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/functional): fix scatter doctest failed for GPU platform issue

GitOrigin-RevId: b5f92c39dd46f13aee07e9cad441e017069aaa3a
上级 4f8e6080
......@@ -236,6 +236,14 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
.. note::
Please notice that, due to performance issues, the result is uncertain on the GPU device
if scatter difference positions from source to the same destination position
regard to index tensor.
Show the case using the following examples, the oup[0][2] is maybe
from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
if set the index[1][2] from 1 to 0.
:param inp: the inp tensor which to be scattered
:param axis: the axis along which to index
......@@ -252,17 +260,16 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
index = tensor([[0,2,0,2,1],[2,0,0,1,2]])
index = tensor([[0,2,0,2,1],[2,0,1,1,2]])
oup = F.scatter(inp, 0, index,source)
print(oup.numpy())
Outputs:
.. testoutput::
:options: +SKIP
[[0.9935 0.0718 0.5939 0. 0. ]
[0. 0. 0. 0.357 0.4396]
[[0.9935 0.0718 0.2256 0. 0. ]
[0. 0. 0.5939 0.357 0.4396]
[0.7723 0.9465 0. 0.8926 0.4576]]
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册