From 85914f7a88f748a3883bfd09df810e910cce3089 Mon Sep 17 00:00:00 2001 From: ShenLiang <2282912238@qq.com> Date: Fri, 30 Aug 2019 16:26:55 +0800 Subject: [PATCH] add gather_nd op and unit test (#19366) * fixed the code for coverage * fixed the document,test=document_preview test=develop --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/gather.cu.h | 78 +++++++- paddle/fluid/operators/gather.h | 46 +++++ paddle/fluid/operators/gather_nd_op.cc | 182 ++++++++++++++++++ paddle/fluid/operators/gather_nd_op.cu | 105 ++++++++++ paddle/fluid/operators/gather_nd_op.h | 91 +++++++++ paddle/fluid/operators/scatter.cu.h | 75 ++++++++ paddle/fluid/operators/scatter.h | 46 ++++- python/paddle/fluid/layers/nn.py | 86 +++++++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../tests/unittests/test_gather_nd_op.py | 169 ++++++++++++++++ 11 files changed, 878 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/gather_nd_op.cc create mode 100644 paddle/fluid/operators/gather_nd_op.cu create mode 100644 paddle/fluid/operators/gather_nd_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_gather_nd_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 9b2a948488..287b918255 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -194,6 +194,7 @@ paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale' paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '4836e98a634f6fbea26d0cdaa303f867')) paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '32ffc0e8818d7319ed1bf63a791e985d')) paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c')) +paddle.fluid.layers.gather_nd (ArgSpec(args=['input', 'index', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '3cc24f9cf135770aa6263dba25b457f9')) paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab')) paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'abe3f714120117a5a3d3e639853932bf')) paddle.fluid.layers.random_crop (ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)), ('document', '042af0b8abea96b40c22f6e70d99e042')) diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index 86b3a25235..d0ab24a39e 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include "paddle/fluid/framework/dim.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -39,6 +43,27 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices, } } +template +__global__ void GatherNdCUDAKernel(const T* input, const int* input_dims, + const IndexT* indices, T* output, + size_t remain_size, size_t slice_size, + size_t end_size) { + CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) { + int indices_i = i / slice_size; + int slice_i = i - indices_i * slice_size; // offset inside the slice + IndexT gather_i = 0; + int64_t temp = slice_size; + for (int64_t j = end_size - 1; j >= 0; --j) { + auto index_value = indices[indices_i * end_size + j]; + assert(index_value >= 0 && index_value < input_dims[j]); + gather_i += (index_value * temp); + temp *= input_dims[j]; + } + IndexT input_i = gather_i + slice_i; + *(output + i) = *(input + input_i); + } +} + /** * A thin wrapper on gpu tensor * Return a new tensor from source tensor, gathered according to index @@ -84,5 +109,56 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, p_src, p_index, p_output, index_size, slice_size); } +template +void GPUGatherNd(const framework::ExecutionContext& context, + const Tensor& input, const Tensor& index, Tensor* output) { + const auto& ctx = context.template device_context(); + const auto gplace = boost::get(ctx.GetPlace()); + auto cplace = platform::CPUPlace(); + + auto index_dims = index.dims(); + auto index_dims_size = index_dims.size(); + auto input_dims = input.dims(); + auto input_dims_size = input_dims.size(); + + const T* p_input = input.data(); + const IndexT* p_index = index.data(); + T* p_output = output->data(); + + // final dim + int64_t end_size = index_dims[index_dims_size - 1]; + // remain dim + auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1); + int64_t remain_numel = framework::product(remain_ddim); + // slice size + int64_t slice_size = 1; + for (int64_t i = end_size; i < input_dims_size; ++i) { + slice_size *= input_dims[i]; + } + // source dim + std::vector v_input_dims(input_dims_size); + for (int i = 0; i < input_dims_size; ++i) { + v_input_dims[i] = static_cast(input_dims[i]); + } + + auto& dev_ctx = context.cuda_device_context(); + auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); + int bytes = input_dims_size * sizeof(int); + auto p_input_dims = allocator.Allocate(bytes); + int* g_input_dims = reinterpret_cast(p_input_dims->ptr()); + memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes, + ctx.stream()); + + int block = 512; + int n = slice_size * remain_numel; + int grid = (n + block - 1) / block; + + GatherNdCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_input, g_input_dims, p_index, p_output, remain_numel, slice_size, + end_size); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/gather.h b/paddle/fluid/operators/gather.h index 1e02c036e3..d2f519c162 100644 --- a/paddle/fluid/operators/gather.h +++ b/paddle/fluid/operators/gather.h @@ -60,5 +60,51 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, } } +template +void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, + const Tensor& index, Tensor* output) { + PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, + "It should be running on the CPU"); + + auto index_dims = index.dims(); + auto index_dims_size = index_dims.size(); + auto input_dims = input.dims(); + auto input_dims_size = input_dims.size(); + + const T* p_input = input.data(); + const IndexT* p_index = index.data(); + T* p_output = output->data(); + + // final dim + int64_t end_size = index_dims[index_dims_size - 1]; + // remain dim + auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1); + int64_t remain_numel = framework::product(remain_ddim); + // slice size + int64_t slice_size = 1; + for (int64_t i = end_size; i < input_dims_size; ++i) { + slice_size *= input_dims[i]; + } + const size_t slice_bytes = slice_size * sizeof(T); + + for (int64_t i = 0; i < remain_numel; ++i) { + int64_t index_ = 0; + int64_t temp = 1; + for (int64_t j = end_size - 1; j >= 0; --j) { + IndexT index_value = p_index[i * end_size + j]; + PADDLE_ENFORCE_LT(index_value, input_dims[j], + "Input(index[-1)] has wrong value, it is %d", + index_value); + PADDLE_ENFORCE_GE(index_value, 0UL, + "The value of Input(index) must be no less than 0"); + + index_ += (index_value * temp); + temp *= input_dims[j]; + } + memcpy(p_output + i * slice_size, p_input + index_ * slice_size, + slice_bytes); + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/gather_nd_op.cc b/paddle/fluid/operators/gather_nd_op.cc new file mode 100644 index 0000000000..43699f57b6 --- /dev/null +++ b/paddle/fluid/operators/gather_nd_op.cc @@ -0,0 +1,182 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather_nd_op.h" +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" + +namespace paddle { +namespace operators { + +class GatherNdOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of GatherNdOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, + "Input(Index) of GatherNdOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of GatherNdOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto x_dims_size = x_dims.size(); + auto index_dims = ctx->GetInputDim("Index"); + auto index_dims_size = index_dims.size(); + + PADDLE_ENFORCE_LE(index_dims[index_dims_size - 1], x_dims_size, + "Input(Index).shape[-1] <= Input(X).rank"); + PADDLE_ENFORCE_GE(index_dims_size, 2UL, + "The rank of Input(Index) should be greater than 1"); + + std::vector result_dims; + // The result dims is + // Index.shape[:-1] + X.shape[Index.shape[-1]:] + for (int i = 0; i < index_dims_size - 1; ++i) { + result_dims.emplace_back(index_dims[i]); + } + for (int i = index_dims[index_dims_size - 1]; i < x_dims_size; ++i) { + result_dims.emplace_back(x_dims[i]); + } + + ctx->SetOutputDim("Out", framework::make_ddim(result_dims)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class GatherNdGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); + } +}; + +class GatherNdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The source input of gather_nd op"); + AddInput("Index", "The index input of gather_nd op"); + AddOutput("Out", "The output of gather_nd op"); + AddComment(R"DOC( + Gather_Nd Operator. + + This function is actually a high-dimensional extension of gather + and supports for simultaneous indexing by multiple axes. Out is + obtained by gathering slices from X into a tensor with shape + Index.shape[:-1] + X.shape[Index.shape[-1]:]. + + Example: + + Given: + X = [[[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]], + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]] + + X.shape = (2, 3, 4) + + *Case 1: + + Index = [[1]] + + we get: + Out = + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]] + + *Case 2: + + Index = [[0,2]] + + we get: + + Out = [8, 9, 10, 11] + + *Case 3: + + Index = [[1, 2, 3]] + + we get: + + Out = [23] + +)DOC"); + } +}; + +class GatherNdGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("gather_nd_grad"); + op->SetInput("Index", Input("Index")); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GatherNdGradNoNeedBufferVarInference, + "X"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(gather_nd, ops::GatherNdOp, ops::GatherNdOpMaker, + ops::GatherNdGradOpDescMaker); + +REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp, + ops::GatherNdGradNoNeedBufferVarInference); + +REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel, + ops::GatherNdOpKernel, + ops::GatherNdOpKernel, + ops::GatherNdOpKernel, + ops::GatherNdOpKernel); + +REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel, + ops::GatherNdGradOpKernel, + ops::GatherNdGradOpKernel, + ops::GatherNdGradOpKernel, + ops::GatherNdGradOpKernel); diff --git a/paddle/fluid/operators/gather_nd_op.cu b/paddle/fluid/operators/gather_nd_op.cu new file mode 100644 index 0000000000..1ad335039a --- /dev/null +++ b/paddle/fluid/operators/gather_nd_op.cu @@ -0,0 +1,105 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/gather_nd_op.h" +#include "paddle/fluid/operators/scatter.cu.h" + +namespace paddle { +namespace operators { + +template +class GatherNdOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *output = ctx.Output("Out"); + + output->mutable_data(ctx.GetPlace()); + if (x->numel() == 0) return; + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + GPUGatherNd(ctx, *x, *index, output); + } else if (index_type == framework::proto::VarType::INT64) { + GPUGatherNd(ctx, *x, *index, output); + } + } +}; + +template +class GatherNdGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + "This kernel only runs on GPU device."); + auto *index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); + + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto &place = *ctx.template device_context() + .eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (dO->numel() == 0) return; + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + + PADDLE_ENFORCE_EQ( + index_type_match, true, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + + if (index_type == framework::proto::VarType::INT32) { + GPUScatterNdAdd(ctx, *dO, *index, dX); + } else if (index_type == framework::proto::VarType::INT64) { + GPUScatterNdAdd(ctx, *dO, *index, dX); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(gather_nd_grad, + ops::GatherNdGradOpCUDAKernel, + ops::GatherNdGradOpCUDAKernel, + ops::GatherNdGradOpCUDAKernel, + ops::GatherNdGradOpCUDAKernel, + ops::GatherNdGradOpCUDAKernel); diff --git a/paddle/fluid/operators/gather_nd_op.h b/paddle/fluid/operators/gather_nd_op.h new file mode 100644 index 0000000000..059ca54c46 --- /dev/null +++ b/paddle/fluid/operators/gather_nd_op.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/gather.h" +#include "paddle/fluid/operators/scatter.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GatherNdOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, + "This kernel only runs on CPU."); + + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *output = ctx.Output("Out"); + + output->mutable_data(ctx.GetPlace()); + if (x->numel() == 0) return; + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + CPUGatherNd(ctx.device_context(), *x, *index, output); + } else if (index_type == framework::proto::VarType::INT64) { + CPUGatherNd(ctx.device_context(), *x, *index, output); + } + } +}; + +template +class GatherNdGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, + "This kernel only runs on CPU."); + auto *index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto &place = *ctx.template device_context() + .eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (dO->numel() == 0) return; + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + ScatterNdAdd(ctx, *dO, *index, dX); + } else if (index_type == framework::proto::VarType::INT64) { + ScatterNdAdd(ctx, *dO, *index, dX); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/scatter.cu.h b/paddle/fluid/operators/scatter.cu.h index ce4af44266..f4aabd4618 100644 --- a/paddle/fluid/operators/scatter.cu.h +++ b/paddle/fluid/operators/scatter.cu.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include "math/math_function.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -57,6 +58,26 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, } } +template +__global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices, + T* output, const int* output_dims, + size_t remain_size, size_t slice_size, + size_t end_size) { + CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) { + int indices_i = i / slice_size; + int slice_i = i - indices_i * slice_size; // offset inside the slice + IndexT gather_i = 0; + int64_t temp = slice_size; + for (int64_t j = end_size - 1; j >= 0; --j) { + IndexT index_value = indices[indices_i * end_size + j]; + gather_i += (index_value * temp); + temp *= output_dims[j]; + } + IndexT output_i = gather_i + slice_i; + paddle::platform::CudaAtomicAdd(output + output_i, *(update + i)); + } +} + /** * A thin wrapper on gpu tensor * Return a new updated tensor from source tensor, scatter-assigned according to @@ -109,5 +130,59 @@ void GPUScatterAssign(const framework::ExecutionContext& context, p_src, p_index, p_output, index_size, slice_size, overwrite); } +template +void GPUScatterNdAdd(const framework::ExecutionContext& context, + const Tensor& update, const Tensor& index, + Tensor* output) { + auto index_dims = index.dims(); + auto index_dims_size = index_dims.size(); + + auto output_dims = output->dims(); + auto output_dims_size = output_dims.size(); + + const T* p_update = update.data(); + const IndexT* p_index = index.data(); + T* p_output = output->data(); + + // final dim + int64_t end_size = index_dims[index_dims_size - 1]; + // remain dim + auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1); + int64_t remain_numel = framework::product(remain_ddim); + // slice size + int64_t slice_size = 1; + for (int64_t i = end_size; i < output_dims_size; ++i) { + slice_size *= output_dims[i]; + } + const size_t slice_bytes = slice_size * sizeof(T); + // put output_dims int CUDA + // gplace and cplace + const auto& ctx = context.template device_context(); + const auto gplace = boost::get(ctx.GetPlace()); + auto cplace = platform::CPUPlace(); + + std::vector v_output_dims(output_dims_size); + for (int i = 0; i < output_dims_size; ++i) { + v_output_dims[i] = static_cast(output_dims[i]); + } + auto& dev_ctx = context.cuda_device_context(); + auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); + int bytes = output_dims_size * sizeof(int); + auto output_dims_ptr = allocator.Allocate(bytes); + int* g_output_dims = reinterpret_cast(output_dims_ptr->ptr()); + memory::Copy(gplace, g_output_dims, cplace, v_output_dims.data(), bytes, + ctx.stream()); + + int block = 512; + int n = slice_size * remain_numel; + int grid = (n + block - 1) / block; + + ScatterNdCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_update, p_index, p_output, g_output_dims, remain_numel, slice_size, + end_size); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 680dc282c1..3f6bfff5db 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -144,5 +144,49 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, } } +template +void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update, + const Tensor& index, Tensor* output) { + PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()), + true, "It should be running on the CPU"); + + // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] + auto index_dims = index.dims(); + auto index_dims_size = index_dims.size(); + + auto output_dims = output->dims(); + auto output_dims_size = output_dims.size(); + + const T* p_update = update.data(); + const IndexT* p_index = index.data(); + T* result_p_output = output->data(); + const T* p_output = output->data(); + + // final dim + int64_t end_size = index_dims[index_dims_size - 1]; + // remain dim + auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1); + int64_t remain_numel = framework::product(remain_ddim); + // slice size + int64_t slice_size = 1; + for (int64_t i = end_size; i < output_dims_size; ++i) { + slice_size *= output_dims[i]; + } + const size_t slice_bytes = slice_size * sizeof(T); + + for (int64_t i = 0; i < remain_numel; ++i) { + IndexT index_ = 0; + IndexT temp = 1; + for (int64_t j = end_size - 1; j >= 0; --j) { + IndexT index_value = p_index[i * end_size + j]; + index_ += (index_value * temp); + temp *= output_dims[j]; + } + elementwise_inner_add(ctx, p_update, p_output, result_p_output, + update, output, i, index_, slice_size, + slice_bytes); + } +} + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a86d558ffa..87f8454c62 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -122,6 +122,7 @@ __all__ = [ 'resize_trilinear', 'resize_nearest', 'gather', + 'gather_nd', 'scatter', 'sequence_scatter', 'random_crop', @@ -8449,6 +8450,91 @@ def gather(input, index, overwrite=True): return out +def gather_nd(input, index, name=None): + """ + **Gather Nd Layer** + + This function is actually a high-dimensional extension of :code:`gather` + and supports for simultaneous indexing by multiple axes. :attr:`index` is a + K-dimensional integer tensor, which is regarded as a (K-1)-dimensional + tensor of :attr:`index` into :attr:`input`, where each element defines + a slice of params: + + .. math:: + + output[(i_0, ..., i_{K-2})] = input[index[(i_0, ..., i_{K-2})]] + + Obviously, :code:`index.shape[-1] <= input.rank` . And, the output tensor has + shape :code:`index.shape[:-1] + input.shape[index.shape[-1]:]` . + + .. code-block:: text + + Given: + input = [[[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]], + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]] + input.shape = (2, 3, 4) + + * Case 1: + index = [[1]] + + gather_nd(input, index) + = [input[1, :, :]] + = [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]] + + * Case 2: + index = [[0,2]] + + gather_nd(input, index) + = [input[0, 2, :]] + = [8, 9, 10, 11] + + * Case 3: + index = [[1, 2, 3]] + + gather_nd(input, index) + = [input[1, 2, 3]] + = [23] + + Args: + input (Variable): The source input + index (Variable): The index input with rank > 1, index.shape[-1] <= input.rank + name (str|None): A name for this layer(optional). If set None, the + layer will be named automatically + + Returns: + output (Variable): A tensor with the shape index.shape[:-1] + input.shape[index.shape[-1]:] + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + x = fluid.layers.data(name='x', shape=[3, 4, 5], dtype='float32') + index = fluid.layers.data(name='index', shape=[2, 2], dtype='int32') + output = fluid.layers.gather_nd(x, index) + + """ + helper = LayerHelper('gather_nd', **locals()) + dtype = helper.input_dtype() + if name is None: + output = helper.create_variable_for_type_inference(dtype) + else: + output = helper.create_variable( + name=name, dtype=dtype, persistable=False) + helper.append_op( + type="gather_nd", + inputs={"X": input, + "Index": index}, + outputs={"Out": output}) + return output + + def scatter(input, index, updates, name=None, overwrite=True): """ **Scatter Layer** diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a357b6b864..42f1ae3a0b 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -185,6 +185,7 @@ set(TEST_OPS_WITH_GC test_fill_constant_batch_size_like_op test_fill_zeros_like2_op test_gather_op + test_gather_nd_op test_gaussian_random_batch_size_like_op test_linear_chain_crf_op test_lod_reset_op diff --git a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py new file mode 100644 index 0000000000..3264b2aff4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py @@ -0,0 +1,169 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid + + +class TestGatherNdOpWithEmptyIndex(OpTest): + """ + Index has empty element, which means copy entire tensor + """ + + def setUp(self): + self.op_type = "gather_nd" + xnp = np.array( + [[65, 17, 2], [-14, -25, -1], [76, 22, 3]]).astype("float32") + self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")} + self.outputs = { + 'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :])) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestGatherNdOpWithLowIndex(OpTest): + """ + Index has low rank, X has high rank + """ + + def setUp(self): + self.op_type = "gather_nd" + xnp = np.array( + [[65, 17, 2], [14, 25, 1], [76, 22, 3]]).astype("float32") + index = np.array([[1], [2]]).astype("int64") + + self.inputs = {'X': xnp, 'Index': index} + + self.outputs = {'Out': xnp[tuple(index.T)]} #[[14, 25, 1], [76, 22, 3]] + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestGatherNdOpWithSameIndexAsX(OpTest): + """ + Index has same rank as X's rank + """ + + def setUp(self): + self.op_type = "gather_nd" + xnp = np.array( + [[65, 17, 2], [14, 25, 1], [76, 22, 3]]).astype("float64") + index = np.array([[1, 1], [2, 1]]).astype("int64") + + self.inputs = {'X': xnp, 'Index': index} + self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22] + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestGatherNdOpWithHighRankSame(OpTest): + """ + Both Index and X have high rank, and Rank(Index) = Rank(X) + """ + + def setUp(self): + self.op_type = "gather_nd" + shape = (20, 9, 8, 1, 31) + xnp = np.random.rand(*shape) + index = np.vstack([np.random.randint(0, s, size=150) for s in shape]).T + + self.inputs = {'X': xnp, 'Index': index.astype("int32")} + self.outputs = {'Out': xnp[tuple(index.T)]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestGatherNdOpWithHighRankDiff(OpTest): + """ + Both Index and X have high rank, and Rank(Index) < Rank(X) + """ + + def setUp(self): + self.op_type = "gather_nd" + shape = (20, 9, 8, 1, 31) + xnp = np.random.rand(*shape).astype("double") + index = np.vstack([np.random.randint(0, s, size=1000) for s in shape]).T + index_re = index.reshape([10, 5, 20, 5]) + + self.inputs = {'X': xnp, 'Index': index_re.astype("int32")} + self.outputs = {'Out': xnp[tuple(index.T)].reshape([10, 5, 20])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +#Test Python API +class TestGatherNdOpAPI(OpTest): + def test_case1(self): + x1 = fluid.layers.data( + name='x1', shape=[30, 40, 50, 60], dtype='float32') + index1 = fluid.layers.data(name='index1', shape=[2, 4], dtype='int32') + output1 = fluid.layers.gather_nd(x1, index1) + + def test_case2(self): + x2 = fluid.layers.data(name='x2', shape=[30, 40, 50], dtype='float32') + index2 = fluid.layers.data(name='index2', shape=[2, 2], dtype='int64') + output2 = fluid.layers.gather_nd(x2, index2) + + def test_case3(self): + x3 = fluid.layers.data(name='x3', shape=[3, 4, 5], dtype='float32') + index3 = fluid.layers.data(name='index3', shape=[2, 1], dtype='int32') + output3 = fluid.layers.gather_nd(x3, index3, name="gather_nd_layer") + + +#Test Raise Index Error +class TestGatherNdOpRaise(OpTest): + def test_check_raise(self): + def check_raise_is_test(): + try: + x = fluid.layers.data( + name='x', shape=[3, 4, 5], dtype='float32') + index = fluid.layers.data( + name='index', shape=[2, 10], dtype='int32') + output = fluid.layers.gather_nd(x, index) + except Exception as e: + t = \ + "Input(Index).shape[-1] <= Input(X).rank" + if t in str(e): + raise IndexError + + self.assertRaises(IndexError, check_raise_is_test) + + +if __name__ == "__main__": + unittest.main() -- GitLab