未验证 提交 3e2c6a56 编写于 作者: Y Yuang Liu 提交者: GitHub

[bug fix] fix scatter 0d index grad error (#55738)

上级 9c101490
......@@ -1060,7 +1060,7 @@ void ScatterInferMeta(const MetaTensor& x,
(ref_dims.size() == updates_dims.size()),
true,
phi::errors::InvalidArgument(
"When the Input(Updates) is not a 0D tensor, the "
"When the Input(Index) is not a 0D tensor, the "
"Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.",
......@@ -1075,6 +1075,17 @@ void ScatterInferMeta(const MetaTensor& x,
"batch-size is %d.",
updates_dims[0],
index_dims[0]));
} else {
PADDLE_ENFORCE_EQ(
(ref_dims.size() - 1 == updates_dims.size()),
true,
phi::errors::InvalidArgument(
"When the Input(Index) is a 0D tensor, the "
"Input(Updates) should have the shape size as Input(X)'s "
"shape size - 1. But received the size of Input(x)'s shape is %d, "
" the size of Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
}
out->set_dims(ref_dims);
out->share_lod(x);
......
......@@ -195,12 +195,8 @@ void GPUScatterGradForX(const phi::GPUContext& ctx,
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto dst_dims = output->dims();
// slice size
int64_t slice_size = 1; // slice size
if (index.dims().size() != 0) {
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
} else {
for (int i = 0; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
}
int64_t slice_size = 1;
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
const size_t& slice_bytes = slice_size * sizeof(T);
......
......@@ -244,7 +244,7 @@ template <typename T, typename IndexT = int>
void CPUScatterGradForX(const phi::CPUContext& ctx UNUSED,
const DenseTensor& index,
DenseTensor* output) {
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto dst_dims = output->dims();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
......
......@@ -1916,6 +1916,36 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0])
self.assertEqual(out.grad.shape, [2, 3])
def test_scatter_shape_check(self):
x = paddle.to_tensor([1.0, 2.0, 3.0])
index = paddle.to_tensor(1)
updates = paddle.to_tensor([3.0])
with self.assertRaises(ValueError):
out = paddle.scatter(x, index, updates)
x = paddle.to_tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]])
index = paddle.to_tensor(1)
updates = paddle.to_tensor([[5.0, 5.0]])
with self.assertRaises(ValueError):
out = paddle.scatter(x, index, updates)
def test_scatter_0D_index(self):
x = paddle.to_tensor([1.0, 2.0, 3.0], stop_gradient=False)
index = paddle.to_tensor(1)
updates = paddle.to_tensor(3.0)
out = paddle.scatter(x, index, updates)
out.backward()
np.testing.assert_array_equal(x.grad.numpy()[1], 0.0)
x = paddle.to_tensor(
[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], stop_gradient=False
)
index = paddle.to_tensor(1)
updates = paddle.to_tensor([5.0, 5.0])
out = paddle.scatter(x, index, updates)
out.backward()
np.testing.assert_array_equal(x.grad.numpy()[1], [0.0, 0.0])
def test_diagflat(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册