未验证 提交 f9815bfe 编写于 作者: Y Yuang Liu 提交者: GitHub

Scatter 0D index for gather, 0D index and 0D updates for scatter. (#48452)

上级 a3ae080a
......@@ -1268,37 +1268,69 @@ void GatherInferMeta(const MetaTensor& x,
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
"The index should be 0D or 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
auto input_dim = x.dims();
auto axis_v = axis.to<int>();
if (axis.FromTensor() || axis_v == 0) {
// if axis.FromTensor(), we can not obtain correct shape of output
int batch_size = index_dims[0];
phi::DDim output_dims(input_dim);
output_dims[0] = batch_size;
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
if (index_dims.size() == 0) {
// 0D index will decrease the dimension
if (input_dim.size() == 1) {
// the index is a 0D tensor and the x is a 1D tensor
out->set_dims(phi::DDim(phi::Dim<0>()));
} else {
if (axis.FromTensor() || axis_v == 0) {
// decrease the output dimension
std::vector<int> out_dim_vec;
for (int i = 1; i < input_dim.size(); ++i) {
out_dim_vec.emplace_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
}
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
}
out_dim_vec.push_back(index_size);
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
} else {
if (axis.FromTensor() || axis_v == 0) {
// if axis.FromTensor(), we can not obtain correct shape of output
int batch_size = index_dims[0];
phi::DDim output_dims(input_dim);
output_dims[0] = batch_size;
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
}
......
......@@ -995,31 +995,34 @@ void ScatterInferMeta(const MetaTensor& x,
"index is a 2D tensor, but we get %d.",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be a 0D or 1D tensor when the "
"index is not a 2D tensor, but we get %d.",
index_dims.size()));
}
if (index_dims.size() != 0) {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
phi::errors::InvalidArgument("The index should be a 1D tensor when the "
"index is not a 2D tensor, but we get %d.",
index_dims.size()));
(ref_dims.size() == updates_dims.size()),
true,
phi::errors::InvalidArgument(
"When the Input(Updates) is not a 0D tensor, the "
"Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
PADDLE_ENFORCE_EQ(
updates_dims[0],
index_dims[0],
phi::errors::InvalidArgument(
"Input(Updates) and Input(Ids) should have same batch-size, but"
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
"batch-size is %d.",
updates_dims[0],
index_dims[0]));
}
PADDLE_ENFORCE_EQ(
ref_dims.size(),
updates_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
PADDLE_ENFORCE_EQ(
updates_dims[0],
index_dims[0],
phi::errors::InvalidArgument(
"Input(Updates) and Input(Ids) should have same batch-size, but"
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
"batch-size is %d.",
updates_dims[0],
index_dims[0]));
out->set_dims(ref_dims);
out->share_lod(x);
out->set_dtype(x.dtype());
......
......@@ -94,12 +94,9 @@ void GPUGather(const phi::GPUContext& ctx,
}
// index size
int64_t index_size = index.dims()[0];
if (index_size == 0) return;
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto src_dims = src.dims();
phi::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int64_t slice_size = 1;
......@@ -246,7 +243,9 @@ void GatherV2CUDAFunction(const DenseTensor* input,
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
if (index->dims().size() != 0) {
out_dim_vec.push_back(index_size);
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
......
......@@ -38,7 +38,6 @@ void CPUGather(const phi::CPUContext& ctx,
const DenseTensor& src,
const DenseTensor& index,
DenseTensor* output) {
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(
index.dims()[1],
......@@ -48,14 +47,15 @@ void CPUGather(const phi::CPUContext& ctx,
"in gather_op, but received value is [%d].",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(),
1,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in gather_op,"
"but received shape's size is [%d].",
index.dims().size()));
PADDLE_ENFORCE_EQ(
index.dims().size() == 1 || index.dims().size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 0D or 1D, when it is not 2D, but we get %d",
index.dims().size()));
}
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto src_dims = src.dims();
......@@ -188,7 +188,9 @@ void GatherV2Function(const phi::CPUContext& ctx,
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
if (index->dims().size() != 0) {
out_dim_vec.push_back(index_size);
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
......@@ -224,7 +226,13 @@ void GatherV2GradFunction(const phi::CPUContext& ctx,
if (input->numel() == 0) return;
int axis_index = axis;
int64_t input_index_dim_size = input_dim[axis_index];
int64_t input_index_dim_size;
if (input_dim.size() == out->dims().size()) {
input_index_dim_size = input_dim[axis_index];
} else {
// 0d index
input_index_dim_size = 1;
}
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
......
......@@ -122,7 +122,6 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
const DenseTensor& index,
DenseTensor* output,
bool overwrite = true) {
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(
index.dims()[1],
......@@ -132,26 +131,33 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
"But received value is [%d]",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(),
1,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
PADDLE_ENFORCE_EQ(
index.dims().size() == 1 || index.dims().size() == 0,
true,
phi::errors::InvalidArgument(
"index.dims().size() should be 0, 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
}
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto src_dims = src.dims();
phi::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
size_t slice_size = 1;
if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
} else {
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
}
const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
const size_t& slice_bytes = slice_size * sizeof(T);
// set block and grid num
......
......@@ -76,7 +76,6 @@ void ScatterAssign(const phi::CPUContext& ctx,
const DenseTensor& src,
const DenseTensor& index,
DenseTensor* output) {
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(
index.dims()[1],
......@@ -86,14 +85,15 @@ void ScatterAssign(const phi::CPUContext& ctx,
"But received value is [%d]",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(),
1,
PADDLE_ENFORCE_EQ(index.dims().size() == 1 || index.dims().size() == 0,
true,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]",
"index.dims().size() should be 0, 1 or 2 in "
"scatter_op. But received value is [%d]",
index.dims().size()));
}
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto src_dims = src.dims();
auto dst_dims = output->dims();
......@@ -102,23 +102,29 @@ void ScatterAssign(const phi::CPUContext& ctx,
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(
src_dims[i],
dst_dims[i],
phi::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.",
i,
src_dims[i],
i,
dst_dims[i]));
if (index.dims().size() != 0) {
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(
src_dims[i],
dst_dims[i],
phi::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.",
i,
src_dims[i],
i,
dst_dims[i]));
}
// slice size
size_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
} else {
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
}
const size_t slice_bytes = slice_size * sizeof(T);
......@@ -143,43 +149,48 @@ void ScatterAssignAdd(const phi::CPUContext& ctx,
const DenseTensor& src,
const DenseTensor& index,
DenseTensor* output) {
// check index of shape 1-D
PADDLE_ENFORCE_EQ(
index.dims().size() == 1 ||
index.dims().size() == 1 || index.dims().size() == 0 ||
(index.dims().size() == 2 && index.dims()[1] == 1),
true,
phi::errors::InvalidArgument(
"index's shape is error, "
"expect index'dims shape is 1 or 2 and index.dims[1] is 1"
"but got index'dims shape is %d",
"expect index'dims shape is 0, 1, 2 (index.dims[1] should "
"be 1), but got index'dims shape is %d",
index.dims().size()));
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto src_dims = src.dims();
auto dst_dims = output->dims();
const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(
src_dims[i],
dst_dims[i],
phi::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.",
i,
src_dims[i],
i,
dst_dims[i]));
if (index.dims().size() != 0) {
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(
src_dims[i],
dst_dims[i],
phi::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.",
i,
src_dims[i],
i,
dst_dims[i]));
}
// slice size
size_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
} else {
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
}
const size_t& slice_bytes = slice_size * sizeof(T);
......
......@@ -44,10 +44,10 @@ void GatherGradKernel(const Context& dev_ctx,
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
"The index should be 0D or 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
std::vector<int> xshape(x_grad->dims().size());
......@@ -66,7 +66,7 @@ void GatherGradKernel(const Context& dev_ctx,
index.data<int>(),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
xshape,
index.dims()[0],
index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v,
overwrite);
} else {
......@@ -84,7 +84,7 @@ void GatherGradKernel(const Context& dev_ctx,
index_int_ptr_l3,
reinterpret_cast<XPUType*>(x_grad->data<T>()),
xshape,
index.dims()[0],
index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v,
overwrite);
}
......
......@@ -41,10 +41,10 @@ void GatherKernel(const Context& dev_ctx,
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
"The index should be 0D, 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
std::vector<int> xshape(x.dims().size());
......@@ -56,13 +56,14 @@ void GatherKernel(const Context& dev_ctx,
int r = XPU_SUCCESS;
if (index_type == DataType::INT32) {
r = xpu::gather<XPUType, int>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
index.data<int>(),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
index.dims()[0],
axis_v);
r = xpu::gather<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
index.data<int>(),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v);
} else {
r = xpu::gather<XPUType, int64_t>(
dev_ctx.x_context(),
......@@ -70,7 +71,7 @@ void GatherKernel(const Context& dev_ctx,
index.data<int64_t>(),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
index.dims()[0],
index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v);
}
PADDLE_ENFORCE_EQ(
......
......@@ -43,30 +43,34 @@ void ScatterKernel(const Context &ctx,
// check index of shape 1-D
PADDLE_ENFORCE_EQ(
index.dims().size() == 1 ||
index.dims().size() == 1 || index.dims().size() == 0 ||
(index.dims().size() == 2 && index.dims()[1] == 1),
true,
phi::errors::InvalidArgument(
"index's shape is error, "
"expect index'dims shape is 1 or 2 and index.dims[1] is 1"
"but got index'dims shape is %d",
"expect index'dims shape is 0, 1, 2 (index.dims[1] should "
"be 1), 0 but got index'dims shape is %d",
index.dims().size()));
int index_size = static_cast<int>(index.dims()[0]);
int index_size =
static_cast<int>(index.dims().size() == 0 ? 1 : index.dims()[0]);
auto x_dims = x.dims();
auto update_dims = updates.dims();
for (int i = 1; i < x_dims.size(); i++)
PADDLE_ENFORCE_EQ(
x_dims[i],
update_dims[i],
phi::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.",
i,
x_dims[i],
i,
update_dims[i]));
if (index.dims().size() != 0) {
// only check when the updates tensor is not a 0D tensor
for (int i = 1; i < x_dims.size(); i++)
PADDLE_ENFORCE_EQ(
x_dims[i],
update_dims[i],
phi::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.",
i,
x_dims[i],
i,
update_dims[i]));
}
int dim0 = static_cast<int>(x.dims()[0]);
int dim1 =
......
......@@ -598,6 +598,61 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0)
def test_gather_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 5)
self.assertEqual(out.grad.shape, [])
def test_gather_xD_axis_0(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index)
out.backward()
self.assertEqual(out.shape, [3])
for i in range(3):
self.assertEqual(out.numpy()[i], x.numpy()[1][i])
self.assertEqual(out.grad.shape, [3])
def test_gather_xD_axis_1(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)
self.assertEqual(out.shape, [2])
for i in range(2):
self.assertEqual(out.numpy()[i], x.numpy()[i][1])
def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
out = paddle.scatter(x, index, updates)
out.backward()
self.assertEqual(out.grad.shape, [5])
self.assertEqual(out.numpy()[2], 4)
def test_scatter_XD(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter(x, index, updates)
out.backward()
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
self.assertEqual(out.grad.shape, [2, 3])
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -679,6 +734,68 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 0)
@prog_scope()
def test_gather_1D(self):
x = paddle.full([10], 1.0, 'float32')
index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 1)
@prog_scope()
def test_gather_XD_axis_0(self):
x = paddle.full([2, 3], 1.0, 'float32')
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, (3,))
for i in range(3):
self.assertEqual(res[0][i], 1)
@prog_scope()
def test_gather_XD_axis_1(self):
x = paddle.full([2, 3], 1.0, 'float32')
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, (2,))
for i in range(2):
self.assertEqual(res[0][i], 1)
@prog_scope()
def test_scatter_1D(self):
x = paddle.full([10], 1.0, 'float32')
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4, 'float32')
out = paddle.scatter(x, index, updates)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0][2], 4)
@prog_scope()
def test_scatter_XD(self):
x = paddle.full([2, 3], 1.0, 'float32')
index = paddle.full([], 1, 'int64')
updates = paddle.full([3], 4, 'float32')
out = paddle.scatter(x, index, updates)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
for i in range(3):
self.assertEqual(res[0][1][i], 4)
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
......@@ -426,6 +426,55 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0)
def test_gather_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 5)
self.assertEqual(out.grad.shape, [])
def test_gather_xD_axis_0(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index)
out.backward()
self.assertEqual(out.shape, [3])
for i in range(3):
self.assertEqual(out.numpy()[i], x.numpy()[1][i])
self.assertEqual(out.grad.shape, [3])
def test_gather_xD_axis_1(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)
self.assertEqual(out.shape, [2])
for i in range(2):
self.assertEqual(out.numpy()[i], x.numpy()[i][1])
def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0])
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
out = paddle.scatter(x, index, updates)
self.assertEqual(out.numpy()[2], 4)
def test_scatter_XD(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter(x, index, updates)
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
......@@ -2728,13 +2728,13 @@ def gather(x, index, axis=None, name=None):
x (Tensor): The source input tensor with rank>=1. Supported data type is
int32, int64, float32, float64 and uint8 (only for CPU),
float16 (only for GPU).
index (Tensor): The index input tensor with rank=1. Data type is int32 or int64.
index (Tensor): The index input tensor with rank=0 or rank=1. Data type is int32 or int64.
axis (Tensor|int, optional): The axis of input to be gathered, it's can be int or a Tensor with data type is int32 or int64. The default value is None, if None, the ``axis`` is 0.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
output (Tensor), The output is a tensor with the same rank as ``x``.
output (Tensor), If the index is a 1-D tensor, the output is a tensor with the same shape as ``x``. If the index is a 0-D tensor, the output will reduce the dimension where the axis pointing.
Examples:
......@@ -2888,8 +2888,8 @@ def scatter(x, index, updates, overwrite=True, name=None):
Args:
x (Tensor): The input N-D Tensor with ndim>=1. Data type can be float32, float64.
index (Tensor): The index 1-D Tensor. Data type can be int32, int64. The length of index cannot exceed updates's length, and the value in index cannot exceed input's length.
updates (Tensor): update input with updates parameter based on index. shape should be the same as input, and dim value with dim > 1 should be the same as input.
index (Tensor): The index is a 1-D or 0-D Tensor. Data type can be int32, int64. The length of index cannot exceed updates's length, and the value in index cannot exceed input's length.
updates (Tensor): Update input with updates parameter based on index. When the index is a 1-D tensor, the updates shape should be the same as input, and dim value with dim > 1 should be the same as input. When the index is a 0-D tensor, the updates should be a (N-1)-D tensor, the ith dim of the updates should be queal with the (i+1)th dim of the input.
overwrite (bool): The mode that updating the output when there are same indices.
If True, use the overwrite mode to update the output of the same index,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册