未验证 提交 3e2c6a56 编写于 作者: Y Yuang Liu 提交者: GitHub

[bug fix] fix scatter 0d index grad error (#55738)

上级 9c101490
...@@ -1060,7 +1060,7 @@ void ScatterInferMeta(const MetaTensor& x, ...@@ -1060,7 +1060,7 @@ void ScatterInferMeta(const MetaTensor& x,
(ref_dims.size() == updates_dims.size()), (ref_dims.size() == updates_dims.size()),
true, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"When the Input(Updates) is not a 0D tensor, the " "When the Input(Index) is not a 0D tensor, the "
"Input(X) and Input(Updates) should have the same shape size, " "Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of " "but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.", "Input(Updates)'s shape is %d.",
...@@ -1075,6 +1075,17 @@ void ScatterInferMeta(const MetaTensor& x, ...@@ -1075,6 +1075,17 @@ void ScatterInferMeta(const MetaTensor& x,
"batch-size is %d.", "batch-size is %d.",
updates_dims[0], updates_dims[0],
index_dims[0])); index_dims[0]));
} else {
PADDLE_ENFORCE_EQ(
(ref_dims.size() - 1 == updates_dims.size()),
true,
phi::errors::InvalidArgument(
"When the Input(Index) is a 0D tensor, the "
"Input(Updates) should have the shape size as Input(X)'s "
"shape size - 1. But received the size of Input(x)'s shape is %d, "
" the size of Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
} }
out->set_dims(ref_dims); out->set_dims(ref_dims);
out->share_lod(x); out->share_lod(x);
......
...@@ -195,12 +195,8 @@ void GPUScatterGradForX(const phi::GPUContext& ctx, ...@@ -195,12 +195,8 @@ void GPUScatterGradForX(const phi::GPUContext& ctx,
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto dst_dims = output->dims(); auto dst_dims = output->dims();
// slice size // slice size
int64_t slice_size = 1; // slice size int64_t slice_size = 1;
if (index.dims().size() != 0) { for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
} else {
for (int i = 0; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
}
const IndexT* p_index = index.data<IndexT>(); const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
const size_t& slice_bytes = slice_size * sizeof(T); const size_t& slice_bytes = slice_size * sizeof(T);
......
...@@ -244,7 +244,7 @@ template <typename T, typename IndexT = int> ...@@ -244,7 +244,7 @@ template <typename T, typename IndexT = int>
void CPUScatterGradForX(const phi::CPUContext& ctx UNUSED, void CPUScatterGradForX(const phi::CPUContext& ctx UNUSED,
const DenseTensor& index, const DenseTensor& index,
DenseTensor* output) { DenseTensor* output) {
int64_t index_size = index.dims()[0]; int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto dst_dims = output->dims(); auto dst_dims = output->dims();
const IndexT* p_index = index.data<IndexT>(); const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
......
...@@ -1916,6 +1916,36 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1916,6 +1916,36 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0]) np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0])
self.assertEqual(out.grad.shape, [2, 3]) self.assertEqual(out.grad.shape, [2, 3])
def test_scatter_shape_check(self):
x = paddle.to_tensor([1.0, 2.0, 3.0])
index = paddle.to_tensor(1)
updates = paddle.to_tensor([3.0])
with self.assertRaises(ValueError):
out = paddle.scatter(x, index, updates)
x = paddle.to_tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]])
index = paddle.to_tensor(1)
updates = paddle.to_tensor([[5.0, 5.0]])
with self.assertRaises(ValueError):
out = paddle.scatter(x, index, updates)
def test_scatter_0D_index(self):
x = paddle.to_tensor([1.0, 2.0, 3.0], stop_gradient=False)
index = paddle.to_tensor(1)
updates = paddle.to_tensor(3.0)
out = paddle.scatter(x, index, updates)
out.backward()
np.testing.assert_array_equal(x.grad.numpy()[1], 0.0)
x = paddle.to_tensor(
[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], stop_gradient=False
)
index = paddle.to_tensor(1)
updates = paddle.to_tensor([5.0, 5.0])
out = paddle.scatter(x, index, updates)
out.backward()
np.testing.assert_array_equal(x.grad.numpy()[1], [0.0, 0.0])
def test_diagflat(self): def test_diagflat(self):
x1 = paddle.rand([]) x1 = paddle.rand([])
x2 = paddle.rand([]) x2 = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册