未验证 提交 05c9c0a5 编写于 作者: Y Yuang Liu 提交者: GitHub

Fix gather, scatter op 0d tenor GPU error. (#50271)

上级 0dd41a2a
......@@ -285,7 +285,8 @@ void GatherV2GradCUDAFunction(const DenseTensor* input,
if (input->numel() == 0) return;
int axis_index = axis;
int64_t input_index_dim_size = input_dim[axis_index];
int64_t input_index_dim_size =
index->dims().size() == 0 ? 1 : input_dim[axis_index];
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
......
......@@ -182,11 +182,15 @@ template <typename T, typename IndexT = int>
void GPUScatterGradForX(const phi::GPUContext& ctx,
const DenseTensor& index,
DenseTensor* output) {
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto dst_dims = output->dims();
// slice size
int64_t slice_size = 1;
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
int64_t slice_size = 1; // slice size
if (index.dims().size() != 0) {
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
} else {
for (int i = 0; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
}
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
const size_t& slice_bytes = slice_size * sizeof(T);
......
......@@ -887,7 +887,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [3])
def _test_gather_xD_axis_1(self):
def test_gather_xD_axis_1(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
......@@ -901,7 +901,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [2])
def _test_scatter_1D(self):
def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
......@@ -913,7 +913,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.numpy()[2], 4)
self.assertEqual(out.grad.shape, [5])
def _test_scatter_XD(self):
def test_scatter_XD(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
......@@ -1925,7 +1925,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[2].shape, (3,))
@prog_scope()
def _test_gather_XD_axis_1(self):
def test_gather_XD_axis_1(self):
x = paddle.full([2, 3], 1.0, 'float32')
x.stop_gradient = False
index = paddle.full([], 1, 'int64')
......@@ -1940,7 +1940,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[2].shape, (2,))
@prog_scope()
def _test_scatter_1D(self):
def test_scatter_1D(self):
x = paddle.full([10], 1.0, 'float32')
x.stop_gradient = False
index = paddle.full([], 2, 'int64')
......@@ -1956,7 +1956,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[2].shape, (10,))
@prog_scope()
def _test_scatter_XD(self):
def test_scatter_XD(self):
x = paddle.full([2, 3], 1.0, 'float32')
x.stop_gradient = False
index = paddle.full([], 1, 'int64')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册