未验证 提交 f9815bfe 编写于 作者: Y Yuang Liu 提交者: GitHub

Scatter 0D index for gather, 0D index and 0D updates for scatter. (#48452)

上级 a3ae080a
...@@ -1268,37 +1268,69 @@ void GatherInferMeta(const MetaTensor& x, ...@@ -1268,37 +1268,69 @@ void GatherInferMeta(const MetaTensor& x,
index_dims[1])); index_dims[1]));
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_dims.size(), index_dims.size() == 1 || index_dims.size() == 0,
1, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d", "The index should be 0D or 1D, when it is not 2D, but we get %d",
index_dims.size())); index_dims.size()));
} }
auto input_dim = x.dims(); auto input_dim = x.dims();
auto axis_v = axis.to<int>(); auto axis_v = axis.to<int>();
if (axis.FromTensor() || axis_v == 0) { if (index_dims.size() == 0) {
// if axis.FromTensor(), we can not obtain correct shape of output // 0D index will decrease the dimension
int batch_size = index_dims[0]; if (input_dim.size() == 1) {
phi::DDim output_dims(input_dim); // the index is a 0D tensor and the x is a 1D tensor
output_dims[0] = batch_size; out->set_dims(phi::DDim(phi::Dim<0>()));
out->set_dims(output_dims); } else {
out->set_dtype(x.dtype()); if (axis.FromTensor() || axis_v == 0) {
out->share_lod(x); // decrease the output dimension
} else { std::vector<int> out_dim_vec;
int index_size = index_dims[0]; for (int i = 1; i < input_dim.size(); ++i) {
std::vector<int> out_dim_vec; out_dim_vec.emplace_back(input_dim[i]);
for (int i = 0; i < axis_v; i++) { }
out_dim_vec.push_back(input_dim[i]); auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
}
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
} }
out_dim_vec.push_back(index_size); } else {
for (int i = axis_v + 1; i < input_dim.size(); i++) { if (axis.FromTensor() || axis_v == 0) {
out_dim_vec.push_back(input_dim[i]); // if axis.FromTensor(), we can not obtain correct shape of output
int batch_size = index_dims[0];
phi::DDim output_dims(input_dim);
output_dims[0] = batch_size;
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} }
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} }
} }
......
...@@ -995,31 +995,34 @@ void ScatterInferMeta(const MetaTensor& x, ...@@ -995,31 +995,34 @@ void ScatterInferMeta(const MetaTensor& x,
"index is a 2D tensor, but we get %d.", "index is a 2D tensor, but we get %d.",
index_dims[1])); index_dims[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be a 0D or 1D tensor when the "
"index is not a 2D tensor, but we get %d.",
index_dims.size()));
}
if (index_dims.size() != 0) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_dims.size(), (ref_dims.size() == updates_dims.size()),
1, true,
phi::errors::InvalidArgument("The index should be a 1D tensor when the " phi::errors::InvalidArgument(
"index is not a 2D tensor, but we get %d.", "When the Input(Updates) is not a 0D tensor, the "
index_dims.size())); "Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
PADDLE_ENFORCE_EQ(
updates_dims[0],
index_dims[0],
phi::errors::InvalidArgument(
"Input(Updates) and Input(Ids) should have same batch-size, but"
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
"batch-size is %d.",
updates_dims[0],
index_dims[0]));
} }
PADDLE_ENFORCE_EQ(
ref_dims.size(),
updates_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
PADDLE_ENFORCE_EQ(
updates_dims[0],
index_dims[0],
phi::errors::InvalidArgument(
"Input(Updates) and Input(Ids) should have same batch-size, but"
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
"batch-size is %d.",
updates_dims[0],
index_dims[0]));
out->set_dims(ref_dims); out->set_dims(ref_dims);
out->share_lod(x); out->share_lod(x);
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
......
...@@ -94,12 +94,9 @@ void GPUGather(const phi::GPUContext& ctx, ...@@ -94,12 +94,9 @@ void GPUGather(const phi::GPUContext& ctx,
} }
// index size // index size
int64_t index_size = index.dims()[0]; int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
if (index_size == 0) return;
auto src_dims = src.dims(); auto src_dims = src.dims();
phi::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size // slice size
int64_t slice_size = 1; int64_t slice_size = 1;
...@@ -246,7 +243,9 @@ void GatherV2CUDAFunction(const DenseTensor* input, ...@@ -246,7 +243,9 @@ void GatherV2CUDAFunction(const DenseTensor* input,
inner_dim_size *= input_dim[i]; inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]); out_dim_vec.push_back(input_dim[i]);
} }
out_dim_vec.push_back(index_size); if (index->dims().size() != 0) {
out_dim_vec.push_back(index_size);
}
for (int i = axis_index + 1; i < input_dim.size(); i++) { for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i]; outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]); out_dim_vec.push_back(input_dim[i]);
......
...@@ -38,7 +38,6 @@ void CPUGather(const phi::CPUContext& ctx, ...@@ -38,7 +38,6 @@ void CPUGather(const phi::CPUContext& ctx,
const DenseTensor& src, const DenseTensor& src,
const DenseTensor& index, const DenseTensor& index,
DenseTensor* output) { DenseTensor* output) {
// check index of shape 1-D
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index.dims()[1], index.dims()[1],
...@@ -48,14 +47,15 @@ void CPUGather(const phi::CPUContext& ctx, ...@@ -48,14 +47,15 @@ void CPUGather(const phi::CPUContext& ctx,
"in gather_op, but received value is [%d].", "in gather_op, but received value is [%d].",
index.dims()[1])); index.dims()[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index.dims().size(), PADDLE_ENFORCE_EQ(
1, index.dims().size() == 1 || index.dims().size() == 0,
phi::errors::InvalidArgument( true,
"index.dims().size() should be 1 or 2 in gather_op," phi::errors::InvalidArgument(
"but received shape's size is [%d].", "The index should be 0D or 1D, when it is not 2D, but we get %d",
index.dims().size())); index.dims().size()));
} }
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
...@@ -188,7 +188,9 @@ void GatherV2Function(const phi::CPUContext& ctx, ...@@ -188,7 +188,9 @@ void GatherV2Function(const phi::CPUContext& ctx,
inner_dim_size *= input_dim[i]; inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]); out_dim_vec.push_back(input_dim[i]);
} }
out_dim_vec.push_back(index_size); if (index->dims().size() != 0) {
out_dim_vec.push_back(index_size);
}
for (int i = axis_index + 1; i < input_dim.size(); i++) { for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i]; outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]); out_dim_vec.push_back(input_dim[i]);
...@@ -224,7 +226,13 @@ void GatherV2GradFunction(const phi::CPUContext& ctx, ...@@ -224,7 +226,13 @@ void GatherV2GradFunction(const phi::CPUContext& ctx,
if (input->numel() == 0) return; if (input->numel() == 0) return;
int axis_index = axis; int axis_index = axis;
int64_t input_index_dim_size = input_dim[axis_index]; int64_t input_index_dim_size;
if (input_dim.size() == out->dims().size()) {
input_index_dim_size = input_dim[axis_index];
} else {
// 0d index
input_index_dim_size = 1;
}
int64_t inner_dim_size = 1; int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1; int64_t outer_dim_size = 1;
......
...@@ -122,7 +122,6 @@ void GPUScatterAssign(const phi::GPUContext& ctx, ...@@ -122,7 +122,6 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
const DenseTensor& index, const DenseTensor& index,
DenseTensor* output, DenseTensor* output,
bool overwrite = true) { bool overwrite = true) {
// check index of shape 1-D
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index.dims()[1], index.dims()[1],
...@@ -132,26 +131,33 @@ void GPUScatterAssign(const phi::GPUContext& ctx, ...@@ -132,26 +131,33 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
"But received value is [%d]", "But received value is [%d]",
index.dims()[1])); index.dims()[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index.dims().size(), PADDLE_ENFORCE_EQ(
1, index.dims().size() == 1 || index.dims().size() == 0,
phi::errors::InvalidArgument( true,
"index.dims().size() should be 1 or 2 in scatter_op." phi::errors::InvalidArgument(
"But received value is [%d]", "index.dims().size() should be 0, 1 or 2 in scatter_op."
index.dims().size())); "But received value is [%d]",
index.dims().size()));
} }
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
phi::DDim output_dims(src_dims); phi::DDim output_dims(src_dims);
output_dims[0] = index_size; output_dims[0] = index_size;
// slice size // slice size
int64_t slice_size = 1; size_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
} else {
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
}
const T* p_src = src.data<T>(); const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>(); const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
const size_t& slice_bytes = slice_size * sizeof(T); const size_t& slice_bytes = slice_size * sizeof(T);
// set block and grid num // set block and grid num
......
...@@ -76,7 +76,6 @@ void ScatterAssign(const phi::CPUContext& ctx, ...@@ -76,7 +76,6 @@ void ScatterAssign(const phi::CPUContext& ctx,
const DenseTensor& src, const DenseTensor& src,
const DenseTensor& index, const DenseTensor& index,
DenseTensor* output) { DenseTensor* output) {
// check index of shape 1-D
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index.dims()[1], index.dims()[1],
...@@ -86,14 +85,15 @@ void ScatterAssign(const phi::CPUContext& ctx, ...@@ -86,14 +85,15 @@ void ScatterAssign(const phi::CPUContext& ctx,
"But received value is [%d]", "But received value is [%d]",
index.dims()[1])); index.dims()[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index.dims().size(), PADDLE_ENFORCE_EQ(index.dims().size() == 1 || index.dims().size() == 0,
1, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op." "index.dims().size() should be 0, 1 or 2 in "
"But received value is [%d]", "scatter_op. But received value is [%d]",
index.dims().size())); index.dims().size()));
} }
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = output->dims(); auto dst_dims = output->dims();
...@@ -102,23 +102,29 @@ void ScatterAssign(const phi::CPUContext& ctx, ...@@ -102,23 +102,29 @@ void ScatterAssign(const phi::CPUContext& ctx,
const IndexT* p_index = index.data<IndexT>(); const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
// check src shape and dst shape should match if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); i++) // check src shape and dst shape should match
PADDLE_ENFORCE_EQ( for (int i = 1; i < src_dims.size(); i++)
src_dims[i], PADDLE_ENFORCE_EQ(
dst_dims[i], src_dims[i],
phi::errors::InvalidArgument( dst_dims[i],
"The dimensions of the source tensor and target tensor should" phi::errors::InvalidArgument(
" match, but received source tensor's %d-th dimension is %d," "The dimensions of the source tensor and target tensor should"
"target tensor's %d-th dimension is %d.", " match, but received source tensor's %d-th dimension is %d,"
i, "target tensor's %d-th dimension is %d.",
src_dims[i], i,
i, src_dims[i],
dst_dims[i])); i,
dst_dims[i]));
}
// slice size // slice size
size_t slice_size = 1; size_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
} else {
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
}
const size_t slice_bytes = slice_size * sizeof(T); const size_t slice_bytes = slice_size * sizeof(T);
...@@ -143,43 +149,48 @@ void ScatterAssignAdd(const phi::CPUContext& ctx, ...@@ -143,43 +149,48 @@ void ScatterAssignAdd(const phi::CPUContext& ctx,
const DenseTensor& src, const DenseTensor& src,
const DenseTensor& index, const DenseTensor& index,
DenseTensor* output) { DenseTensor* output) {
// check index of shape 1-D
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index.dims().size() == 1 || index.dims().size() == 1 || index.dims().size() == 0 ||
(index.dims().size() == 2 && index.dims()[1] == 1), (index.dims().size() == 2 && index.dims()[1] == 1),
true, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"index's shape is error, " "index's shape is error, "
"expect index'dims shape is 1 or 2 and index.dims[1] is 1" "expect index'dims shape is 0, 1, 2 (index.dims[1] should "
"but got index'dims shape is %d", "be 1), but got index'dims shape is %d",
index.dims().size())); index.dims().size()));
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = output->dims(); auto dst_dims = output->dims();
const T* p_src = src.data<T>(); const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>(); const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
// check src shape and dst shape should match if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); i++) // check src shape and dst shape should match
PADDLE_ENFORCE_EQ( for (int i = 1; i < src_dims.size(); i++)
src_dims[i], PADDLE_ENFORCE_EQ(
dst_dims[i], src_dims[i],
phi::errors::InvalidArgument( dst_dims[i],
"The dimensions of the source tensor and target tensor should" phi::errors::InvalidArgument(
" match, but received source tensor's %d-th dimension is %d," "The dimensions of the source tensor and target tensor should"
"target tensor's %d-th dimension is %d.", " match, but received source tensor's %d-th dimension is %d,"
i, "target tensor's %d-th dimension is %d.",
src_dims[i], i,
i, src_dims[i],
dst_dims[i])); i,
dst_dims[i]));
}
// slice size // slice size
size_t slice_size = 1; size_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
} else {
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
}
const size_t& slice_bytes = slice_size * sizeof(T); const size_t& slice_bytes = slice_size * sizeof(T);
......
...@@ -44,10 +44,10 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -44,10 +44,10 @@ void GatherGradKernel(const Context& dev_ctx,
index_dims[1])); index_dims[1]));
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_dims.size(), index_dims.size() == 1 || index_dims.size() == 0,
1, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d", "The index should be 0D or 1D, when it is not 2D, but we get %d",
index_dims.size())); index_dims.size()));
} }
std::vector<int> xshape(x_grad->dims().size()); std::vector<int> xshape(x_grad->dims().size());
...@@ -66,7 +66,7 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -66,7 +66,7 @@ void GatherGradKernel(const Context& dev_ctx,
index.data<int>(), index.data<int>(),
reinterpret_cast<XPUType*>(x_grad->data<T>()), reinterpret_cast<XPUType*>(x_grad->data<T>()),
xshape, xshape,
index.dims()[0], index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v, axis_v,
overwrite); overwrite);
} else { } else {
...@@ -84,7 +84,7 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -84,7 +84,7 @@ void GatherGradKernel(const Context& dev_ctx,
index_int_ptr_l3, index_int_ptr_l3,
reinterpret_cast<XPUType*>(x_grad->data<T>()), reinterpret_cast<XPUType*>(x_grad->data<T>()),
xshape, xshape,
index.dims()[0], index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v, axis_v,
overwrite); overwrite);
} }
......
...@@ -41,10 +41,10 @@ void GatherKernel(const Context& dev_ctx, ...@@ -41,10 +41,10 @@ void GatherKernel(const Context& dev_ctx,
index_dims[1])); index_dims[1]));
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_dims.size(), index_dims.size() == 1 || index_dims.size() == 0,
1, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d", "The index should be 0D, 1D, when it is not 2D, but we get %d",
index_dims.size())); index_dims.size()));
} }
std::vector<int> xshape(x.dims().size()); std::vector<int> xshape(x.dims().size());
...@@ -56,13 +56,14 @@ void GatherKernel(const Context& dev_ctx, ...@@ -56,13 +56,14 @@ void GatherKernel(const Context& dev_ctx,
int r = XPU_SUCCESS; int r = XPU_SUCCESS;
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
r = xpu::gather<XPUType, int>(dev_ctx.x_context(), r = xpu::gather<XPUType, int>(
reinterpret_cast<const XPUType*>(x.data<T>()), dev_ctx.x_context(),
index.data<int>(), reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()), index.data<int>(),
xshape, reinterpret_cast<XPUType*>(out->data<T>()),
index.dims()[0], xshape,
axis_v); index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v);
} else { } else {
r = xpu::gather<XPUType, int64_t>( r = xpu::gather<XPUType, int64_t>(
dev_ctx.x_context(), dev_ctx.x_context(),
...@@ -70,7 +71,7 @@ void GatherKernel(const Context& dev_ctx, ...@@ -70,7 +71,7 @@ void GatherKernel(const Context& dev_ctx,
index.data<int64_t>(), index.data<int64_t>(),
reinterpret_cast<XPUType*>(out->data<T>()), reinterpret_cast<XPUType*>(out->data<T>()),
xshape, xshape,
index.dims()[0], index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v); axis_v);
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -43,30 +43,34 @@ void ScatterKernel(const Context &ctx, ...@@ -43,30 +43,34 @@ void ScatterKernel(const Context &ctx,
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index.dims().size() == 1 || index.dims().size() == 1 || index.dims().size() == 0 ||
(index.dims().size() == 2 && index.dims()[1] == 1), (index.dims().size() == 2 && index.dims()[1] == 1),
true, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"index's shape is error, " "index's shape is error, "
"expect index'dims shape is 1 or 2 and index.dims[1] is 1" "expect index'dims shape is 0, 1, 2 (index.dims[1] should "
"but got index'dims shape is %d", "be 1), 0 but got index'dims shape is %d",
index.dims().size())); index.dims().size()));
int index_size = static_cast<int>(index.dims()[0]); int index_size =
static_cast<int>(index.dims().size() == 0 ? 1 : index.dims()[0]);
auto x_dims = x.dims(); auto x_dims = x.dims();
auto update_dims = updates.dims(); auto update_dims = updates.dims();
for (int i = 1; i < x_dims.size(); i++) if (index.dims().size() != 0) {
PADDLE_ENFORCE_EQ( // only check when the updates tensor is not a 0D tensor
x_dims[i], for (int i = 1; i < x_dims.size(); i++)
update_dims[i], PADDLE_ENFORCE_EQ(
phi::errors::InvalidArgument( x_dims[i],
"The dimensions of the source tensor and target tensor should" update_dims[i],
" match, but received source tensor's %d-th dimension is %d," phi::errors::InvalidArgument(
"target tensor's %d-th dimension is %d.", "The dimensions of the source tensor and target tensor should"
i, " match, but received source tensor's %d-th dimension is %d,"
x_dims[i], "target tensor's %d-th dimension is %d.",
i, i,
update_dims[i])); x_dims[i],
i,
update_dims[i]));
}
int dim0 = static_cast<int>(x.dims()[0]); int dim0 = static_cast<int>(x.dims()[0]);
int dim1 = int dim1 =
......
...@@ -598,6 +598,61 @@ class TestSundryAPI(unittest.TestCase): ...@@ -598,6 +598,61 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0) self.assertEqual(out.numpy(), 0)
def test_gather_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 5)
self.assertEqual(out.grad.shape, [])
def test_gather_xD_axis_0(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index)
out.backward()
self.assertEqual(out.shape, [3])
for i in range(3):
self.assertEqual(out.numpy()[i], x.numpy()[1][i])
self.assertEqual(out.grad.shape, [3])
def test_gather_xD_axis_1(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)
self.assertEqual(out.shape, [2])
for i in range(2):
self.assertEqual(out.numpy()[i], x.numpy()[i][1])
def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
out = paddle.scatter(x, index, updates)
out.backward()
self.assertEqual(out.grad.shape, [5])
self.assertEqual(out.numpy()[2], 4)
def test_scatter_XD(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter(x, index, updates)
out.backward()
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
self.assertEqual(out.grad.shape, [2, 3])
class TestSundryAPIStatic(unittest.TestCase): class TestSundryAPIStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -679,6 +734,68 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -679,6 +734,68 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 0) self.assertEqual(res[0], 0)
@prog_scope()
def test_gather_1D(self):
x = paddle.full([10], 1.0, 'float32')
index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 1)
@prog_scope()
def test_gather_XD_axis_0(self):
x = paddle.full([2, 3], 1.0, 'float32')
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, (3,))
for i in range(3):
self.assertEqual(res[0][i], 1)
@prog_scope()
def test_gather_XD_axis_1(self):
x = paddle.full([2, 3], 1.0, 'float32')
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, (2,))
for i in range(2):
self.assertEqual(res[0][i], 1)
@prog_scope()
def test_scatter_1D(self):
x = paddle.full([10], 1.0, 'float32')
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4, 'float32')
out = paddle.scatter(x, index, updates)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0][2], 4)
@prog_scope()
def test_scatter_XD(self):
x = paddle.full([2, 3], 1.0, 'float32')
index = paddle.full([], 1, 'int64')
updates = paddle.full([3], 4, 'float32')
out = paddle.scatter(x, index, updates)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
for i in range(3):
self.assertEqual(res[0][1][i], 4)
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
...@@ -426,6 +426,55 @@ class TestSundryAPI(unittest.TestCase): ...@@ -426,6 +426,55 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0) self.assertEqual(out.numpy(), 0)
def test_gather_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 5)
self.assertEqual(out.grad.shape, [])
def test_gather_xD_axis_0(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index)
out.backward()
self.assertEqual(out.shape, [3])
for i in range(3):
self.assertEqual(out.numpy()[i], x.numpy()[1][i])
self.assertEqual(out.grad.shape, [3])
def test_gather_xD_axis_1(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)
self.assertEqual(out.shape, [2])
for i in range(2):
self.assertEqual(out.numpy()[i], x.numpy()[i][1])
def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0])
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
out = paddle.scatter(x, index, updates)
self.assertEqual(out.numpy()[2], 4)
def test_scatter_XD(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter(x, index, updates)
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
...@@ -2728,13 +2728,13 @@ def gather(x, index, axis=None, name=None): ...@@ -2728,13 +2728,13 @@ def gather(x, index, axis=None, name=None):
x (Tensor): The source input tensor with rank>=1. Supported data type is x (Tensor): The source input tensor with rank>=1. Supported data type is
int32, int64, float32, float64 and uint8 (only for CPU), int32, int64, float32, float64 and uint8 (only for CPU),
float16 (only for GPU). float16 (only for GPU).
index (Tensor): The index input tensor with rank=1. Data type is int32 or int64. index (Tensor): The index input tensor with rank=0 or rank=1. Data type is int32 or int64.
axis (Tensor|int, optional): The axis of input to be gathered, it's can be int or a Tensor with data type is int32 or int64. The default value is None, if None, the ``axis`` is 0. axis (Tensor|int, optional): The axis of input to be gathered, it's can be int or a Tensor with data type is int32 or int64. The default value is None, if None, the ``axis`` is 0.
name (str, optional): The default value is None. Normally there is no need for user to set this property. name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` . For more information, please refer to :ref:`api_guide_Name` .
Returns: Returns:
output (Tensor), The output is a tensor with the same rank as ``x``. output (Tensor), If the index is a 1-D tensor, the output is a tensor with the same shape as ``x``. If the index is a 0-D tensor, the output will reduce the dimension where the axis pointing.
Examples: Examples:
...@@ -2888,8 +2888,8 @@ def scatter(x, index, updates, overwrite=True, name=None): ...@@ -2888,8 +2888,8 @@ def scatter(x, index, updates, overwrite=True, name=None):
Args: Args:
x (Tensor): The input N-D Tensor with ndim>=1. Data type can be float32, float64. x (Tensor): The input N-D Tensor with ndim>=1. Data type can be float32, float64.
index (Tensor): The index 1-D Tensor. Data type can be int32, int64. The length of index cannot exceed updates's length, and the value in index cannot exceed input's length. index (Tensor): The index is a 1-D or 0-D Tensor. Data type can be int32, int64. The length of index cannot exceed updates's length, and the value in index cannot exceed input's length.
updates (Tensor): update input with updates parameter based on index. shape should be the same as input, and dim value with dim > 1 should be the same as input. updates (Tensor): Update input with updates parameter based on index. When the index is a 1-D tensor, the updates shape should be the same as input, and dim value with dim > 1 should be the same as input. When the index is a 0-D tensor, the updates should be a (N-1)-D tensor, the ith dim of the updates should be queal with the (i+1)th dim of the input.
overwrite (bool): The mode that updating the output when there are same indices. overwrite (bool): The mode that updating the output when there are same indices.
If True, use the overwrite mode to update the output of the same index, If True, use the overwrite mode to update the output of the same index,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册