paddle v1.5 动态图下scatter函数会被清零
Created by: githubutilities
scatter函数bug复现
import numpy as np
import paddle.fluid as fluid
def test_scatter_add():
input = fluid.dygraph.to_variable(
np.array([[1, 2],
[5, 6]], dtype='float32'),
)
index = fluid.dygraph.to_variable(
np.array([0, 1], dtype=np.int32)
)
updates = fluid.dygraph.to_variable(
np.array([[3, 4],
[3, 4]], dtype='float32'),
)
output = fluid.layers.scatter(input, index, updates, overwrite=False)
print(output.numpy())
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
test_scatter_add()
输出 - CPU和GPU都能复现这个问题
[[3. 4.]
[3. 4.]]
正确输出
[[4. 6.] [8. 10.]]