From 3310f519abf83a16b382709e31b2b8c640479556 Mon Sep 17 00:00:00 2001 From: huangxu96 <46740794+huangxu96@users.noreply.github.com> Date: Tue, 28 Dec 2021 15:16:52 +0800 Subject: [PATCH] Add API and op for take_along_axis (#38396) * add API and op for take_along_axis * fix compile dependency problem and add example code and doc * add unitest * delete some code for CI coverage * fix code style problem * fix as review --- paddle/fluid/operators/CMakeLists.txt | 8 +- .../fluid/operators/gather_scatter_kernel.cc | 148 +++++++++++++++++ .../fluid/operators/gather_scatter_kernel.cu | 157 ++++++++++++++++++ .../fluid/operators/gather_scatter_kernel.h | 57 +++++++ paddle/fluid/operators/take_along_axis_op.cc | 154 +++++++++++++++++ paddle/fluid/operators/take_along_axis_op.cu | 98 +++++++++++ paddle/fluid/operators/take_along_axis_op.h | 92 ++++++++++ python/paddle/__init__.py | 1 + .../unittests/test_take_along_axis_op.py | 111 +++++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/manipulation.py | 54 ++++++ 11 files changed, 881 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/gather_scatter_kernel.cc create mode 100644 paddle/fluid/operators/gather_scatter_kernel.cu create mode 100644 paddle/fluid/operators/gather_scatter_kernel.h create mode 100644 paddle/fluid/operators/take_along_axis_op.cc create mode 100644 paddle/fluid/operators/take_along_axis_op.cu create mode 100644 paddle/fluid/operators/take_along_axis_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_take_along_axis_op.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 1f1bc01c40d..985f870ded4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -87,7 +87,13 @@ if(WITH_UNITY_BUILD) include(unity_build_rule.cmake) endif() -set(OP_HEADER_DEPS ${OP_HEADER_DEPS} pten pten_api_utils) +if (WITH_ROCM) + hip_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor) +else() + cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor) +endif() + +set(OP_HEADER_DEPS ${OP_HEADER_DEPS} pten pten_api_utils gather_scatter_kernel) register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/fluid/operators/gather_scatter_kernel.cc b/paddle/fluid/operators/gather_scatter_kernel.cc new file mode 100644 index 00000000000..6c2ce23c368 --- /dev/null +++ b/paddle/fluid/operators/gather_scatter_kernel.cc @@ -0,0 +1,148 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather_scatter_kernel.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class TensorAssign { + public: + template + void operator()(tensor_t* self_data, tensor_t* src_data) const { + *self_data = *src_data; + } +}; +static TensorAssign tensor_assign; + +class ReduceAdd { + public: + template + void operator()(tensor_t* self_data, tensor_t* src_data) const { + *self_data += *src_data; + } +}; + +static ReduceAdd reduce_add; + +template +struct cpu_gather_scatter_functor { + template + void operator()(Tensor self, int dim, const Tensor& index, const Tensor& src, + const std::string& method_name, const func_t& reduce_op, + const platform::DeviceContext& ctx) { + if (index.numel() == 0) { + return; + } + auto* self_data = self.data(); + auto* index_data = index.data(); + auto* src_data = src.data(); + int64_t self_size = self.numel(); + int64_t index_size = index.numel(); + int64_t src_size = src.numel(); + auto self_dims = self.dims(); + auto index_dims = index.dims(); + auto src_dims = src.dims(); + if (self_size == 0 || src_size == 0 || index_size == 0) { + VLOG(3) << "zero size input found"; + platform::errors::InvalidArgument( + "self_size, src_size, index_size cannot be 0"); + return; + } + int select_dim_size = index_dims[dim]; + // index matrix has different shape with self matrix or src matrix. + int replaced_select_dim_size = + is_scatter_like ? self_dims[dim] : src_dims[dim]; + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + for (int64_t i = 0; i < dim; ++i) { + inner_dim_size *= index_dims[i]; + } + + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + } + + int64_t index_idx = 0; + int64_t self_idx, src_idx; + + // N layer loop squeezed into 3 layers loop + for (int64_t i = 0; i < inner_dim_size; i++) { + for (int64_t j = 0; j < select_dim_size; j++) { + for (int64_t k = 0; k < outer_dim_size; k++) { + int64_t index = index_data[index_idx]; + + /* + gather computation formula: + + self[i][j][k] = src[index[i][j][k]][j][k] # if dim == 0 + self[i][j][k] = src[i][index[i][j][k]][k] # if dim == 1 + self[i][j][k] = src[i][j][index[i][j][k]] # if dim == 2 + + scatter computation formula: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + */ + + // This index might out of bound of index matrix's index, so here + // multiply the replaced_select_dim_size. + int64_t replace_index = k + index * outer_dim_size + + i * outer_dim_size * replaced_select_dim_size; + + self_idx = is_scatter_like ? replace_index : index_idx; + src_idx = is_scatter_like ? index_idx : replace_index; + + reduce_op((tensor_t*)(self_data + self_idx), + (tensor_t*)(src_data + src_idx)); + index_idx++; + } + } + } + } +}; + +template +void cpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result, + const platform::DeviceContext& ctx) { + cpu_gather_scatter_functor()( + result, dim, index, self, "gather_out_cpu", tensor_assign, ctx); +} + +template +void cpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index, + Tensor src, const platform::DeviceContext& ctx) { + cpu_gather_scatter_functor()( + self, dim, index, src, "scatter_assign_cpu", tensor_assign, ctx); +} + +template +void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index, + Tensor src, const platform::DeviceContext& ctx) { + cpu_gather_scatter_functor()( + self, dim, index, src, "scatter_add_cpu", reduce_add, ctx); +} + +Instantiate_Template_Function(cpu_gather_kernel) + Instantiate_Template_Function(cpu_scatter_add_kernel) + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/gather_scatter_kernel.cu b/paddle/fluid/operators/gather_scatter_kernel.cu new file mode 100644 index 00000000000..b94001a56db --- /dev/null +++ b/paddle/fluid/operators/gather_scatter_kernel.cu @@ -0,0 +1,157 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class TensorAssign { + public: + template + constexpr void operator()(tensor_t* self_data, tensor_t* src_data) const { + *self_data = *src_data; + } +}; +static TensorAssign tensor_assign; + +class ReduceAdd { + public: + template < + typename tensor_t, + std::enable_if_t::value>* = nullptr> + __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { + platform::CudaAtomicAdd(self_data, *src_data); + } + template ::value>* = nullptr> + __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { + *self_data += *src_data; + } +}; +static ReduceAdd reduce_add; + +template +__global__ void GatherScatterGPUKernel( + tensor_t* self_data, int dim, const index_t* index_data, tensor_t* src_data, + int64_t inner_dim_size, int select_dim_size, int replaced_select_dim_size, + int64_t outer_dim_size, int64_t numel, const func_t& reduce_op) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop + // squeezed from the N layers loop. + /* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */ + i = tid / (select_dim_size * outer_dim_size); + int64_t remind = tid % (select_dim_size * outer_dim_size); + j = remind / outer_dim_size; + k = remind % outer_dim_size; + index_t index = index_data[tid]; + /* + gather computation formula: + + self[i][j][k] = src[index[i][j][k]][j][k] # if dim == 0 + self[i][j][k] = src[i][index[i][j][k]][k] # if dim == 1 + self[i][j][k] = src[i][j][index[i][j][k]] # if dim == 2 + + scatter computation formula: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + */ + // index matrix has different shape with self matrix or src matrix. + int64_t replace_index = k + index * outer_dim_size + + i * outer_dim_size * replaced_select_dim_size; + int64_t self_idx = is_scatter_like ? replace_index : tid; + int64_t src_idx = is_scatter_like ? tid : replace_index; + reduce_op((tensor_t*)(self_data + self_idx), (tensor_t*)(src_data + src_idx)); +} + +template +struct gpu_gather_scatter_functor { + template + void operator()(Tensor self, int dim, const Tensor& index, Tensor src, + const std::string& method_name, const func_t& reduce_op, + const platform::DeviceContext& ctx) { + if (index.numel() == 0) { + return; + } + auto* self_data = self.data(); + auto* index_data = index.data(); + auto* src_data = src.data(); + int64_t self_size = self.numel(); + int64_t index_size = index.numel(); + int64_t src_size = src.numel(); + auto self_dims = self.dims(); + auto index_dims = index.dims(); + auto src_dims = src.dims(); + if (self_size == 0 || src_size == 0 || index_size == 0) return; + int select_dim_size = index_dims[dim]; + // index matrix has different shape with self matrix or src matrix. + int replaced_select_dim_size = + is_scatter_like ? self_dims[dim] : src_dims[dim]; + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + for (int64_t i = 0; i < index_dims.size(); ++i) { + inner_dim_size *= index_dims[i]; + } + + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + } + + int64_t slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + int block = 512; + int64_t n = slice_size * index_size; + int64_t grid = (n + block - 1) / block; + auto stream = + reinterpret_cast(ctx).stream(); + GatherScatterGPUKernel<<>>( + self_data, dim, index_data, src_data, inner_dim_size, select_dim_size, + replaced_select_dim_size, outer_dim_size, index_size, reduce_op); + } +}; // struct gpu_gather_scatter_functor + +template +void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result, + const platform::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + result, dim, index, self, "gather_out_gpu", tensor_assign, ctx); + return; +} + +template +void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index, + Tensor src, const platform::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + self, dim, index, src, "scatter_add_gpu", reduce_add, ctx); +} + +namespace plat = paddle::platform; +Instantiate_Template_Function(gpu_gather_kernel) + Instantiate_Template_Function(gpu_scatter_add_kernel) + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/gather_scatter_kernel.h b/paddle/fluid/operators/gather_scatter_kernel.h new file mode 100644 index 00000000000..28eea5f2a03 --- /dev/null +++ b/paddle/fluid/operators/gather_scatter_kernel.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/tensor.h" + +#pragma once + +namespace paddle { +namespace operators { + +#define Instantiate_Template_Function(func) \ + Instantiate_Template_Function_index_t( \ + func, int) Instantiate_Template_Function_index_t(func, float) \ + Instantiate_Template_Function_index_t(func, double) \ + Instantiate_Template_Function_index_t(func, int64_t) \ + Instantiate_Template_Function_index_t(func, platform::float16) \ + Instantiate_Template_Function_index_t(func, unsigned char) + +#define Instantiate_Template_Function_index_t(func, tensor_t) \ + template void func(Tensor input, int dim, \ + const Tensor& index, Tensor result, \ + const platform::DeviceContext& ctx); \ + template void func(Tensor input, int dim, \ + const Tensor& index, Tensor result, \ + const platform::DeviceContext& ctx); + +using Tensor = framework::Tensor; + +template +void cpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result, + const platform::DeviceContext& ctx); + +template +void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index, + Tensor src, const platform::DeviceContext& ctx); + +template +void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result, + const platform::DeviceContext& ctx); + +template +void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index, + Tensor src, const platform::DeviceContext& ctx); + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/take_along_axis_op.cc b/paddle/fluid/operators/take_along_axis_op.cc new file mode 100644 index 00000000000..fef5d10f2da --- /dev/null +++ b/paddle/fluid/operators/take_along_axis_op.cc @@ -0,0 +1,154 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/take_along_axis_op.h" +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +class TakeAlongAxisOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Input"), true, + platform::errors::InvalidArgument( + "Input(Input) of TakeAlongAxisOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Index"), true, + platform::errors::InvalidArgument( + "Input(Index) of TakeAlongAxisOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Result"), true, + platform::errors::InvalidArgument( + "Output(Result) of TakeAlongAxisOp should not be null.")); + + auto input_dim = ctx->GetInputDim("Input"); + auto index_dim = ctx->GetInputDim("Index"); + + PADDLE_ENFORCE_GT(input_dim.size(), 0, + platform::errors::InvalidArgument( + "Dimension of the input(Input) of TakeAlongAxisOp " + "should be greater than 0.", + input_dim)); + + PADDLE_ENFORCE_GT(index_dim.size(), 0, + platform::errors::InvalidArgument( + "Dimension of the input(Index) of TakeAlongAxisOp " + "should be greater than 0.", + index_dim)); + + ctx->SetOutputDim("Result", index_dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +class TakeAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "The input tensor of TakeAlongAxisOp"); + AddInput("Index", "The index tensor of TakeAlongAxisOp"); + AddOutput("Result", "The result tensor of TakeAlongAxisOp"); + AddAttr("Axis", + "The Tensor which contains the axis that we do TakeAlongAxis " + "operation."); + AddComment(R"DOC( + Take_along_axis Operator.) + )DOC"); + } +}; + +class TakeAlongAxisGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("Input"), + ctx->GetInputDim("Input")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Result")), + ctx.device_context()); + } + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +template +class TakeAlongAxisGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("take_along_axis_grad"); + op->SetInput("Index", this->Input("Index")); + op->SetInput("Input", this->Input("Input")); + + op->SetInput(framework::GradVarName("Result"), this->OutputGrad("Result")); + op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(take_along_axis, ops::TakeAlongAxisOp, + ops::TakeAlongAxisOpMaker, + ops::TakeAlongAxisGradOpMaker, + ops::TakeAlongAxisGradOpMaker); + +REGISTER_OPERATOR(take_along_axis_grad, ops::TakeAlongAxisGradOp); + +REGISTER_OP_CPU_KERNEL(take_along_axis, ops::TakeAlongAxisOpKernel, + ops::TakeAlongAxisOpKernel, + ops::TakeAlongAxisOpKernel, + ops::TakeAlongAxisOpKernel, + ops::TakeAlongAxisOpKernel); + +REGISTER_OP_CPU_KERNEL(take_along_axis_grad, + ops::TakeAlongAxisGradOpKernel, + ops::TakeAlongAxisGradOpKernel, + ops::TakeAlongAxisGradOpKernel, + ops::TakeAlongAxisGradOpKernel, + ops::TakeAlongAxisGradOpKernel); diff --git a/paddle/fluid/operators/take_along_axis_op.cu b/paddle/fluid/operators/take_along_axis_op.cu new file mode 100644 index 00000000000..c705959390b --- /dev/null +++ b/paddle/fluid/operators/take_along_axis_op.cu @@ -0,0 +1,98 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/operators/take_along_axis_op.h" + +namespace paddle { +namespace operators { + +template +class TakeAlongAxisCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet( + "This kernel only runs on GPU device.")); + auto input = ctx.Input("Input"); + auto axis = ctx.Attr("Axis"); + auto index = ctx.Input("Index"); + auto result = ctx.Output("Result"); + result->Resize(index->dims()); + result->mutable_data(ctx.GetPlace()); + + const auto &index_type = index->type(); + if (index_type == framework::proto::VarType::INT32) { + gpu_gather_kernel(*input, axis, *index, *result, + ctx.device_context()); + } else if (index_type == framework::proto::VarType::INT64) { + gpu_gather_kernel(*input, axis, *index, *result, + ctx.device_context()); + } + } +}; + +template +class TakeAlongAxisGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("This kernel only runs on GPU.")); + + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto index = ctx.Input("Index"); + auto result_grad = ctx.Input(framework::GradVarName("Result")); + auto axis = ctx.Attr("Axis"); + // We need to know the shape of input matrix to determine the shape of grad + // matrix of input. + auto input = ctx.Input("Input"); + input_grad->Resize(input->dims()); + input_grad->mutable_data(ctx.GetPlace()); + + // Set to zero tensor. + auto &dev_ctx = ctx.template device_context(); + math::SetConstant functor; + functor(reinterpret_cast(dev_ctx), + input_grad, static_cast(0)); + const auto &index_type = index->type(); + + if (index_type == framework::proto::VarType::INT32) { + gpu_scatter_add_kernel( + *input_grad, axis, *index, *result_grad, + ctx.device_context()); // the gradient of gather is scatter + } else if (index_type == framework::proto::VarType::INT64) { + gpu_scatter_add_kernel(*input_grad, axis, *index, + *result_grad, ctx.device_context()); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(take_along_axis, ops::TakeAlongAxisCUDAKernel, + ops::TakeAlongAxisCUDAKernel, + ops::TakeAlongAxisCUDAKernel, + ops::TakeAlongAxisCUDAKernel, + ops::TakeAlongAxisCUDAKernel); +REGISTER_OP_CUDA_KERNEL(take_along_axis_grad, + ops::TakeAlongAxisGradOpCUDAKernel, + ops::TakeAlongAxisGradOpCUDAKernel, + ops::TakeAlongAxisGradOpCUDAKernel, + ops::TakeAlongAxisGradOpCUDAKernel, + ops::TakeAlongAxisGradOpCUDAKernel); diff --git a/paddle/fluid/operators/take_along_axis_op.h b/paddle/fluid/operators/take_along_axis_op.h new file mode 100644 index 00000000000..580ca528ceb --- /dev/null +++ b/paddle/fluid/operators/take_along_axis_op.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class TakeAlongAxisOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_cpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); + + auto input = ctx.Input("Input"); + auto axis = ctx.Attr("Axis"); + auto index = ctx.Input("Index"); + auto result = ctx.Output("Result"); + result->Resize(index->dims()); + result->mutable_data(ctx.GetPlace()); + + const auto &index_type = index->type(); + if (index_type == framework::proto::VarType::INT32) { + cpu_gather_kernel(*input, axis, *index, *result, + ctx.device_context()); + } else if (index_type == framework::proto::VarType::INT64) { + cpu_gather_kernel(*input, axis, *index, *result, + ctx.device_context()); + } + } +}; + +template +class TakeAlongAxisGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_cpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); + + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto index = ctx.Input("Index"); + auto result_grad = ctx.Input(framework::GradVarName("Result")); + auto axis = ctx.Attr("Axis"); + // We need to know the shape of input matrix to determine the shape of grad + // matrix of input. + auto input = ctx.Input("Input"); + input_grad->Resize(input->dims()); + input_grad->mutable_data(ctx.GetPlace()); + + // Set to zero tensor. + auto &dev_ctx = ctx.template device_context(); + math::SetConstant functor; + functor(reinterpret_cast(dev_ctx), + input_grad, static_cast(0)); + + const auto &index_type = index->type(); + if (index_type == framework::proto::VarType::INT32) { + cpu_scatter_add_kernel( + *input_grad, axis, *index, *result_grad, + ctx.device_context()); // the gradient of gather is scatter + } else if (index_type == framework::proto::VarType::INT64) { + cpu_scatter_add_kernel(*input_grad, axis, *index, + *result_grad, ctx.device_context()); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e6311ea2e6a..b2effed3c9c 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -158,6 +158,7 @@ from .tensor.manipulation import unbind # noqa: F401 from .tensor.manipulation import roll # noqa: F401 from .tensor.manipulation import chunk # noqa: F401 from .tensor.manipulation import tolist # noqa: F401 +from .tensor.manipulation import take_along_axis # noqa: F401 from .tensor.manipulation import tensordot # noqa: F401 from .tensor.manipulation import as_complex # noqa: F401 from .tensor.manipulation import as_real # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_take_along_axis_op.py b/python/paddle/fluid/tests/unittests/test_take_along_axis_op.py new file mode 100644 index 00000000000..855a790cf05 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_take_along_axis_op.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid +from paddle.framework import core +from paddle.fluid.dygraph.base import switch_to_static_graph + +paddle.enable_static() + + +class TestTakeAlongAxisOp(OpTest): + def setUp(self): + self.init_data() + self.op_type = "take_along_axis" + self.xnp = np.random.random(self.x_shape).astype(self.x_type) + self.target = np.take_along_axis(self.xnp, self.index, self.axis) + broadcast_shape_list = list(self.x_shape) + broadcast_shape_list[self.axis] = 1 + self.braodcast_shape = tuple(broadcast_shape_list) + self.index_broadcast = np.broadcast_to(self.index, self.braodcast_shape) + self.inputs = { + 'Input': self.xnp, + 'Index': self.index_broadcast, + } + self.attrs = {'Axis': self.axis} + self.outputs = {'Result': self.target} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Input'], 'Result') + + def init_data(self): + self.x_type = "float64" + self.x_shape = (5, 5, 5) + self.index_type = "int32" + self.index = np.array( + [[[1]], [[1]], [[2]], [[4]], [[3]]]).astype(self.index_type) + self.axis = 2 + self.axis_type = "int64" + + +class TestCase1(TestTakeAlongAxisOp): + def init_data(self): + self.x_type = "float64" + self.x_shape = (5, 5, 5) + self.index_type = "int32" + self.index = np.array([[[0, 1, 2, 1, 4]]]).astype(self.index_type) + self.axis = 0 + self.axis_type = "int64" + + +class TestTakeAlongAxisAPI(unittest.TestCase): + def setUp(self): + np.random.seed(0) + self.shape = [3, 3] + self.index_shape = [1, 3] + self.index_np = np.array([[0, 1, 2]]).astype('int64') + self.x_np = np.random.random(self.shape).astype(np.float32) + self.place = [paddle.CPUPlace()] + self.axis = 0 + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.shape) + index = paddle.fluid.data('Index', self.index_shape, "int64") + out = paddle.take_along_axis(x, index, self.axis) + exe = paddle.static.Executor(self.place[0]) + res = exe.run(feed={'X': self.x_np, + 'Index': self.index_np}, + fetch_list=[out]) + out_ref = np.array( + np.take_along_axis(self.x_np, self.index_np, self.axis)) + for out in res: + self.assertEqual(np.allclose(out, out_ref, rtol=1e-03), True) + + def test_api_dygraph(self): + paddle.disable_static(self.place[0]) + x_tensor = paddle.to_tensor(self.x_np) + self.index = paddle.to_tensor(self.index_np) + out = paddle.take_along_axis(x_tensor, self.index, self.axis) + out_ref = np.array( + np.take_along_axis(self.x_np, self.index_np, self.axis)) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-03), True) + paddle.enable_static() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 957a42fc69a..4780d71a8d2 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -117,6 +117,7 @@ from .manipulation import roll # noqa: F401 from .manipulation import chunk # noqa: F401 from .manipulation import tensordot # noqa: F401 from .manipulation import as_complex # noqa: F401 +from .manipulation import take_along_axis # noqa: F401 from .manipulation import as_real # noqa: F401 from .manipulation import moveaxis # noqa: F401 from .manipulation import repeat_interleave # noqa: F401 @@ -464,6 +465,7 @@ tensor_method_func = [ #noqa 'angle', 'moveaxis', 'repeat_interleave', + 'take_along_axis', 'exponential_', ] diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5b23ae9d37f..d50a0c34ee6 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2749,3 +2749,57 @@ def moveaxis(x, source, destination, name=None): 'XShape': [x_shape]}, attrs={'axis': perm}) return out + + +def take_along_axis(arr, indices, axis): + """ + Take values from the input array by given indices matrix along the designated axis. + + Args: + arr (Tensor) : The input Tensor. supported data type are float32 and float64. + indices (Tensor) : Indices to take along each 1d slice of arr. This must match the dimension of arr, + and need to broadcast against arr. Supported data type are int and int64. + axis (int) : The axis to take 1d slices along. + + Returns: + Tensor: The indexed element, same dtype with arr + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + x_np = np.array([[1, 2, 3], [4, 5, 6], [7,8,9]]) + index_np = np.array([[0]]) + x = paddle.to_tensor(x_np) + index = paddle.to_tensor(index_np) + axis = 0 + result = paddle.take_along_axis(x, index, axis) + print(result) + # [[1, 2, 3]] + """ + broadcast_shape_list = list(arr.shape) + broadcast_shape_list[axis] = 1 + broadcast_shape = tuple(broadcast_shape_list) + if in_dygraph_mode(): + indices = paddle.broadcast_to(indices, broadcast_shape) + return _C_ops.take_along_axis(arr, indices, 'Axis', axis) + check_variable_and_dtype( + arr, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], + 'take_along_axis') + check_variable_and_dtype(indices, 'index', ['int32', 'int64'], + 'take_along_axis') + indices = paddle.broadcast_to( + indices, + broadcast_shape) # broadcast to shape of the input array first. + helper = LayerHelper('take_along_axis', **locals()) + dtype = helper.input_dtype() + result = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="take_along_axis", + inputs={"Input": arr, + "Index": indices}, + attrs={"Axis": axis}, + outputs={"Result": result}) + return result -- GitLab