diff --git a/paddle/fluid/operators/index_add_op.cc b/paddle/fluid/operators/index_add_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b856e479fba5238befb744a9bb4a7a20af204a76 --- /dev/null +++ b/paddle/fluid/operators/index_add_op.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" + +namespace paddle { +namespace operators { + +class IndexAddOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class IndexAddOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), " + "the input feature data of IndexAddOp, dtype should be" + "bool, int32, int64, float16, float32, float64."); + AddInput("Index", + "(Tensor), the 1-D tensor containing the indices to index."); + AddInput("AddValue", "(Tensor), the tensor containing values to add."); + AddOutput( + "Out", + "(Tensor)," + " the output of IndexAddOp, whose dtype and shape are the same as X."); + AddAttr("axis", "the dimension in which we index.").SetDefault(0); + AddComment(R"DOC( + IndexAdd operator + Add the elements of the input tensor with value + by selecting the indices in the order given in index. + This operator also supports inplace modification. + )DOC"); + } +}; + +template +class IndexAddGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType("index_add_grad"); + op->SetInput("Index", this->Input("Index")); + op->SetInput("AddValue", this->Input("AddValue")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("AddValue"), + this->InputGrad("AddValue")); + } +}; + +class IndexAddGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +DECLARE_INPLACE_OP_INFERER(IndexAddInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(IndexAddGradInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(index_add, + IndexAddInferShapeFunctor, + PD_INFER_META(phi::IndexAddInferMeta)); + +REGISTER_OPERATOR(index_add, + ops::IndexAddOp, + ops::IndexAddOpMaker, + ops::IndexAddGradMaker, + ops::IndexAddGradMaker, + ops::IndexAddInplaceInferer, + IndexAddInferShapeFunctor); + +DECLARE_INFER_SHAPE_FUNCTOR(index_add_grad, + IndexAddGradInferShapeFunctor, + PD_INFER_META(phi::IndexAddGradInferMeta)); + +REGISTER_OPERATOR(index_add_grad, + ops::IndexAddGradOp, + ops::IndexAddGradInplaceInferer, + IndexAddGradInferShapeFunctor); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 6fc34172d0434ea5357da51b69d5b8619ec3dc8c..0012f8e426533c584bfe77fc36ac856b5309adbc 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1257,6 +1257,17 @@ func : increment inplace : (x -> out) +- api : index_add + args : (Tensor x, Tensor index, Tensor add_value, int axis) + output : Tensor(out) + infer_meta : + func : IndexAddInferMeta + kernel : + func : index_add + data_type : x + inplace : (x -> out) + backward : index_add_grad + - api : index_sample args : (Tensor x, Tensor index) output : Tensor diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index fe8c37940665e85c370273f6292825e7bfdfa276..465b43786c1f0d13fe2e3bff97b806f434b0ac97 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1092,6 +1092,17 @@ output : Tensor(x_grad) invoke : imag_grad_impl(out_grad, x_grad) +- backward_api : index_add_grad + forward : index_add(Tensor x, Tensor index, Tensor add_value, int axis) -> Tensor(out) + args : (Tensor index, Tensor add_value, Tensor out_grad, int axis) + output : Tensor(x_grad), Tensor(add_value_grad) + infer_meta : + func : IndexAddGradInferMeta + kernel : + func : index_add_grad + data_type : out_grad + inplace : (out_grad -> x_grad) + - backward_api : index_sample_grad forward : index_sample (Tensor x, Tensor index) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 83cf1a713dc551a38bb993658b9199a954ea6659..a1c1a07861bf3aff432ff0a76c7a4dad839e90a1 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -981,4 +981,26 @@ void Yolov3LossGradInferMeta(const MetaTensor& x, } } +void IndexAddGradInferMeta(const MetaTensor& index, + const MetaTensor& add_value, + const MetaTensor& out_grad, + int axis, + MetaTensor* x_grad, + MetaTensor* add_value_grad) { + auto do_dims = out_grad.dims(); + auto add_value_dims = add_value.dims(); + if (x_grad) { + x_grad->set_dims(do_dims); + x_grad->set_dtype(out_grad.dtype()); + x_grad->set_layout(out_grad.layout()); + x_grad->share_lod(out_grad); + } + if (add_value_grad) { + add_value_grad->set_dims(add_value_dims); + add_value_grad->set_dtype(add_value.dtype()); + add_value_grad->set_layout(add_value.layout()); + add_value_grad->share_lod(add_value); + } +} + } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 36edb0e56bafdf42dc3a8c9815bead5430fd6bba..2a11986a3940d140e9580404e53a0be36a0e59ce 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -398,4 +398,11 @@ void Yolov3LossGradInferMeta(const MetaTensor& x, MetaTensor* gt_label_grad, MetaTensor* gt_score_grad); +void IndexAddGradInferMeta(const MetaTensor& index, + const MetaTensor& add_value, + const MetaTensor& out_grad, + int axis, + MetaTensor* x_grad, + MetaTensor* add_tensor_grad); + } // namespace phi diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index ad8897bb4c07848cf285199cb20e9b02395b7b09..25e35edb35c3e13700814d2ce332ce271d233452 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1507,6 +1507,63 @@ void IndexSelectInferMeta(const MetaTensor& x, output->share_lod(x); } +void IndexAddInferMeta(const MetaTensor& x, + const MetaTensor& index, + const MetaTensor& add_value, + int axis, + MetaTensor* output) { + auto input_dim = x.dims(); + auto index_dim = index.dims(); + auto add_value_dim = add_value.dims(); + + PADDLE_ENFORCE_EQ( + axis < input_dim.size() && axis >= (0 - input_dim.size()), + true, + phi::errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(axis) = %d.", + input_dim.size(), + input_dim.size() - 1, + axis)); + + int real_axis = axis >= 0 ? axis : axis + input_dim.size(); + + PADDLE_ENFORCE_EQ(index_dim.size() == 1, + true, + phi::errors::InvalidArgument( + "The 'shape' of Input(Index) must be 1-D tensor. " + "But received: the 'shape' of Input(Index) is [%s], " + "the dimension of Input(Index) is [%d].", + index_dim, + index_dim.size())); + + PADDLE_ENFORCE_EQ( + index_dim[0] != 0, + true, + phi::errors::InvalidArgument("The length of Input(Index) can't be 0.")); + + // Note, add_value does not support broadcast now. + PADDLE_ENFORCE_EQ(input_dim.size() == add_value_dim.size(), + true, + phi::errors::InvalidArgument( + "The add_value must be the same dimension as x.")); + for (int i = 0; i < input_dim.size(); i++) { + if (i != real_axis) { + PADDLE_ENFORCE_EQ(input_dim[i] == add_value_dim[i], + true, + phi::errors::InvalidArgument( + "The add_value parameter does not supported " + "broadcast, so input_dim[i] must be equal to " + "add_value_dim[i] when i != axis.")); + } + } + + output->set_dims(x.dims()); + output->set_dtype(x.dtype()); + output->set_layout(x.layout()); + output->share_lod(x); +} + void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { auto dim_x = x.dims(); auto dim_y = y.dims(); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 10430c289e41d2a7d6d2b1d50e5f24fcf38a830a..591f0e41e0f8ca86674c4e047a3669d71103310a 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -237,6 +237,12 @@ void IndexSelectInferMeta(const MetaTensor& x, int dim, MetaTensor* output); +void IndexAddInferMeta(const MetaTensor& x, + const MetaTensor& index, + const MetaTensor& add_value, + int axis, + MetaTensor* output); + void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void LogLossInferMeta(const MetaTensor& input, diff --git a/paddle/phi/kernels/cpu/index_add_grad_kernel.cc b/paddle/phi/kernels/cpu/index_add_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..64be0927210c9aa60d3c3ab27b93b3e3482b0492 --- /dev/null +++ b/paddle/phi/kernels/cpu/index_add_grad_kernel.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_add_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/cpu/index_select_impl.h" + +namespace phi { + +template +void IndexAddGradKernel(const Context& ctx, + const DenseTensor& index, + const DenseTensor& add_value, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* add_value_grad) { + if (axis < 0) { + axis += out_grad.dims().size(); + } + const auto& index_type = index.dtype(); + + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + // get x_grad: copy out_grad to x_grad. + ctx.template Alloc(x_grad); + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + + auto inputs = out_grad; + // get add_value_grad by using index_select(out_grad, index, axis) + if (index_type == phi::DataType::INT32) { + IndexSelectInner( + ctx, &inputs, index, add_value_grad, axis); + } else if (index_type == phi::DataType::INT64) { + IndexSelectInner( + ctx, &inputs, index, add_value_grad, axis); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(index_add_grad, + CPU, + ALL_LAYOUT, + phi::IndexAddGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/index_add_impl.h b/paddle/phi/kernels/cpu/index_add_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..d9a1b93d7217de27f30337049c9b56b41be68d49 --- /dev/null +++ b/paddle/phi/kernels/cpu/index_add_impl.h @@ -0,0 +1,118 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +template +void IndexAddInner(const Context& ctx, + DenseTensor* input, + const DenseTensor& index, + int axis, + DenseTensor* add_value, + DenseTensor* output) { + auto input_dim = input->dims(); + auto input_dim_size = input_dim.size(); + auto output_dim = output->dims(); + auto index_size = index.dims()[0]; + auto add_value_dim = add_value->dims(); + + const IndexT* index_data = index.data(); + + ctx.template Alloc(output); + + // copy x to output. + // todo(@limin29): inplace do not need copy. + phi::Copy(ctx, *input, ctx.GetPlace(), false, output); + + auto slice_size = 1; + for (auto i = axis + 1; i < input_dim_size; i++) { + slice_size *= input_dim[i]; + } + auto outer_nums = 1; + for (auto i = 0; i < axis; i++) { + outer_nums *= input_dim[i]; + } + + for (int i = 0; i < index_size; i++) { + PADDLE_ENFORCE_GE( + index_data[i], + 0, + phi::errors::InvalidArgument( + "Variable value (index) of OP(index_add) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[axis], + index_data[i])); + PADDLE_ENFORCE_LT( + index_data[i], + input_dim[axis], + phi::errors::InvalidArgument( + "Variable value (index) of OP(index_add) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[axis], + index_data[i])); + } + + VLOG(3) << "Index_Add_Debug; outer_nums: " << outer_nums + << "; slice_size: " << slice_size << "; index_size: " << index_size; + + output->Resize(phi::make_ddim({outer_nums, input_dim[axis], slice_size})); + add_value->Resize(phi::make_ddim({outer_nums, index_size, slice_size})); + VLOG(3) << "output.dims: " << output->dims() + << ", add_value.dims: " << add_value->dims(); + + auto add_value_tensor = EigenTensor::From(*add_value); + auto output_tensor = EigenTensor::From(*output); + + auto& place = *ctx.eigen_device(); + for (auto j = 0; j < index_size; j++) { + IndexT index_value = index_data[j]; + auto output_t = output_tensor.chip(index_value, 1); + output_t.device(place) = output_t + add_value_tensor.chip(j, 1); + } + output->Resize(output_dim); + add_value->Resize(add_value_dim); +} + +template +void IndexAddBaseKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + int axis, + const DenseTensor& add_value, + DenseTensor* output) { + const auto& index_type = index.dtype(); + if (axis < 0) { + axis += x.dims().size(); + } + auto inputs = x; + auto add_values = add_value; + if (index_type == phi::DataType::INT32) { + IndexAddInner( + dev_ctx, &inputs, index, axis, &add_values, output); + } else if (index_type == phi::DataType::INT64) { + IndexAddInner( + dev_ctx, &inputs, index, axis, &add_values, output); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/index_add_kernel.cc b/paddle/phi/kernels/cpu/index_add_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e75e6a8d0a0a255b6aaf0080cbdef3ee93e0e5f --- /dev/null +++ b/paddle/phi/kernels/cpu/index_add_kernel.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_add_kernel.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +// #include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/cpu/index_add_impl.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void IndexAddKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& add_value, + int axis, + DenseTensor* output) { + IndexAddBaseKernel(dev_ctx, x, index, axis, add_value, output); +} + +} // namespace phi + +PD_REGISTER_KERNEL(index_add, + CPU, + ALL_LAYOUT, + phi::IndexAddKernel, + float, + double, + phi::dtype::float16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/index_add_grad_kernel.cu b/paddle/phi/kernels/gpu/index_add_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..1afcb59f8f1c4f82c1ed9b3ace2262413059f3ef --- /dev/null +++ b/paddle/phi/kernels/gpu/index_add_grad_kernel.cu @@ -0,0 +1,108 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_add_grad_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpu/index_select_impl.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +void IndexAddGradKernel(const Context& ctx, + const DenseTensor& index, + const DenseTensor& add_value, + const DenseTensor& out_grad, + int dim, + DenseTensor* x_grad, + DenseTensor* add_value_grad) { + auto* output_grad_data = out_grad.data(); + auto* in_grad_data = ctx.template Alloc(x_grad); + auto* add_value_grad_data = ctx.template Alloc(add_value_grad); + + auto input_dim = x_grad->dims(); + auto output_dim = out_grad.dims(); + auto add_value_dim = add_value_grad->dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = phi::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = add_value_dim[dim]; + int64_t delta = input_dim[dim] - size; + const auto& index_type = index.dtype(); + + bool index_type_match = + index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + int64_t numel = add_value_grad->numel(); + if (numel == 0) { + return; + } + auto stream = ctx.stream(); + + // get x_grad: copy out_grad to x_grad. + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + + // get add_value_grad: index_select(out_grad, index, axis) + unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; + dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); + + if (index_type == phi::DataType::INT64) { + const int64_t* index_data = index.data(); + index_select_cuda_kernel + <<>>(output_grad_data, + add_value_grad_data, + index_data, + numel, + stride, + size, + delta); + } else { + const int* index_data = index.data(); + index_select_cuda_kernel + <<>>(output_grad_data, + add_value_grad_data, + index_data, + numel, + stride, + size, + delta); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(index_add_grad, + GPU, + ALL_LAYOUT, + phi::IndexAddGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/index_add_kernel.cu b/paddle/phi/kernels/gpu/index_add_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..109027d6f4e16dddb9c11f5fe2922544808917d6 --- /dev/null +++ b/paddle/phi/kernels/gpu/index_add_kernel.cu @@ -0,0 +1,128 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_add_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void index_add_cuda_kernel(const T* input, + const IndexT* index, + const T* add_value, + int64_t N, + int64_t stride, + int64_t size, + int64_t delta, + T* output) { + CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = + idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + paddle::platform::CudaAtomicAdd(&output[input_idx], add_value[idx]); + } +} + +template +void IndexAddKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& add_value, + int axis, + DenseTensor* output) { + int dim = axis; + auto input_dim = x.dims(); + auto output_dim = output->dims(); + auto add_value_dim = add_value.dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = phi::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = add_value_dim[dim]; + int64_t delta = input_dim[dim] - size; + const auto& index_type = index.dtype(); + + bool index_type_match = + index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + auto* in_data = x.data(); + T* out_data = ctx.template Alloc(output); + auto* add_value_data = add_value.data(); + + int64_t numel = add_value.numel(); + if (numel == 0) { + return; + } + auto stream = ctx.stream(); + + unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; + dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); + + // copy input to output. + // todo(@limin29): inplace do not need copy. + phi::Copy(ctx, x, ctx.GetPlace(), false, output); + + if (index_type == phi::DataType::INT64) { + const int64_t* index_data = index.data(); + index_add_cuda_kernel + <<>>(in_data, + index_data, + add_value_data, + numel, + stride, + size, + delta, + out_data); + } else { + const int* index_data = index.data(); + index_add_cuda_kernel + <<>>(in_data, + index_data, + add_value_data, + numel, + stride, + size, + delta, + out_data); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(index_add, + GPU, + ALL_LAYOUT, + phi::IndexAddKernel, + float, + double, + phi::dtype::float16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/index_select_impl.h b/paddle/phi/kernels/gpu/index_select_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..fc631b651540fb1f002c533da8e5b142cdc30ddc --- /dev/null +++ b/paddle/phi/kernels/gpu/index_select_impl.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void index_select_cuda_kernel(const T* input, + T* output, + const IndexT* index, + int64_t N, + int64_t stride, + int64_t size, + int64_t delta) { + CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = + idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + output[idx] = input[input_idx]; + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/index_select_kernel.cu b/paddle/phi/kernels/gpu/index_select_kernel.cu index 0a6ac69cef0981edebb2d273c04e103e9679e3ad..e9228b54edf7c171f115a13fe4954d639db862ac 100644 --- a/paddle/phi/kernels/gpu/index_select_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_kernel.cu @@ -19,29 +19,12 @@ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/gpu/index_select_impl.h" namespace phi { using paddle::platform::PADDLE_CUDA_NUM_THREADS; -template -__global__ void index_select_cuda_kernel(const T* input, - T* output, - const IndexT* index, - int64_t N, - int64_t stride, - int64_t size, - int64_t delta) { - CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { - int64_t pre_idx = idx / (stride * size); - int64_t dim_idx = idx % (stride * size) / stride; - IndexT src_dim_idx = index[dim_idx]; - int64_t input_idx = - idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; - output[idx] = input[input_idx]; - } -} - template void IndexSelectKernel(const Context& ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/index_add_grad_kernel.h b/paddle/phi/kernels/index_add_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3ba130c60cebae40eff92e84d2f2deea224d6f92 --- /dev/null +++ b/paddle/phi/kernels/index_add_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void IndexAddGradKernel(const Context& ctx, + const DenseTensor& index, + const DenseTensor& add_value, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* add_value_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/index_add_kernel.h b/paddle/phi/kernels/index_add_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..62693af8229426a3f54d30a21cc4472c695596f4 --- /dev/null +++ b/paddle/phi/kernels/index_add_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void IndexAddKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& add_value, + int axis, + DenseTensor* output); +} // namespace phi diff --git a/paddle/phi/ops/compat/index_add_sig.cc b/paddle/phi/ops/compat/index_add_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..39b231406fa66793289bbe291d019fdd21184bc3 --- /dev/null +++ b/paddle/phi/ops/compat/index_add_sig.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature IndexAddOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "index_add", {"X", "Index", "AddValue"}, {"axis"}, {"Out"}); +} + +KernelSignature IndexAddGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("index_add_grad", + {"Index", "AddValue", "Out@GRAD"}, + {"axis"}, + {"X@GRAD", "AddValue@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(index_add, phi::IndexAddOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(index_add_grad, phi::IndexAddGradOpArgumentMapping); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e419f09479a9c9aa616d8d72b5ed51836cfbf22f..61b9674ca730d3063f8faff3a061bb98d8a06573 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -186,6 +186,8 @@ from .tensor.manipulation import as_complex # noqa: F401 from .tensor.manipulation import as_real # noqa: F401 from .tensor.manipulation import moveaxis # noqa: F401 from .tensor.manipulation import repeat_interleave # noqa: F401 +from .tensor.manipulation import index_add # noqa: F401 +from .tensor.manipulation import index_add_ # noqa: F401 from .tensor.math import abs # noqa: F401 from .tensor.math import acos # noqa: F401 from .tensor.math import asin # noqa: F401 @@ -655,6 +657,8 @@ __all__ = [ # noqa 'put_along_axis', 'heaviside', 'tril_indices', + 'index_add', + "index_add_", 'sgn', 'triu_indices', 'take', diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index d1b82387f9dfc236c129d91071f60a9cde3711fc..1c413d72cd7a226036f7860bd3489815618689ee 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1149,6 +1149,7 @@ endif() set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 200) set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_index_add_op PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_ssa_graph_inference_feed_partial_data PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_crf PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_index_add_op.py b/python/paddle/fluid/tests/unittests/test_index_add_op.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6aca4a45b0cb7a5c97851955bbc2e07a64e7fb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_index_add_op.py @@ -0,0 +1,362 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +def compute_index_add_ref(axis, x_shape, x_np, add_value_shape, add_value_np, + index_size, index_np): + if axis < 0: + axis = axis + len(x_shape) + if axis != 0: + outer_loop = np.prod(x_shape[:axis]).astype(int) + x_reshape = [outer_loop] + list(x_shape[axis:]) + x_np_reshape = np.reshape(x_np, tuple(x_reshape)) + + add_value_reshape = [np.prod(add_value_shape[:axis]).astype(int) + ] + list(add_value_shape[axis:]) + + add_value_np_reshape = np.reshape(add_value_np, + tuple(add_value_reshape)) + else: + x_np_reshape = x_np + add_value_np_reshape = add_value_np + out_np = x_np_reshape.copy() + + if axis != 0: + for i in range(outer_loop): + for j in range(index_size): + out_np[i, index_np[j]] += add_value_np_reshape[i, j] + else: + for j in range(index_size): + out_np[index_np[j]] += add_value_np_reshape[j] + ref_out = np.reshape(out_np, x_shape) + return ref_out + + +def raw_index_add(x, index, value, axis): + return paddle.index_add(x, index, axis, value) + + +class TestIndexAddOp(OpTest): + + def setUp(self): + self.python_api = raw_index_add + self.op_type = "index_add" + self.init_dtype_type() + index_np = np.random.randint(low=0, + high=self.x_shape[self.axis], + size=self.index_size) + x_np = np.random.random(self.x_shape).astype(self.x_type) + add_value_np = np.random.random(self.add_value_shape).astype( + self.x_type) + + self.inputs = {'X': x_np, 'Index': index_np, 'AddValue': add_value_np} + self.attrs = {'axis': self.axis} + out = compute_index_add_ref(self.axis, self.x_shape, x_np, + self.add_value_shape, add_value_np, + self.index_size, index_np) + self.outputs = {'Out': out} + + def init_dtype_type(self): + self.axis = 0 + self.x_type = np.float64 + self.index_type = np.int64 + self.x_shape = (101, 3) + self.index_size = 3 + self.add_value_shape = (3, 3) + + def test_check_output(self): + self.check_output(check_eager=True, atol=1e-2) + + def test_check_grad_normal(self): + self.check_grad(['X', 'AddValue'], 'Out', check_eager=True) + + +class TestIndexAddAPI(unittest.TestCase): + + def setUp(self): + self.setType() + self.setPlace() + self.config() + self.check_backward = True + self.generate_input_data() + + self.index_shape = tuple([self.index_size]) + + self.rtol = 1e-5 + self.atol = 1e-2 + if self.x_type is np.float16: + self.atol = 1e-1 + + def setType(self): + self.x_type = np.float32 + self.index_type = np.int32 + + def setPlace(self): + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def config(self): + self.axis = 0 + self.x_shape = (100, 5) + self.index_size = 20 + self.add_value_shape = (20, 5) + + def generate_input_data(self): + axis = self.axis + if self.axis < 0: + axis = self.axis + len(self.x_shape) + + self.x_np = np.random.random(self.x_shape).astype(self.x_type) + self.add_value_np = np.random.random(self.add_value_shape).astype( + self.x_type) + self.index_np = np.random.randint(low=0, + high=self.x_shape[axis], + size=self.index_size).astype( + self.index_type) + if self.check_backward: + self.dout_np = np.random.random(self.x_shape).astype(self.x_type) + + def compute_index_add_backward_ref(self): + axis = self.axis + if self.axis < 0: + axis = self.axis + len(self.x_shape) + + x_grad = self.dout_np + + dout_tensor = paddle.to_tensor(self.dout_np) + index = paddle.to_tensor(self.index_np) + add_value_grad = paddle.index_select(dout_tensor, index, axis) + + return x_grad, add_value_grad.numpy() + + def run_imperative(self, device): + paddle.device.set_device(device) + input_tensor = paddle.to_tensor(self.x_np, stop_gradient=False) + index = paddle.to_tensor(self.index_np) + add_value = paddle.to_tensor(self.add_value_np, stop_gradient=False) + + out = paddle.index_add(input_tensor, index, self.axis, add_value) + ref_out = compute_index_add_ref(self.axis, self.x_shape, self.x_np, + self.add_value_shape, self.add_value_np, + self.index_size, self.index_np) + np.testing.assert_allclose(ref_out, + out.numpy(), + rtol=self.rtol, + atol=self.atol) + + if self.check_backward: + dout_tensor = paddle.to_tensor(self.dout_np) + paddle.autograd.backward([out], [dout_tensor], retain_graph=True) + ref_x_grad, ref_add_value_grad = self.compute_index_add_backward_ref( + ) + np.testing.assert_allclose(ref_x_grad, + input_tensor.grad.numpy(), + rtol=self.rtol, + atol=self.atol) + np.testing.assert_allclose(ref_add_value_grad, + add_value.grad.numpy(), + rtol=self.rtol, + atol=self.atol) + + def run_static(self, device): + x = paddle.static.data(name='X', shape=self.x_shape, dtype=self.x_type) + index = paddle.static.data(name='Index', + shape=self.index_shape, + dtype=self.index_type) + add_value = paddle.static.data(name='AddValue', + shape=self.add_value_shape, + dtype=self.x_type) + + out = paddle.index_add(x, index, self.axis, add_value) + + if device == "cpu": + place = paddle.CPUPlace() + elif device == "gpu": + place = paddle.CUDAPlace(0) + else: + raise TypeError( + "paddle.index_add api only support cpu and gpu device now.") + + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + res = exe.run(paddle.static.default_main_program(), + feed={ + "X": self.x_np, + "Index": self.index_np, + "AddValue": self.add_value_np, + }, + fetch_list=[out.name], + return_numpy=False) + return res + + def test_static(self): + paddle.enable_static() + for device in self.place: + with paddle.static.program_guard(Program()): + out = self.run_static(device) + ref_out = compute_index_add_ref(self.axis, self.x_shape, self.x_np, + self.add_value_shape, + self.add_value_np, self.index_size, + self.index_np) + np.testing.assert_allclose(ref_out, + np.array(out[0]), + rtol=self.rtol, + atol=self.atol) + + def test_dynamic(self): + paddle.disable_static() + for device in self.place: + self.run_imperative(device) + + +class TestIndexAddAPIMoreType(TestIndexAddAPI): + + def setType(self): + self.x_type = np.float64 + self.index_type = np.int64 + + +class TestIndexAddAPICase2(TestIndexAddAPI): + + def config(self): + self.axis = 1 + self.x_shape = (100, 100, 5) + self.index_size = 20 + self.add_value_shape = (100, 20, 5) + + +class TestIndexAddAPICase3(TestIndexAddAPI): + + def config(self): + self.axis = 2 + self.x_shape = (100, 100, 25) + self.index_size = 20 + self.add_value_shape = (100, 100, 20) + + +class TestIndexAddAPICase4(TestIndexAddAPI): + + def config(self): + self.axis = 0 + self.x_shape = (10, ) + self.index_size = 4 + self.add_value_shape = (4, ) + + +class TestIndexAddAPICase5(TestIndexAddAPI): + + def config(self): + self.axis = -1 + self.x_shape = (10, 10) + self.index_size = 4 + self.add_value_shape = (10, 4) + + +class TestIndexAddAPIError(unittest.TestCase): + + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + def test_add_value_shape(): + axis = 0 + x = paddle.static.data(name='X', + shape=[10, 10], + dtype="float64") + index = paddle.static.data(name='Index', + shape=[4], + dtype="int32") + add_value = paddle.static.data(name='AddValue', + shape=[4, 3], + dtype="float64") + out = paddle.index_add(x, index, axis, add_value) + + self.assertRaises(ValueError, test_add_value_shape) + + def test_index_dtype(): + axis = 0 + x = paddle.static.data(name='X1', + shape=[10, 10], + dtype="float64") + index = paddle.static.data(name='Index1', + shape=[4], + dtype="float32") + add_value = paddle.static.data(name='AddValue1', + shape=[4, 10], + dtype="float64") + out = paddle.index_add(x, index, axis, add_value) + + self.assertRaises(TypeError, test_index_dtype) + + def test_index_shape(): + axis = 0 + x = paddle.static.data(name='X2', + shape=[10, 10], + dtype="float64") + index = paddle.static.data(name='Index2', + shape=[4, 3], + dtype="int32") + add_value = paddle.static.data(name='AddValue2', + shape=[4, 10], + dtype="float64") + out = paddle.index_add(x, index, axis, add_value) + + self.assertRaises(ValueError, test_index_shape) + + def test_axis_value(): + axis = 3 + x = paddle.static.data(name='X3', + shape=[10, 10], + dtype="float64") + index = paddle.static.data(name='Index3', + shape=[4], + dtype="int32") + add_value = paddle.static.data(name='AddValue3', + shape=[4, 10], + dtype="float64") + out = paddle.index_add(x, index, axis, add_value) + + self.assertRaises(ValueError, test_axis_value) + + def test_add_value_broadcast(): + axis = 0 + x = paddle.static.data(name='X4', + shape=[10, 10], + dtype="float64") + index = paddle.static.data(name='Index4', + shape=[4], + dtype="int32") + add_value = paddle.static.data(name='AddValue4', + shape=[4], + dtype="float64") + out = paddle.index_add(x, index, axis, add_value) + + self.assertRaises(ValueError, test_add_value_broadcast) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py index 626ea6c2ae0a30404b3e25e55ef853a678adce6e..a3a3eca128cb0a3882e7a0d175456e88ac9efca8 100644 --- a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py @@ -13,20 +13,8 @@ # limitations under the License. NEED_TO_FIX_OP_LIST = [ - 'fused_elemwise_activation', - 'bilinear_tensor_product', - 'conv2d_transpose', - 'depthwise_conv2d_transpose', - 'grid_sampler', - 'lstmp', - 'margin_rank_loss', - 'matmul', - 'scatter', - 'soft_relu', - 'squared_l2_distance', - 'tree_conv', - 'cvm', - 'cudnn_lstm', - 'rnn', - 'multi_dot', + 'fused_elemwise_activation', 'bilinear_tensor_product', 'conv2d_transpose', + 'depthwise_conv2d_transpose', 'grid_sampler', 'lstmp', 'margin_rank_loss', + 'matmul', 'scatter', 'soft_relu', 'squared_l2_distance', 'tree_conv', 'cvm', + 'cudnn_lstm', 'rnn', 'multi_dot', 'index_add' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index a5c06cee8509378782dda9e0bdf5a310864f3233..93decb4278efaf9d0e727f769ee46812cb17ea4d 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -128,6 +128,8 @@ from .manipulation import put_along_axis_ # noqa: F401 from .manipulation import as_real # noqa: F401 from .manipulation import moveaxis # noqa: F401 from .manipulation import repeat_interleave # noqa: F401 +from .manipulation import index_add # noqa: F401 +from .manipulation import index_add_ # noqa: F401 from .math import abs # noqa: F401 from .math import acos # noqa: F401 from .math import asin # noqa: F401 @@ -506,6 +508,8 @@ tensor_method_func = [ # noqa 'put_along_axis_', 'exponential_', 'heaviside', + 'index_add', + "index_add_", 'take', 'bucketize', 'sgn', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 555ce0f4270375c702123e4f05c58310f1d9ec6a..ef7a20911da538fb764a019d5b46279589791813 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4386,6 +4386,120 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'): "Reduce", reduce) +def _index_add_params_check(x, index, input_axis, add_value): + dims = len(x.shape) + add_value_dims = len(add_value.shape) + + if input_axis >= 0: + axis = input_axis + else: + axis = input_axis + dims + + check_axis = axis + if check_axis >= dims or check_axis < -dims: + raise ValueError("Axis should be in range [-rank(x), rank(x)).") + + if isinstance(index, Variable): + if index.dtype not in [paddle.int64, paddle.int32]: + raise TypeError("The index dtype should be int32 or int64.") + if len(index.shape) != 1: + raise ValueError("The index should be a 1-D Tensor.") + + if dims != add_value_dims: + raise ValueError( + "The add_value does not support broadcast now. It must have the same dimension as x." + ) + for i in range(dims): + if i != axis and x.shape[i] != add_value.shape[i]: + raise ValueError( + "The add_value.shape[i] should be equal to x.shape[i] when i != axis." + ) + + +def index_add(x, index, axis, value, name=None): + """ + Adds the elements of the input tensor with value tensor by selecting the indices in the order given in index. + + Args: + x (Tensor) : The Destination Tensor. Supported data types are int32, int64, float16, float32, float64. + index (Tensor): The 1-D Tensor containing the indices to index. + The data type of ``index`` must be int32 or int64. + axis (int): The dimension in which we index. + value (Tensor): The tensor used to add the elements along the target axis. + name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor: same dimention and dtype with x. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + + input_tensor = paddle.to_tensor(paddle.ones((3, 3)), dtype="float32") + index = paddle.to_tensor([0, 2], dtype="int32") + value = paddle.to_tensor([[1, 1, 1], [1, 1, 1]], dtype="float32") + outplace_res = paddle.index_add(input_tensor, index, 0, value) + print(outplace_res.numpy()) + # [[2 2 2] + # [1 1 1] + # [2 2 2]] + """ + _index_add_params_check(x, index, axis, value) + + if in_dygraph_mode(): + return _C_ops.index_add(x, index, value, axis) + + helper = LayerHelper("index_add", **locals()) + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], + 'paddle.tensor.manipulation.index_add') + check_variable_and_dtype(index, 'index', ['int32', 'int64'], + 'paddle.tensor.manipulation.index_add') + check_variable_and_dtype( + value, 'add_value', ['float16', 'float32', 'float64', 'int32', 'int64'], + 'paddle.tensor.manipulation.index_add') + + out = helper.create_variable_for_type_inference(x.dtype) + + helper.append_op(type='index_add', + inputs={ + 'X': x, + 'Index': index, + 'AddValue': value, + }, + outputs={'Out': out}, + attrs={'axis': axis}) + return out + + +@inplace_apis_in_dygraph_only +def index_add_(x, index, axis, value, name=None): + """ + Inplace version of ``index_add`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_tensor_index_add`. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + + input_tensor = paddle.to_tensor(paddle.ones((3, 3)), dtype="float32") + index = paddle.to_tensor([0, 2], dtype="int32") + value = paddle.to_tensor([[1, 1], [1, 1], [1, 1]], dtype="float32") + inplace_res = paddle.index_add_(input_tensor, index, 1, value) + print(inplace_res.numpy()) + # [[2, 1, 2] + # [2, 1, 2] + # [2, 1, 2]] + """ + + _index_add_params_check(x, index, axis, value) + return _C_ops.index_add_(x, index, value, axis) + + # TODO(dev): We need avoid implementing it by this way. __METHODS = { 'fill_': fill_, diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index 7b793c4cf83bd9666a5a1c03d20a455a66856c8c..5e5003cf53e166791a0cbdb894a962c23ab27621 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -845,7 +845,7 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [ 'test_normalization_wrapper', 'test_flip', 'test_cosine_similarity_api', 'test_cumsum_op', 'test_range', 'test_log_loss_op', 'test_where_index', 'test_tril_triu_op', 'test_lod_reset_op', 'test_lod_tensor', - 'test_addmm_op', 'test_index_select_op', 'test_nvprof', + 'test_addmm_op', 'test_index_select_op', 'test_index_add_op', 'test_nvprof', 'test_index_sample_op', 'test_unstack_op', 'test_increment', 'strided_memcpy_test', 'test_target_assign_op', 'test_trt_dynamic_shape_transformer_prune',