From f3393f494541aca303e3292c1cb78104e90d89c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Wed, 10 May 2023 19:23:32 +0800 Subject: [PATCH] add index_put api (#52886) * add index_put api * fix value broadcast in backward and add test case in static * add timeout=120s for index_put * add op_compat for index_put * add inplace index_put test * add test case when index tensor in indices is int32 when indices.size less than x.dims * add index_put api backward in cpu place * add backward test case * refactor code to delete some duplicated code * replace reshape with resize for decrease extra memcpy * add datatype flag in backward yaml * fix bug in documentation * Update python/paddle/tensor/manipulation.py --------- Co-authored-by: Ligoml <39876205+Ligoml@users.noreply.github.com> --- paddle/phi/api/yaml/backward.yaml | 11 + paddle/phi/api/yaml/ops.yaml | 11 + paddle/phi/infermeta/multiary.cc | 16 +- paddle/phi/infermeta/multiary.h | 6 + .../phi/kernels/cpu/index_put_grad_kernel.cc | 225 +++++ paddle/phi/kernels/cpu/index_put_kernel.cc | 166 ++++ paddle/phi/kernels/funcs/index_put_utils.h | 348 ++++++++ .../phi/kernels/gpu/index_put_grad_kernel.cu | 287 ++++++ paddle/phi/kernels/gpu/index_put_kernel.cu | 198 +++++ paddle/phi/kernels/index_put_grad_kernel.h | 30 + paddle/phi/kernels/index_put_kernel.h | 29 + python/paddle/__init__.py | 4 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../tests/unittests/test_index_put_op.py | 826 ++++++++++++++++++ python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/manipulation.py | 103 ++- 16 files changed, 2263 insertions(+), 2 deletions(-) create mode 100644 paddle/phi/kernels/cpu/index_put_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/index_put_kernel.cc create mode 100644 paddle/phi/kernels/funcs/index_put_utils.h create mode 100644 paddle/phi/kernels/gpu/index_put_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/index_put_kernel.cu create mode 100644 paddle/phi/kernels/index_put_grad_kernel.h create mode 100644 paddle/phi/kernels/index_put_kernel.h create mode 100644 python/paddle/fluid/tests/unittests/test_index_put_op.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index bde673e60b6..6ae8c190064 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -796,6 +796,17 @@ data_type : out_grad inplace : (out_grad -> x_grad) +- backward_op : index_put_grad + forward : index_put (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) -> Tensor(out) + args : (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false) + output : Tensor(x_grad), Tensor(value_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, value] + kernel : + func : index_put_grad + data_type : out_grad + - backward_op : index_sample_grad forward : index_sample (Tensor x, Tensor index) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c254a75f98e..de9d6c4f29c 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -940,6 +940,17 @@ inplace : (x -> out) backward : index_add_grad +- op : index_put + args : (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) + output : Tensor(out) + infer_meta : + func : IndexPutInferMeta + kernel : + func : index_put + data_type : x + inplace : (x -> out) + backward : index_put_grad + - op : index_sample args : (Tensor x, Tensor index) output : Tensor diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 5a8e38e21fd..8ea2dc65d9a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1962,6 +1962,21 @@ void InterpolateInferMeta( } } +void IndexPutInferMeta(const MetaTensor& x, + const std::vector& indices, + const MetaTensor& value, + bool accumulate, + MetaTensor* out) { + auto in_dims = x.dims(); + PADDLE_ENFORCE_LT( + in_dims.size(), + 7, + phi::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", + in_dims.size())); + out->share_meta(x); +} + void LambInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, @@ -3295,6 +3310,5 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, out_count->set_dims({-1}); out_count->set_dtype(DataType::INT32); } - } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 993e6c21ff6..2a924ecbb30 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -332,6 +332,12 @@ void InterpolateInferMeta( MetaTensor* output, MetaConfig config = MetaConfig()); +void IndexPutInferMeta(const MetaTensor& x, + const std::vector& indices, + const MetaTensor& value, + bool accumulate, + MetaTensor* out); + void LambInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc new file mode 100644 index 00000000000..7374bcd403d --- /dev/null +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -0,0 +1,225 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_put_grad_kernel.h" +#include +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/funcs/index_put_utils.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +namespace phi { + +template +void set_zero_kernel(const int64_t N, + const int64_t** indices, + const phi::DDim& stride, + const phi::DDim& shape, + T* out) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t idx = 0; idx < N; ++idx) { + int64_t cur_ix = 0; + int64_t offset = 0; + + for (int i = 0; i < shape.size(); ++i) { + cur_ix = (static_cast(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + *(out + offset) = 0; + } +} + +template +void index_put_grad_kernel(const int64_t N, + const T* out_grad, + const int64_t** indices, + const phi::DDim& stride, + const phi::DDim& shape, + T* value_grad) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t idx = 0; idx < N; ++idx) { + int64_t cur_ix = 0; + int64_t offset = 0; + + for (int i = 0; i < shape.size(); ++i) { + cur_ix = (static_cast(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + *(value_grad + idx) = *(out_grad + offset); + } +} + +template +void LaunchIndexPutGradKernel(const Context& dev_ctx, + const std::vector& indices, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* value_grad, + DenseTensor* x_grad) { + const int64_t* pd_indices[7]; + for (size_t i = 0; i < indices.size(); ++i) { + pd_indices[i] = indices[i]->data(); + } + + if (x_grad) { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); + if (!accumulate) { + T* x_grad_data = x_grad->data(); + + auto x_grad_dims = x_grad->dims(); + const int64_t numel = indices[0]->numel(); + auto x_grad_stride = phi::stride(x_grad_dims); + + set_zero_kernel( + numel, pd_indices, x_grad_stride, x_grad_dims, x_grad_data); + } + } + + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices[0]->numel(); + auto out_grad_stride = phi::stride(out_grad_dims); + + if (value_grad) { + if (value_grad->numel() == 1) { + DenseTensor tmp_value_grad(value_grad->dtype()); + tmp_value_grad.Resize(indices[0]->dims()); + + T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); + auto out_grad_data = out_grad.data(); + + index_put_grad_kernel(numel, + out_grad_data, + pd_indices, + out_grad_stride, + out_grad_dims, + tmp_value_grad_data); + + std::vector v_dims(tmp_value_grad.dims().size()); + std::iota(v_dims.begin(), v_dims.end(), 0); + IntArray v_axis(v_dims); + SumKernel(dev_ctx, + tmp_value_grad, + v_axis, + value_grad->dtype(), + false, + value_grad); + } else if (value_grad->numel() == indices[0]->numel()) { + T* value_grad_data = dev_ctx.template Alloc(value_grad); + auto out_grad_data = out_grad.data(); + + index_put_grad_kernel(numel, + out_grad_data, + pd_indices, + out_grad_stride, + out_grad_dims, + value_grad_data); + } else { + DenseTensor tmp_value_grad(value_grad->dtype()); + tmp_value_grad.Resize(indices[0]->dims()); + + T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); + auto out_grad_data = out_grad.data(); + + index_put_grad_kernel(numel, + out_grad_data, + pd_indices, + out_grad_stride, + out_grad_dims, + tmp_value_grad_data); + + std::vector after_dims = phi::vectorize(tmp_value_grad.dims()); + std::vector before_dims = phi::vectorize(value_grad->dims()); + std::vector compress_dims; + std::vector dims_without_1; + + funcs::CalCompressedDimsWith1AndWithout1( + &after_dims, &before_dims, &compress_dims, &dims_without_1); + + auto pre_dims = value_grad->dims(); + value_grad->Resize(phi::make_ddim(dims_without_1)); + IntArray v_axis(compress_dims); + SumKernel(dev_ctx, + tmp_value_grad, + v_axis, + value_grad->dtype(), + false, + value_grad); + value_grad->Resize(pre_dims); + } + } +} + +template +void IndexPutGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices, + const DenseTensor& value, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* x_grad, + DenseTensor* value_grad) { + PADDLE_ENFORCE_EQ( + x.dtype(), + value.dtype(), + phi::errors::InvalidArgument( + "The data type of tensor in indices must be same to the data type " + "of tensor x.")); + std::vector tmp_args; + std::vector int_indices_v = + funcs::DealWithBoolIndices(dev_ctx, indices, &tmp_args); + auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); + + std::vector res_dim_v(phi::vectorize(bd_dim)); + std::vector res_indices_v(x.dims().size(), nullptr); + std::vector tmp_res_indices_v; + std::vector range_tensor_v; + + for (int i = indices.size(); i < x.dims().size(); ++i) { + range_tensor_v.emplace_back(funcs::GetRangeTensor( + dev_ctx, x.dims()[i], phi::DataType::INT64)); + } + + funcs::DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); + + LaunchIndexPutGradKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); +} +} // namespace phi + +PD_REGISTER_KERNEL(index_put_grad, + CPU, + ALL_LAYOUT, + phi::IndexPutGradKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc new file mode 100644 index 00000000000..da3e37ac242 --- /dev/null +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_put_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/funcs/index_put_utils.h" + +namespace phi { + +template +void index_put_kernel(const int64_t N, + const T* x, + const T* vals, + const int64_t** indices, + const phi::DDim& stride, + const phi::DDim& shape, + int64_t is_single_val_tensor, + bool accumulate, + T* out) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t idx = 0; idx < N; ++idx) { + int64_t cur_ix = 0; + int64_t offset = 0; + + for (int i = 0; i < shape.size(); ++i) { + cur_ix = (static_cast(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + + if (accumulate) { + *(out + offset) += *(vals + (idx & is_single_val_tensor)); + } else { + *(out + offset) = *(vals + (idx & is_single_val_tensor)); + } + } +} + +template +void LaunchIndexPutKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices, + const DenseTensor& value, + bool accumulate, + DenseTensor* out) { + auto* x_data = x.data(); + auto* val_data = value.data(); + bool is_initialized = out->initialized(); + T* out_data = dev_ctx.template Alloc(out); + + if (!is_initialized) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + } + + auto x_dims = x.dims(); + const int64_t numel = indices[0]->numel(); + auto x_stride = phi::stride(x_dims); + + int64_t is_single_val_tensor = (value.numel() == 1) ? 0 : INT64_MAX; + + const int64_t* pd_indices[7]; + for (size_t i = 0; i < indices.size(); ++i) { + pd_indices[i] = indices[i]->data(); + } + + index_put_kernel(numel, + x_data, + val_data, + pd_indices, + x_stride, + x_dims, + is_single_val_tensor, + accumulate, + out_data); +} + +template +void IndexPutKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices, + const DenseTensor& value, + bool accumulate, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + x.dtype(), + value.dtype(), + phi::errors::InvalidArgument( + "The data type of tensor in indices must be same to the data type " + "of tensor x.")); + PADDLE_ENFORCE_EQ(indices.empty(), + false, + phi::errors::InvalidArgument("Indices cannot be empty.")); + + const size_t total_dims = x.dims().size(); + PADDLE_ENFORCE_LE(total_dims, + 6, + phi::errors::InvalidArgument( + "Dims of input tensor should be less than 7.")); + + std::vector tmp_args; + std::vector int_indices_v = + funcs::DealWithBoolIndices(dev_ctx, indices, &tmp_args); + + auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); + + std::vector res_dim_v(phi::vectorize(bd_dim)); + std::vector res_indices_v(x.dims().size(), nullptr); + std::vector tmp_res_indices_v; + std::vector tmp_value_v; + std::vector range_tensor_v; + const DenseTensor* ptr_value = nullptr; + + for (int i = indices.size(); i < x.dims().size(); ++i) { + range_tensor_v.emplace_back(funcs::GetRangeTensor( + dev_ctx, x.dims()[i], phi::DataType::INT64)); + } + + funcs::DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); + if (value.numel() != 1) { + tmp_value_v.emplace_back( + DenseTensor(value.dtype()).Resize(phi::make_ddim(res_dim_v))); + ExpandKernel( + dev_ctx, value, IntArray(res_dim_v), &tmp_value_v[0]); + ptr_value = &tmp_value_v[0]; + } else { + ptr_value = &value; + } + + LaunchIndexPutKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); +} +} // namespace phi + +PD_REGISTER_KERNEL(index_put, + CPU, + ALL_LAYOUT, + phi::IndexPutKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h new file mode 100644 index 00000000000..51e918c8523 --- /dev/null +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -0,0 +1,348 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/nonzero_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" +#include "paddle/phi/kernels/split_kernel.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#ifdef __NVCC__ +#include +#include +#elif defined(__HIPCC__) +#include +#endif +#endif + +namespace phi { + +namespace funcs { + +template +phi::DenseTensor GetReshapeAndExpandTensor(const Context& dev_ctx, + const phi::DenseTensor& tensor, + const phi::DDim& res_dim, + const phi::DDim& bd_dim, + int index) { + std::vector before_dims = phi::vectorize(tensor.dims()); + std::vector mid_dims(res_dim.size(), 1); + + if (index == 0) { + for (size_t i = 0; i < before_dims.size(); ++i) { + mid_dims[bd_dim.size() - i - 1] = before_dims[before_dims.size() - i - 1]; + } + } else { + mid_dims[index] = before_dims[0]; + } + + phi::DenseTensor mid_tensor(tensor.dtype()); + mid_tensor.Resize(phi::make_ddim(mid_dims)); + ReshapeInferKernel(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor); + + phi::DenseTensor res_tensor(tensor.dtype()); + res_tensor.Resize(res_dim); + ExpandKernel( + dev_ctx, mid_tensor, IntArray(phi::vectorize(res_dim)), &res_tensor); + return res_tensor; +} + +template +std::vector DealWithBoolIndices( + const Context& dev_ctx, + const std::vector& indices_v, + std::vector* tmp_indices_v) { + std::vector res(indices_v.begin(), indices_v.end()); + bool contains_bool_tensor = false; + for (size_t i = 0; i < indices_v.size(); ++i) { + if (indices_v[i]->dtype() == phi::DataType::BOOL) { + contains_bool_tensor = true; + } else if ((indices_v[i]->dtype() == phi::DataType::INT64) || + (indices_v[i]->dtype() == phi::DataType::INT32)) { + PADDLE_ENFORCE_EQ( + contains_bool_tensor, + false, + phi::errors::InvalidArgument( + "indices contains bool tensor and int32/int64 tensor at the same " + "time")); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "data type of tensor in indices must be int32, int64 or bool")); + } + } + + if (contains_bool_tensor) { + if (indices_v.size() != 1) { + PADDLE_THROW(phi::errors::InvalidArgument( + "the size of indices must be 1 when it containts bool tensor")); + } + int rank = indices_v[0]->dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1UL, + phi::errors::InvalidArgument("the only bool tensor in indices should " + "have number of dimension at least 1")); + phi::DenseTensor nonzero_indices(phi::DataType::INT64); + nonzero_indices.Resize(phi::make_ddim({-1, rank})); + NonZeroKernel(dev_ctx, *indices_v[0], &nonzero_indices); + + std::vector integer_indices(rank, nullptr); + for (int i = 0; i < rank; ++i) { + tmp_indices_v->emplace_back( + DenseTensor(phi::DataType::INT64) + .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); + } + for (int i = 0; i < rank; ++i) { + integer_indices[i] = &((*tmp_indices_v)[i]); + } + SplitWithNumKernel( + dev_ctx, nonzero_indices, rank, 1, integer_indices); + + std::vector res_tmp(integer_indices.size(), + nullptr); + for (int i = 0; i < rank; ++i) { + res_tmp[i] = &((*tmp_indices_v)[i]); + } + res.swap(res_tmp); + } + return res; +} + +static phi::DDim BroadCastTensorsDims( + const std::vector& tensors) { + int target_rank = 0; + for (const auto& tensor : tensors) { + target_rank = std::max(target_rank, tensor->dims().size()); + } + + PADDLE_ENFORCE_GT(target_rank, + 0, + errors::InvalidArgument("BroadCastTensorsDims requires at " + "least one input tensor to have " + "rank greater than zero")); + + std::vector target_dims(target_rank, 0); + for (int index = 0; index < target_rank; index++) { + int target_dim_size = 1; + for (const auto& tensor : tensors) { + auto input_ddim = tensor->dims(); + int axis = static_cast(input_ddim.size()) - index - 1; + int dim_size = 1; + if (axis >= 0) { + dim_size = input_ddim[axis]; + } + + if (target_dim_size != 1 && dim_size != 1 && + target_dim_size != dim_size) { + PADDLE_THROW(errors::InvalidArgument( + "BroadCastTensorsDims inputs does not satisfy bcast semantics, " + "please check axis = %d in reverse order", + index)); + } + + target_dim_size = dim_size == 1 ? target_dim_size : dim_size; + } + target_dims[target_rank - index - 1] = target_dim_size; + } + return phi::make_ddim(target_dims); +} + +template +T** GetDevicePointerArray(const Context& ctx, + const std::vector& indices_v) { + std::vector h_indices_v(indices_v.size()); + for (int i = 0; i < indices_v.size(); ++i) { + h_indices_v[i] = indices_v[i]->data(); + } + auto d_indices_data = phi::memory_utils::Alloc( + ctx.GetPlace(), + h_indices_v.size() * sizeof(T*), + phi::Stream(reinterpret_cast(ctx.stream()))); + phi::memory_utils::Copy(ctx.GetPlace(), + d_indices_data->ptr(), + phi::CPUPlace(), + reinterpret_cast(h_indices_v.data()), + h_indices_v.size() * sizeof(T*), + ctx.stream()); + return reinterpret_cast(d_indices_data->ptr()); +} + +template +void DealWithIndices(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& int_indices_v, + std::vector* res_indices_v, + std::vector* tmp_res_indices_v, + const std::vector& range_tensor_v, + const phi::DDim& bd_dim, + std::vector* res_dim_v) { + size_t total_dims = x.dims().size(); + if (int_indices_v.size() < total_dims) { + std::vector tmp_x_dims = phi::vectorize(x.dims()); + int len_bd_dim = bd_dim.size(); + res_dim_v->insert(res_dim_v->end(), + tmp_x_dims.begin() + int_indices_v.size(), + tmp_x_dims.end()); + + std::vector reshaped_indices_v; + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + reshaped_indices_v.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + reshaped_indices_v.emplace_back(*int_indices_v[i]); + } + } + reshaped_indices_v.insert( + reshaped_indices_v.end(), range_tensor_v.begin(), range_tensor_v.end()); + + phi::DDim res_dim = phi::make_ddim(*res_dim_v); + + for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { + tmp_res_indices_v->emplace_back( + GetReshapeAndExpandTensor( + dev_ctx, + reshaped_indices_v[i], + res_dim, + bd_dim, + ((i < int_indices_v.size()) + ? 0 + : i - int_indices_v.size() + len_bd_dim))); + } + for (size_t i = 0; i < res_indices_v->size(); ++i) { + (*res_indices_v)[i] = &(*tmp_res_indices_v)[i]; + } + + } else { + std::vector int_indices_v_tmp; + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + int_indices_v_tmp.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + int_indices_v_tmp.emplace_back(*int_indices_v[i]); + } + } + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (bd_dim != int_indices_v[i]->dims()) { + tmp_res_indices_v->emplace_back( + DenseTensor(phi::DataType::INT64).Resize(bd_dim)); + ExpandKernel( + dev_ctx, + int_indices_v_tmp[i], + IntArray(phi::vectorize(bd_dim)), + &(*tmp_res_indices_v)[i]); + } else { + tmp_res_indices_v->emplace_back(int_indices_v_tmp[i]); + } + } + + for (size_t i = 0; i < res_indices_v->size(); ++i) { + (*res_indices_v)[i] = &(*tmp_res_indices_v)[i]; + } + } +} + +static void CalCompressedDimsWith1AndWithout1( + std::vector* after_dims, + std::vector* before_dims, + std::vector* compress_dims, + std::vector* dims_without_1) { + int i = static_cast(after_dims->size()) - 1; + int j = static_cast(before_dims->size()) - 1; + if (i < j) { + PADDLE_THROW(phi::errors::InvalidArgument( + "shape of value can't not be broadcast to shape of x[indices]")); + } + + while ((i >= 0) && (j >= 0)) { + if ((*after_dims)[i] == (*before_dims)[j]) { + dims_without_1->push_back((*before_dims)[j]); + i--; + j--; + continue; + } else if ((*before_dims)[j] == 1) { + compress_dims->push_back(i); + i--; + j--; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "shape of value can't not be broadcast to shape of x[indices]")); + } + } + while (i >= 0) { + compress_dims->push_back(i); + i--; + } +} + +#if defined(__NVCC__) || defined(__HIPCC__) +template +__global__ void range_cuda_kernel(int64_t N, T* out) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + if (idx >= N) { + return; + } + out[idx] = idx; +} + +template +phi::DenseTensor GetRangeCudaTensor(const Context& dev_ctx, + int64_t N, + phi::DataType dtype) { + phi::DenseTensor res(dtype); + res.Resize(phi::make_ddim({N})); + DenseTensor* p_res = &res; + T* out = dev_ctx.template Alloc(p_res); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N); + range_cuda_kernel + <<>>( + N, out); + return res; +} +#endif + +template +void range_kernel(int64_t N, T* out) { + for (int64_t idx = 0; idx < N; ++idx) { + out[idx] = idx; + } +} + +template +phi::DenseTensor GetRangeTensor(const Context& dev_ctx, + int64_t N, + phi::DataType dtype) { + phi::DenseTensor res(dtype); + res.Resize(phi::make_ddim({N})); + DenseTensor* p_res = &res; + T* out = dev_ctx.template Alloc(p_res); + range_kernel(N, out); + return res; +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu new file mode 100644 index 00000000000..7ae1e42c067 --- /dev/null +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -0,0 +1,287 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_put_grad_kernel.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/funcs/index_put_utils.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +namespace phi { + +template +__global__ void set_zero_cuda_kernel(const int64_t N, + int64_t** indices, + phi::Array stride, + phi::Array shape, + T* out) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t cur_ix = 0; + + if (idx >= N) { + return; + } + int64_t offset = 0; + for (int i = 0; i < Rank; ++i) { + cur_ix = (static_cast(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + + *(out + offset) = 0; +} + +template +__global__ void index_put_grad_cuda_kernel(const int64_t N, + const T* out_grad, + int64_t** indices, + phi::Array stride, + phi::Array shape, + T* value_grad) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t cur_ix = 0; + + if (idx >= N) { + return; + } + int64_t offset = 0; + for (int i = 0; i < Rank; ++i) { + cur_ix = (static_cast(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + + *(value_grad + idx) = *(out_grad + offset); +} + +template +void LaunchIndexPutGradCudaKernel( + const Context& dev_ctx, + const std::vector& indices, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* value_grad, + DenseTensor* x_grad) { + if (x_grad) { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); + if (!accumulate) { + T* x_grad_data = x_grad->data(); + + auto x_grad_dims = x_grad->dims(); + const int64_t numel = indices[0]->numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + auto x_grad_stride = phi::stride(x_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = x_grad_stride[idx]; + shape_a[idx] = x_grad_dims[idx]; + } + + auto pd_indices = + funcs::GetDevicePointerArray(dev_ctx, indices); + set_zero_cuda_kernel<<>>( + numel, pd_indices, stride_a, shape_a, x_grad_data); + } + } + + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices[0]->numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + auto out_grad_stride = phi::stride(out_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = out_grad_stride[idx]; + shape_a[idx] = out_grad_dims[idx]; + } + + auto pd_indices = + funcs::GetDevicePointerArray(dev_ctx, indices); + + if (value_grad) { + if (value_grad->numel() == 1) { + DenseTensor tmp_value_grad(value_grad->dtype()); + tmp_value_grad.Resize(indices[0]->dims()); + + T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); + auto out_grad_data = out_grad.data(); + + index_put_grad_cuda_kernel + <<>>(numel, + out_grad_data, + pd_indices, + stride_a, + shape_a, + tmp_value_grad_data); + + std::vector v_dims(tmp_value_grad.dims().size()); + std::iota(v_dims.begin(), v_dims.end(), 0); + IntArray v_axis(v_dims); + SumKernel(dev_ctx, + tmp_value_grad, + v_axis, + value_grad->dtype(), + false, + value_grad); + } else if (value_grad->numel() == indices[0]->numel()) { + T* value_grad_data = dev_ctx.template Alloc(value_grad); + auto out_grad_data = out_grad.data(); + + index_put_grad_cuda_kernel<<>>( + numel, out_grad_data, pd_indices, stride_a, shape_a, value_grad_data); + } else { + DenseTensor tmp_value_grad(value_grad->dtype()); + tmp_value_grad.Resize(indices[0]->dims()); + + T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); + auto out_grad_data = out_grad.data(); + + index_put_grad_cuda_kernel + <<>>(numel, + out_grad_data, + pd_indices, + stride_a, + shape_a, + tmp_value_grad_data); + + std::vector after_dims = phi::vectorize(tmp_value_grad.dims()); + std::vector before_dims = phi::vectorize(value_grad->dims()); + std::vector compress_dims; + std::vector dims_without_1; + + funcs::CalCompressedDimsWith1AndWithout1( + &after_dims, &before_dims, &compress_dims, &dims_without_1); + + auto pre_dims = value_grad->dims(); + value_grad->Resize(phi::make_ddim(dims_without_1)); + IntArray v_axis(compress_dims); + SumKernel(dev_ctx, + tmp_value_grad, + v_axis, + value_grad->dtype(), + false, + value_grad); + value_grad->Resize(pre_dims); + } + } +} + +template +void IndexPutGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices, + const DenseTensor& value, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* x_grad, + DenseTensor* value_grad) { + PADDLE_ENFORCE_EQ( + x.dtype(), + value.dtype(), + phi::errors::InvalidArgument( + "The data type of tensor in indices must be same to the data type " + "of tensor x.")); + std::vector tmp_args; + std::vector int_indices_v = + funcs::DealWithBoolIndices(dev_ctx, indices, &tmp_args); + const size_t total_dims = x.dims().size(); + auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); + + std::vector res_dim_v(phi::vectorize(bd_dim)); + std::vector res_indices_v(x.dims().size(), nullptr); + std::vector tmp_res_indices_v; + std::vector range_tensor_v; + + for (int i = indices.size(); i < x.dims().size(); ++i) { + range_tensor_v.emplace_back(funcs::GetRangeCudaTensor( + dev_ctx, x.dims()[i], phi::DataType::INT64)); + } + + funcs::DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); + + switch (total_dims) { + case 1: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 2: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 3: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 4: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 5: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 6: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "dims of input tensor should be less than 7, But received" + "%d", + x.dims().size())); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(index_put_grad, + GPU, + ALL_LAYOUT, + phi::IndexPutGradKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu new file mode 100644 index 00000000000..ad27993c352 --- /dev/null +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -0,0 +1,198 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_put_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/funcs/index_put_utils.h" + +namespace phi { + +template +__global__ void index_put_cuda_kernel(const int64_t N, + const T* x, + const T* vals, + int64_t** indices, + phi::Array stride, + phi::Array shape, + int64_t is_single_val_tensor, + bool accumulate, + T* out) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t cur_ix = 0; + + if (idx >= N) { + return; + } + int64_t offset = 0; + for (int i = 0; i < Rank; ++i) { + cur_ix = (static_cast(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + + if (accumulate) { + *(out + offset) += *(vals + (idx & is_single_val_tensor)); + } else { + *(out + offset) = *(vals + (idx & is_single_val_tensor)); + } +} + +template +void LaunchIndexPutCudaKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices, + const DenseTensor& value, + bool accumulate, + DenseTensor* out) { + auto* x_data = x.data(); + auto* val_data = value.data(); + bool is_initialized = out->initialized(); + T* out_data = dev_ctx.template Alloc(out); + + if (!is_initialized) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + } + + auto x_dims = x.dims(); + const int64_t numel = indices[0]->numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + auto x_stride = phi::stride(x_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = x_stride[idx]; + shape_a[idx] = x_dims[idx]; + } + + int64_t is_single_val_tensor = (value.numel() == 1) ? 0 : INT64_MAX; + + auto pd_indices = + funcs::GetDevicePointerArray(dev_ctx, indices); + index_put_cuda_kernel + <<>>( + numel, + x_data, + val_data, + pd_indices, + stride_a, + shape_a, + is_single_val_tensor, + accumulate, + out_data); +} + +template +void IndexPutKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices, + const DenseTensor& value, + bool accumulate, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + x.dtype(), + value.dtype(), + phi::errors::InvalidArgument( + "The data type of tensor in indices must be same to the data type " + "of tensor x.")); + PADDLE_ENFORCE_EQ(indices.empty(), + false, + phi::errors::InvalidArgument("Indices cannot be empty.")); + std::vector tmp_args; + std::vector int_indices_v = + funcs::DealWithBoolIndices(dev_ctx, indices, &tmp_args); + const size_t total_dims = x.dims().size(); + auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); + + std::vector res_dim_v(phi::vectorize(bd_dim)); + std::vector res_indices_v(x.dims().size(), nullptr); + std::vector tmp_res_indices_v; + std::vector tmp_value_v; + std::vector range_tensor_v; + const DenseTensor* ptr_value = nullptr; + + for (int i = indices.size(); i < x.dims().size(); ++i) { + range_tensor_v.emplace_back(funcs::GetRangeCudaTensor( + dev_ctx, x.dims()[i], phi::DataType::INT64)); + } + + funcs::DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); + + if (value.numel() != 1) { + tmp_value_v.emplace_back( + DenseTensor(value.dtype()).Resize(phi::make_ddim(res_dim_v))); + ExpandKernel( + dev_ctx, value, IntArray(res_dim_v), &tmp_value_v[0]); + ptr_value = &tmp_value_v[0]; + } else { + ptr_value = &value; + } + + switch (total_dims) { + case 1: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 2: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 3: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 4: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 5: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 6: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "dims of input tensor should be less than 7, But received" + "%d", + x.dims().size())); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(index_put, + GPU, + ALL_LAYOUT, + phi::IndexPutKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/index_put_grad_kernel.h b/paddle/phi/kernels/index_put_grad_kernel.h new file mode 100644 index 00000000000..575b7df5f27 --- /dev/null +++ b/paddle/phi/kernels/index_put_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void IndexPutGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* x_grad, + DenseTensor* value_grad); +} // namespace phi diff --git a/paddle/phi/kernels/index_put_kernel.h b/paddle/phi/kernels/index_put_kernel.h new file mode 100644 index 00000000000..4410a508244 --- /dev/null +++ b/paddle/phi/kernels/index_put_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void IndexPutKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + bool accumulate, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 673a855c795..8804cfff6de 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -198,6 +198,8 @@ from .tensor.manipulation import moveaxis # noqa: F401 from .tensor.manipulation import repeat_interleave # noqa: F401 from .tensor.manipulation import index_add # noqa: F401 from .tensor.manipulation import index_add_ # noqa: F401 +from .tensor.manipulation import index_put # noqa: F401 +from .tensor.manipulation import index_put_ # noqa: F401 from .tensor.manipulation import unflatten # noqa: F401 from .tensor.math import abs # noqa: F401 from .tensor.math import acos # noqa: F401 @@ -684,6 +686,8 @@ __all__ = [ # noqa 'tril_indices', 'index_add', "index_add_", + "index_put", + "index_put_", 'sgn', 'triu_indices', 'take', diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e7d11ed8b16..8ca0bc01ab3 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -916,6 +916,7 @@ set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 200) set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120) set_tests_properties(test_index_add_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_index_put_op PROPERTIES TIMEOUT 120) set_tests_properties(test_tensordot PROPERTIES TIMEOUT 200) set_tests_properties(test_partial_eager_deletion_transformer PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py new file mode 100644 index 00000000000..5f0257f2553 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -0,0 +1,826 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +import numpy as np + +import paddle +from paddle.fluid import Program + + +def compute_index_put_ref(x_np, indices_np, value_np, accumulate=False): + if accumulate: + x_np[indices_np] += value_np + return x_np + else: + x_np[indices_np] = value_np + return x_np + + +def raw_index_put(x, indices, value, accummulate): + return paddle.index_put(x, indices, value, accummulate) + + +def has_duplicate_index(indices, shapes): + bd_shape = np.broadcast_shapes(*shapes) + bd_indices = [ + list(np.broadcast_to(indice, bd_shape).flatten()) for indice in indices + ] + + zip_res = list(zip(*bd_indices)) + if len(zip_res) == len(set(zip_res)): + return False + else: + return True + + +def gen_indices_np(x_shape, indices_shapes, index_type): + indices = [] + if index_type == np.bool_: + indice = np.zeros(indices_shapes[0], dtype=np.bool_) + indice.flatten() + for i in range(len(indice)): + indice[i] = (i & 1) == 0 + indice = indice.reshape(indices_shapes[0]) + indices.append(indice) + else: + while True: + indices = [] + for i in range(len(indices_shapes)): + np.random.seed() + index_np = np.random.randint( + low=0, + high=x_shape[i], + size=indices_shapes[i], + dtype=index_type, + ) + indices.append(index_np) + if not has_duplicate_index( + copy.deepcopy(indices), copy.deepcopy(indices_shapes) + ): + break + return tuple(indices) + + +class TestIndexPutAPIBase(unittest.TestCase): + def setUp(self): + self.init_dtype_type() + self.setPlace() + self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) + self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) + self.indices_np = gen_indices_np( + self.x_shape, self.indices_shapes, self.index_type_np + ) + + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + def setPlace(self): + self.place = ['cpu'] + if self.dtype_np is np.float16: + self.place = [] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def test_dygraph_forward(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) + self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) + self.indices_pd = [ + paddle.to_tensor(indice, dtype=self.index_type_pd) + for indice in self.indices_np + ] + self.indices_pd = tuple(self.indices_pd) + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + def test_static_forward(self): + paddle.enable_static() + for place in self.place: + with paddle.static.program_guard(Program()): + x = paddle.static.data( + name="x", shape=self.x_shape, dtype=self.dtype_pd + ) + indices = tuple( + [ + paddle.static.data( + name="indice" + str(i), + shape=self.indices_shapes[i], + dtype=self.index_type_pd, + ) + for i in range(len(self.indices_shapes)) + ] + ) + value = paddle.static.data( + name="value", shape=self.value_shape, dtype=self.dtype_pd + ) + + out = paddle.index_put(x, indices, value, self.accumulate) + exe = paddle.static.Executor(place=place) + feed_list = {} + feed_list.update({"x": self.x_np}) + for i in range(len(indices)): + feed_list.update({"indice" + str(i): self.indices_np[i]}) + feed_list.update({"value": self.value_np}) + pd_res = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + )[0] + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res, atol=1e-7) + + +class TestIndexPutAPI0(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = True + + +class TestIndexPutAPI1(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16), (1, 16)) + self.value_shape = (16, 16) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + +class TestIndexPutAPI2(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16), (1, 16)) + self.value_shape = (16, 16) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = True + + +class TestIndexPutAPI3(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (110, 94) + self.indices_shapes = [(110, 94)] + self.value_shape = (5170,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = False + + +class TestIndexPutAPI4(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (110, 94) + self.indices_shapes = [(110, 94)] + self.value_shape = (5170,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = True + + +class TestIndexPutAPI5(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + +class TestIndexPutAPI6(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = True + + +class TestIndexPutAPI7(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (110, 94) + self.indices_shapes = [(110,)] + self.value_shape = (55, 94) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = False + + +class TestIndexPutAPI8(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (110, 94) + self.indices_shapes = [(110,)] + self.value_shape = (55, 94) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = True + + +class TestIndexPutAPI9(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (56,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + +class TestIndexPutAPI10(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (56,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = True + + +class TestIndexPutAPI11(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (1,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + +class TestIndexPutAPI12(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (1,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = True + + +class TestIndexPutAPI13(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (44, 94) + self.indices_shapes = [(44,)] + self.value_shape = (94,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = False + + +class TestIndexPutAPI14(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (44, 94) + self.indices_shapes = [(44,)] + self.value_shape = (94,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = True + + +class TestIndexPutAPI15(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (44, 94) + self.indices_shapes = [(44,)] + self.value_shape = (1,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = False + + +class TestIndexPutAPI16(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (44, 94) + self.indices_shapes = [(44,)] + self.value_shape = (1,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = True + + +class TestIndexPutAPI17(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI18(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI19(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float32 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float32 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI20(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float32 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float32 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI21(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float16 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float16 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI22(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float16 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float16 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI23(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.int32 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int32 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI24(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.int32 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int32 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI25(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.int64 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int64 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI26(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.int64 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int64 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI27(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.bool_ + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.bool + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI28(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.bool_ + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.bool + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI29(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI30(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutInplaceAPI(unittest.TestCase): + def setUp(self): + self.init_dtype_type() + self.setPlace() + self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) + self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) + self.indices_np = gen_indices_np( + self.x_shape, self.indices_shapes, self.index_type_np + ) + + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + def setPlace(self): + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def test_dygraph_forward(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) + self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) + self.indices_pd = [ + paddle.to_tensor(indice, dtype=self.index_type_pd) + for indice in self.indices_np + ] + self.indices_pd = tuple(self.indices_pd) + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + x_pd_bk = self.x_pd.clone() + pd_res = paddle.index_put_( + x_pd_bk, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + np.testing.assert_allclose(ref_res, x_pd_bk.numpy(), atol=1e-7) + + +class TestIndexPutInplaceAPI1(TestIndexPutInplaceAPI): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = True + + +class TestIndexPutAPIBackward(unittest.TestCase): + def setUp(self): + self.setPlace() + + def setPlace(self): + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def test_backward(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[4], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + out = paddle.index_put(x, (ix1, ix2), value, True) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + def test_backwardScalarVal(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[1], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 + ) + + out = paddle.index_put(x, (ix1, ix2), value, True) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 + ) + + def test_backwardBroadCastValue(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[2], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([2.0, 2.0], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + out = paddle.index_put(x, (ix1, ix2), value, True) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([2.0, 2.0], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + def test_backwardBroadCastValue1(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[1, 2], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([[2.0, 2.0]], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + out = paddle.index_put(x, (ix1, ix2), value, True) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([[2.0, 2.0]], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + def test_backwardBroadCastValue2(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[2, 1], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([[2.0], [2.0]], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + out = paddle.index_put(x, (ix1, ix2), value, True) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([[2.0], [2.0]], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index f8d75c9651d..a399fad5a8f 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -135,6 +135,8 @@ from .manipulation import moveaxis # noqa: F401 from .manipulation import repeat_interleave # noqa: F401 from .manipulation import index_add # noqa: F401 from .manipulation import index_add_ # noqa: F401 +from .manipulation import index_put # noqa: F401 +from .manipulation import index_put_ # noqa: F401 from .manipulation import unflatten # noqa: F401 from .math import abs # noqa: F401 from .math import acos # noqa: F401 @@ -534,6 +536,8 @@ tensor_method_func = [ # noqa 'heaviside', 'index_add', "index_add_", + 'index_put', + 'index_put_', 'take', 'bucketize', 'sgn', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index cc48d8a83dc..7fe717383cd 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4795,6 +4795,108 @@ def index_add_(x, index, axis, value, name=None): return _C_ops.index_add_(x, index, value, axis) +@inplace_apis_in_dygraph_only +def index_put_(x, indices, value, accumulate=False, name=None): + """ + Puts values from the tensor values into the tensor x using the indices specified in indices (which is a tuple of Tensors). + The expression paddle.index_put_(x, indices, values) is equivalent to tensor[indices] = values. Returns x. + If accumulate is True, the elements in values are added to x. If accumulate is False, the behavior is undefined if indices contain duplicate elements. + + Args: + x (Tensor) : The Source Tensor. Supported data types are int32, int64, float16, float32, float64, bool. + indices (Tuple of Tensor): The tuple of Tensor containing the indices to index. + The data type of ``tensor in indices`` must be int32, int64 or bool + value (Tensor): The tensor used to be assigned to x. + accummulate (Bool, optional): Whether the elements in values are added to x. Default: False. + name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor, same dimention and dtype with x. + + Examples: + .. code-block:: python + import paddle + + x = paddle.zeros([3, 3]) + value = paddle.ones([3]) + ix1 = paddle.to_tensor([0,1,2]) + ix2 = paddle.to_tensor([1,2,1]) + indices=(ix1,ix2) + + out = paddle.index_put_(x,indices,value) + print(x) + # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0., 1., 0.], + # [0., 0., 1.], + # [0., 1., 0.]]) + print(out) + # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0., 1., 0.], + # [0., 0., 1.], + # [0., 1., 0.]]) + """ + return _C_ops.index_put_(x, indices, value, accumulate) + + +def index_put(x, indices, value, accumulate=False, name=None): + """ + Outplace version of ``index_put_`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_index_put`. + + Examples: + .. code-block:: python + import paddle + + x = paddle.zeros([3, 3]) + value = paddle.ones([3]) + ix1 = paddle.to_tensor([0,1,2]) + ix2 = paddle.to_tensor([1,2,1]) + indices=(ix1,ix2) + + out = paddle.index_put(x,indices,value) + print(x) + # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0., 0., 0.], + # [0., 0., 0.], + # [0., 0., 0.]]) + print(out) + # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0., 1., 0.], + # [0., 0., 1.], + # [0., 1., 0.]]) + """ + if in_dygraph_mode(): + return _C_ops.index_put(x, indices, value, accumulate) + + helper = LayerHelper("index_put", **locals()) + check_variable_and_dtype( + x, + 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], + 'paddle.tensor.manipulation.index_put', + ) + check_variable_and_dtype( + value, + 'value', + ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], + 'paddle.tensor.manipulation.index_put', + ) + + out = helper.create_variable_for_type_inference(x.dtype) + + helper.append_op( + type='index_put', + inputs={ + 'x': x, + 'indices': indices, + 'value': value, + }, + outputs={'out': out}, + attrs={'accumulate': accumulate}, + ) + return out + + def unflatten(x, axis, shape, name=None): """ Expand a certain dimension of the input x Tensor into a desired shape. @@ -4840,7 +4942,6 @@ def unflatten(x, axis, shape, name=None): # determine whether the input axis is valid. axis = non_negative_axis(x, axis) - if isinstance(shape, (list, tuple)): new_shape = ( list(x.shape[:axis]) + list(shape) + list(x.shape[axis + 1 :]) -- GitLab