From a7de0e6654c9569c2059371822f843ea1ba29307 Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Fri, 17 Dec 2021 14:48:42 +0800 Subject: [PATCH] add op/api repeat/interleave (#37981) --- .../fluid/operators/repeat_interleave_op.cc | 174 ++++++++++ .../fluid/operators/repeat_interleave_op.cu | 307 ++++++++++++++++++ paddle/fluid/operators/repeat_interleave_op.h | 196 +++++++++++ paddle/fluid/pybind/op_function_generator.h | 1 + python/paddle/__init__.py | 2 + .../unittests/test_repeat_interleave_op.py | 212 ++++++++++++ python/paddle/tensor/__init__.py | 4 +- python/paddle/tensor/manipulation.py | 62 ++++ 8 files changed, 957 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/repeat_interleave_op.cc create mode 100644 paddle/fluid/operators/repeat_interleave_op.cu create mode 100644 paddle/fluid/operators/repeat_interleave_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py diff --git a/paddle/fluid/operators/repeat_interleave_op.cc b/paddle/fluid/operators/repeat_interleave_op.cc new file mode 100644 index 00000000000..7957dd1c1a2 --- /dev/null +++ b/paddle/fluid/operators/repeat_interleave_op.cc @@ -0,0 +1,174 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/repeat_interleave_op.h" +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class RepeatInterleaveOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(X) of RepeatInterleaveOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(Out) of RepeatInterleaveOp should not be null.")); + + auto input_dim = ctx->GetInputDim("X"); + auto dim = ctx->Attrs().Get("dim"); + auto output_dim = framework::vectorize(input_dim); + PADDLE_ENFORCE_EQ( + dim < input_dim.size() && dim >= (0 - input_dim.size()), true, + platform::errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + input_dim.size(), input_dim.size() - 1, dim)); + + auto repeats = ctx->Attrs().Get("Repeats"); + if (ctx->HasInput("RepeatsTensor")) { + auto repeats_dim = ctx->GetInputDim("RepeatsTensor"); + + PADDLE_ENFORCE_EQ( + repeats_dim.size() == 1 || + (repeats_dim.size() == 2 && repeats_dim[1] == 1), + true, platform::errors::InvalidArgument( + "The 'shape' of Input(RepeatsTensor) must be 1-D tensor. " + "But received: the 'shape' of Input(Index) is [%s], " + "the dimension of Input(Index) is [%d].", + repeats_dim, repeats_dim.size())); + + PADDLE_ENFORCE_EQ(repeats_dim[0] != 0, true, + platform::errors::InvalidArgument( + "The length of Input(RepeatsTensor) can't be 0.")); + + if (dim < 0) { + dim += input_dim.size(); + } + output_dim[dim] = -1; + } else if (repeats > 0) { + output_dim[dim] = input_dim[dim] * repeats; + } + VLOG(3) << "infershap out " << output_dim[dim]; + ctx->SetOutputDim("Out", framework::make_ddim(output_dim)); + auto type = ctx->GetInputsVarType("X")[0]; + if (type == framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class RepeatInterleaveGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::InvalidArgument( + "Input(Out@GRAD) should be not null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::InvalidArgument( + "Output(X@GRAD) should be not null.")); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +class RepeatInterleaveOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) the input tensor."); + AddInput("RepeatsTensor", + "the 1-D tensor containing the repeats alongsize the axis.") + .AsDispensable(); + AddOutput("Out", "the output tensor."); + AddAttr("Repeats", "the number of repetitions for each element.") + .SetDefault(0); + AddAttr("dim", "the dimension in which we repeat.").SetDefault(0); + AddComment(R"DOC( +Returns a new tensor which repeats the input tensor +along dimension dim using the entries in repeats which +is a Tensor or int. + +The returned tensor has the same number of dimensions +as the original tensor (input), except along the given axis. + )DOC"); + } +}; + +template +class RepeatInterleaveGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("repeat_interleave_grad"); + + op->SetInput("X", this->Input("X")); + op->SetInput("RepeatsTensor", this->Input("RepeatsTensor")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(RepeatInterleaveGradNoNeedBufferVarsInferer, + "X"); +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(repeat_interleave, ops::RepeatInterleaveOp, + ops::RepeatInterleaveOpMaker, + ops::RepeatInterleaveGradMaker, + ops::RepeatInterleaveGradMaker); +REGISTER_OPERATOR(repeat_interleave_grad, ops::RepeatInterleaveGradOp, + ops::RepeatInterleaveGradNoNeedBufferVarsInferer); +REGISTER_OP_CPU_KERNEL( + repeat_interleave, + ops::RepeatInterleaveKernel, + ops::RepeatInterleaveKernel, + ops::RepeatInterleaveKernel, + ops::RepeatInterleaveKernel); +REGISTER_OP_CPU_KERNEL( + repeat_interleave_grad, + ops::RepeatInterleaveGradKernel, + ops::RepeatInterleaveGradKernel, + ops::RepeatInterleaveGradKernel, + ops::RepeatInterleaveGradKernel); diff --git a/paddle/fluid/operators/repeat_interleave_op.cu b/paddle/fluid/operators/repeat_interleave_op.cu new file mode 100644 index 00000000000..8acb4f216ea --- /dev/null +++ b/paddle/fluid/operators/repeat_interleave_op.cu @@ -0,0 +1,307 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/repeat_interleave_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +// function borrowed from repeat_interleave_op +template +__global__ void index_select_cuda_kernel(const T* input, T* output, + const IndexT* index, int64_t N, + int64_t stride, int64_t size, + int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + output[idx] = input[input_idx]; +} + +template +__global__ void index_select_grad_cuda_kernel(const T* output_grad, + T* input_grad, + const IndexT* index, int64_t nums, + int64_t N, int64_t stride, + int64_t size, int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]); +} + +template +__global__ void index_select_grad_init(T* input_grad, int64_t N) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + input_grad[idx] = 0.0; +} +template +class RepeatInterleaveCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + // auto* index = context.Input("RepeatsTensor"); + auto* out = context.Output("Out"); + int dim = context.Attr("dim"); + auto input_dim = in->dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = framework::stride(input_dim); + int64_t stride = stride_dim[dim]; + + auto stream = + context.template device_context().stream(); + + int repeats = context.Attr("Repeats"); + framework::LoDTensor index; + auto* in_data = in->data(); + if (context.HasInput("RepeatsTensor")) { + auto repeats_tensor = + context.Input("RepeatsTensor"); + + PADDLE_ENFORCE_EQ(repeats_tensor->dims()[0] == in->dims()[dim], true, + platform::errors::InvalidArgument( + "The length of Input(RepeatsTensor) must be the " + "same as length of Input(X) in axis. " + "But received: [%s], required: [%d].", + repeats_tensor->dims()[0], in->dims()[dim])); + + const auto& index_type = repeats_tensor->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Input(RepeatsTensor) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + if (index_type == framework::proto::VarType::INT64) { + RepeatsTensor2IndexTensor(*repeats_tensor, + &index); + + const int64_t* index_data = index.data(); + auto output_dim = framework::vectorize(in->dims()); + output_dim[dim] = index.dims()[0]; + out->Resize(framework::make_ddim(output_dim)); + auto* out_data = out->mutable_data(context.GetPlace()); + int64_t numel = out->numel(); + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + index_select_cuda_kernel<<< + (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data, + numel, stride, size, delta); + } else { + RepeatsTensor2IndexTensor(*repeats_tensor, &index); + + const int* index_data = index.data(); + auto output_dim = framework::vectorize(in->dims()); + output_dim[dim] = index.dims()[0]; + out->Resize(framework::make_ddim(output_dim)); + auto* out_data = out->mutable_data(context.GetPlace()); + int64_t numel = out->numel(); + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + index_select_cuda_kernel<<< + (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data, + numel, stride, size, delta); + } + } else if (repeats > 0) { + int64_t index_size = in->dims()[dim] * repeats; + std::vector index_vec(index_size); + for (int i = 0; i < in->dims()[dim]; i++) { + std::fill_n(index_vec.begin() + i * repeats, repeats, i); + } + index.Resize(framework::make_ddim({index_size})); + auto ctx = paddle::platform::DeviceContextPool::Instance().Get( + context.GetPlace()); + paddle::framework::TensorFromVector(index_vec, *ctx, &index); + + auto output_dim = framework::vectorize(in->dims()); + output_dim[dim] = index_size; + out->Resize(framework::make_ddim(output_dim)); + auto* out_data = out->mutable_data(context.GetPlace()); + + int64_t numel = out->numel(); + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + const int* index_data = index.data(); + index_select_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_data, out_data, index_data, numel, stride, size, delta); + platform::GpuStreamSync(stream); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "repeats must given with RepeatsTensor (tensor) or repeats (int)")); + } + } +}; + +template +class RepeatInterleaveGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* output_grad = context.Input(framework::GradVarName("Out")); + auto* in_grad = context.Output(framework::GradVarName("X")); + + auto* output_grad_data = output_grad->data(); + auto* in_grad_data = in_grad->mutable_data(context.GetPlace()); + + int dim = context.Attr("dim"); + auto input_dim = in_grad->dims(); + auto output_dim = output_grad->dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = framework::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + int64_t numel = in_grad->numel(); + int64_t out_nums = output_grad->numel(); + + auto stream = + context.template device_context().stream(); + + index_select_grad_init< + T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel); + + int repeats = context.Attr("Repeats"); + framework::LoDTensor index; + if (context.HasInput("RepeatsTensor")) { + auto repeats_tensor = + context.Input("RepeatsTensor"); + + const auto& index_type = repeats_tensor->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + if (index_type == framework::proto::VarType::INT64) { + RepeatsTensor2IndexTensor(*repeats_tensor, + &index); + int64_t index_nums = index.numel(); + + const int64_t* index_data = index.data(); + index_select_grad_cuda_kernel<<< + (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + output_grad_data, in_grad_data, index_data, index_nums, out_nums, + stride, size, delta); + platform::GpuStreamSync(stream); + } else { + RepeatsTensor2IndexTensor(*repeats_tensor, &index); + int64_t index_nums = index.numel(); + + const int* index_data = index.data(); + index_select_grad_cuda_kernel<<< + (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + output_grad_data, in_grad_data, index_data, index_nums, out_nums, + stride, size, delta); + platform::GpuStreamSync(stream); + } + } else if (repeats > 0) { + int64_t index_size = in_grad->dims()[dim] * repeats; + std::vector index_vec(index_size); + for (int i = 0; i < in_grad->dims()[dim]; i++) { + std::fill_n(index_vec.begin() + i * repeats, repeats, i); + } + index.Resize(framework::make_ddim({index_size})); + auto ctx = paddle::platform::DeviceContextPool::Instance().Get( + context.GetPlace()); + paddle::framework::TensorFromVector(index_vec, *ctx, &index); + + const int* index_data = index.data(); + int64_t index_nums = index.numel(); + index_select_grad_cuda_kernel<<< + (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, + index_data, index_nums, + out_nums, stride, size, delta); + platform::GpuStreamSync(stream); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "repeats must given with RepeatsTensor (tensor) or repeats (int)")); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + repeat_interleave, + ops::RepeatInterleaveCUDAKernel, + ops::RepeatInterleaveCUDAKernel, + ops::RepeatInterleaveCUDAKernel, + ops::RepeatInterleaveCUDAKernel, + ops::RepeatInterleaveCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + repeat_interleave_grad, + ops::RepeatInterleaveGradCUDAKernel, + ops::RepeatInterleaveGradCUDAKernel, + ops::RepeatInterleaveGradCUDAKernel, + ops::RepeatInterleaveGradCUDAKernel, + ops::RepeatInterleaveGradCUDAKernel); diff --git a/paddle/fluid/operators/repeat_interleave_op.h b/paddle/fluid/operators/repeat_interleave_op.h new file mode 100644 index 00000000000..1a38b0271dd --- /dev/null +++ b/paddle/fluid/operators/repeat_interleave_op.h @@ -0,0 +1,196 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +#include "paddle/fluid/operators/index_select_op.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DDim = framework::DDim; + +template +void RepeatsTensor2IndexTensor(const LoDTensor& repeats, LoDTensor* index) { + LoDTensor repeats_cpu_copy; + if (!platform::is_cpu_place(repeats.place())) { + framework::TensorCopySync(repeats, platform::CPUPlace(), &repeats_cpu_copy); + } + const RepeatsT* repeats_data = platform::is_cpu_place(repeats.place()) + ? repeats.data() + : repeats_cpu_copy.data(); + + int64_t index_size = 0; + for (int i = 0; i < repeats.dims()[0]; i++) { + index_size += repeats_data[i]; + } + std::vector index_vec(index_size); + int offset = 0; + for (int i = 0; i < repeats.dims()[0]; i++) { + std::fill_n(index_vec.begin() + offset, repeats_data[i], i); + offset += repeats_data[i]; + } + index->Resize(framework::make_ddim({index_size})); + + auto ctx = + paddle::platform::DeviceContextPool::Instance().Get(repeats.place()); + paddle::framework::TensorFromVector(index_vec, *ctx, index); +} + +template +class RepeatInterleaveKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto inputs = *context.Input("X"); + auto* output = context.Output("Out"); + + int dim = context.Attr("dim"); + if (dim < 0) { + dim += inputs.dims().size(); + } + + int repeats = context.Attr("Repeats"); + framework::LoDTensor index; + if (context.HasInput("RepeatsTensor")) { + auto repeats_tensor = + context.Input("RepeatsTensor"); + + PADDLE_ENFORCE_EQ(repeats_tensor->dims()[0] == inputs.dims()[dim], true, + platform::errors::InvalidArgument( + "The length of Input(RepeatsTensor) must be the " + "same as length of Input(X) in axis. " + "But received: [%s], required: [%d].", + repeats_tensor->dims()[0], inputs.dims()[dim])); + + const auto& index_type = repeats_tensor->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Input(RepeatsTensor) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + if (index_type == framework::proto::VarType::INT32) { + RepeatsTensor2IndexTensor(*repeats_tensor, &index); + auto output_dim = framework::vectorize(inputs.dims()); + output_dim[dim] = index.dims()[0]; + output->Resize(framework::make_ddim(output_dim)); + IndexSelectInner(context, &inputs, index, output, + dim); + } else if (index_type == framework::proto::VarType::INT64) { + RepeatsTensor2IndexTensor(*repeats_tensor, + &index); + auto output_dim = framework::vectorize(inputs.dims()); + output_dim[dim] = index.dims()[0]; + output->Resize(framework::make_ddim(output_dim)); + IndexSelectInner(context, &inputs, index, + output, dim); + } + } else if (repeats > 0) { + int64_t index_size = inputs.dims()[dim] * repeats; + std::vector index_vec(index_size); + for (int i = 0; i < inputs.dims()[dim]; i++) { + std::fill_n(index_vec.begin() + i * repeats, repeats, i); + } + index.Resize(framework::make_ddim({index_size})); + paddle::framework::TensorFromVector(index_vec, &index); + + auto output_dim = framework::vectorize(inputs.dims()); + output_dim[dim] = index_size; + output->Resize(framework::make_ddim(output_dim)); + + IndexSelectInner(context, &inputs, index, output, + dim); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "repeats must given with RepeatsTensor (tensor) or repeats (int)")); + } + } +}; + +template +class RepeatInterleaveGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x_grad = + context.Output(framework::GradVarName("X")); + auto* out_grad = + context.Input(framework::GradVarName("Out")); + + int dim = context.Attr("dim"); + if (dim < 0) { + dim += out_grad->dims().size(); + } + + int repeats = context.Attr("Repeats"); + framework::LoDTensor index; + if (context.HasInput("RepeatsTensor")) { + auto repeats_tensor = + context.Input("RepeatsTensor"); + const auto& index_type = repeats_tensor->type(); + + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Input(Repeats) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + if (index_type == framework::proto::VarType::INT32) { + RepeatsTensor2IndexTensor(*repeats_tensor, &index); + IndexSelectGradInner(context, *out_grad, index, + x_grad, dim); + } else if (index_type == framework::proto::VarType::INT64) { + RepeatsTensor2IndexTensor(*repeats_tensor, + &index); + IndexSelectGradInner(context, *out_grad, + index, x_grad, dim); + } + } else if (repeats > 0) { + int64_t index_size = x_grad->dims()[dim] * repeats; + std::vector index_vec(index_size); + for (int i = 0; i < x_grad->dims()[dim]; i++) { + std::fill_n(index_vec.begin() + i * repeats, repeats, i); + } + index.Resize(framework::make_ddim({index_size})); + paddle::framework::TensorFromVector(index_vec, &index); + + IndexSelectGradInner(context, *out_grad, index, + x_grad, dim); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "repeats must given with RepeatsTensor (tensor) or repeats (int)")); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 3e1c5b736f2..f148c971432 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -44,6 +44,7 @@ std::map> op_ins_map = { {"nll_loss", {"X", "Label", "Weight"}}, {"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}}, {"gather", {"X", "Index", "Axis"}}, + {"repeat_interleave", {"X", "RepeatsTensor"}}, {"roi_pool", {"X", "ROIs", "RoisNum"}}, {"roi_align", {"X", "ROIs", "RoisNum"}}, {"psroi_pool", {"X", "ROIs", "RoisNum"}}, diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ae287f63aef..7e1198bac51 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -159,6 +159,7 @@ from .tensor.manipulation import tensordot # noqa: F401 from .tensor.manipulation import as_complex # noqa: F401 from .tensor.manipulation import as_real # noqa: F401 from .tensor.manipulation import moveaxis # noqa: F401 +from .tensor.manipulation import repeat_interleave # noqa: F401 from .tensor.math import abs # noqa: F401 from .tensor.math import acos # noqa: F401 from .tensor.math import asin # noqa: F401 @@ -579,4 +580,5 @@ __all__ = [ # noqa 'fmax', 'fmin', 'moveaxis', + 'repeat_interleave', ] diff --git a/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py b/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py new file mode 100644 index 00000000000..b047b0c53d8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py @@ -0,0 +1,212 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +class TestRepeatInterleaveOp(OpTest): + def setUp(self): + self.op_type = "repeat_interleave" + self.init_dtype_type() + index_np = np.random.randint( + low=0, high=3, size=self.index_size).astype(self.index_type) + x_np = np.random.random(self.x_shape).astype(self.x_type) + self.inputs = {'X': x_np, 'RepeatsTensor': index_np} + self.attrs = {'dim': self.dim} + + outer_loop = np.prod(self.x_shape[:self.dim]) + x_reshape = [outer_loop] + list(self.x_shape[self.dim:]) + x_np_reshape = np.reshape(x_np, tuple(x_reshape)) + out_list = [] + for i in range(outer_loop): + for j in range(self.index_size): + for k in range(index_np[j]): + out_list.append(x_np_reshape[i, j]) + self.out_shape = list(self.x_shape) + self.out_shape[self.dim] = np.sum(index_np) + self.out_shape = tuple(self.out_shape) + + out = np.reshape(out_list, self.out_shape) + self.outputs = {'Out': out} + + def init_dtype_type(self): + self.dim = 1 + self.x_type = np.float64 + self.index_type = np.int64 + self.x_shape = (8, 4, 5) + self.index_size = self.x_shape[self.dim] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + +class TestRepeatInterleaveOp2(OpTest): + def setUp(self): + self.op_type = "repeat_interleave" + self.init_dtype_type() + index_np = 2 + x_np = np.random.random(self.x_shape).astype(self.x_type) + self.inputs = {'X': x_np} #, 'RepeatsTensor': None} + self.attrs = {'dim': self.dim, 'Repeats': index_np} + + outer_loop = np.prod(self.x_shape[:self.dim]) + x_reshape = [outer_loop] + list(self.x_shape[self.dim:]) + x_np_reshape = np.reshape(x_np, tuple(x_reshape)) + out_list = [] + for i in range(outer_loop): + for j in range(self.index_size): + for k in range(index_np): + out_list.append(x_np_reshape[i, j]) + self.out_shape = list(self.x_shape) + self.out_shape[self.dim] = index_np * self.index_size + self.out_shape = tuple(self.out_shape) + + out = np.reshape(out_list, self.out_shape) + self.outputs = {'Out': out} + + def init_dtype_type(self): + self.dim = 1 + self.x_type = np.float64 + self.x_shape = (8, 4, 5) + self.index_size = self.x_shape[self.dim] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + +class TestIndexSelectAPI(unittest.TestCase): + def input_data(self): + self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0]]) + self.data_index = np.array([0, 1, 2, 1]).astype('int32') + + def test_repeat_interleave_api(self): + paddle.enable_static() + self.input_data() + + # case 1: + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[-1, 4]) + index = fluid.layers.data( + name='repeats', + shape=[4], + dtype='int32', + append_batch_size=False) + z = paddle.repeat_interleave(x, index, axis=1) + exe = fluid.Executor(fluid.CPUPlace()) + res, = exe.run(feed={'x': self.data_x, + 'repeats': self.data_index}, + fetch_list=[z.name], + return_numpy=False) + expect_out = np.repeat(self.data_x, self.data_index, axis=1) + self.assertTrue(np.allclose(expect_out, np.array(res))) + + # case 2: + repeats = np.array([1, 2, 1]).astype('int32') + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[-1, 4]) + index = fluid.layers.data( + name='repeats', + shape=[3], + dtype='int32', + append_batch_size=False) + z = paddle.repeat_interleave(x, index, axis=0) + exe = fluid.Executor(fluid.CPUPlace()) + res, = exe.run(feed={ + 'x': self.data_x, + 'repeats': repeats, + }, + fetch_list=[z.name], + return_numpy=False) + expect_out = np.repeat(self.data_x, repeats, axis=0) + self.assertTrue(np.allclose(expect_out, np.array(res))) + + repeats = 2 + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[-1, 4]) + z = paddle.repeat_interleave(x, repeats, axis=0) + exe = fluid.Executor(fluid.CPUPlace()) + res, = exe.run(feed={'x': self.data_x}, + fetch_list=[z.name], + return_numpy=False) + expect_out = np.repeat(self.data_x, repeats, axis=0) + self.assertTrue(np.allclose(expect_out, np.array(res))) + + def test_dygraph_api(self): + self.input_data() + # case axis none + input_x = np.array([[1, 2, 1], [1, 2, 3]]).astype('int32') + index_x = np.array([1, 1, 2, 1, 2, 2]).astype('int32') + + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(input_x) + index = fluid.dygraph.to_variable(index_x) + z = paddle.repeat_interleave(x, index, None) + np_z = z.numpy() + expect_out = np.repeat(input_x, index_x, axis=None) + self.assertTrue(np.allclose(expect_out, np_z)) + + # case repeats int + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(input_x) + index = 2 + z = paddle.repeat_interleave(x, index, None) + np_z = z.numpy() + expect_out = np.repeat(input_x, index, axis=None) + self.assertTrue(np.allclose(expect_out, np_z)) + + # case 1: + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(self.data_x) + index = fluid.dygraph.to_variable(self.data_index) + z = paddle.repeat_interleave(x, index, -1) + np_z = z.numpy() + expect_out = np.repeat(self.data_x, self.data_index, axis=-1) + self.assertTrue(np.allclose(expect_out, np_z)) + + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(self.data_x) + index = fluid.dygraph.to_variable(self.data_index) + z = paddle.repeat_interleave(x, index, 1) + np_z = z.numpy() + expect_out = np.repeat(self.data_x, self.data_index, axis=1) + self.assertTrue(np.allclose(expect_out, np_z)) + + # case 2: + index_x = np.array([1, 2, 1]).astype('int32') + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(self.data_x) + index = fluid.dygraph.to_variable(index_x) + z = paddle.repeat_interleave(x, index, axis=0) + np_z = z.numpy() + expect_out = np.repeat(self.data_x, index, axis=0) + self.assertTrue(np.allclose(expect_out, np_z)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 7794903c1bb..424cbbe4f2d 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -114,6 +114,7 @@ from .manipulation import tensordot # noqa: F401 from .manipulation import as_complex # noqa: F401 from .manipulation import as_real # noqa: F401 from .manipulation import moveaxis # noqa: F401 +from .manipulation import repeat_interleave # noqa: F401 from .math import abs # noqa: F401 from .math import acos # noqa: F401 from .math import asin # noqa: F401 @@ -436,7 +437,8 @@ tensor_method_func = [ #noqa 'lerp', 'lerp_', 'angle', - 'moveaxis' + 'moveaxis', + 'repeat_interleave', ] #this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5d263bde8b3..a03b179e883 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2584,6 +2584,68 @@ def as_real(x, name=None): return out +def repeat_interleave(x, repeats, axis=None, name=None): + """ + + Returns a new tensor which repeats the ``x`` tensor along dimension ``axis`` using + the entries in ``repeats`` which is a int or a Tensor. + + Args: + x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float32, float64, int32, int64. + repeats (Tensor or int): The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis. + axis (int, optional): The dimension in which we manipulate. Default: if None, the output tensor is flatten. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A Tensor with same data type as ``x``. + + x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) + repeats = paddle.to_tensor([3, 2, 1], dtype='int32') + + paddle.repeat_interleave(x, repeats, 1) + # [[1, 1, 1, 2, 2, 3], + # [4, 4, 4, 5, 5, 6]] + + paddle.repeat_interleave(x, 2, 0) + # [[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]] + + paddle.repeat_interleave(x, 2, None) + # [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] + """ + + if axis is None: + x = paddle.flatten(x) + axis = 0 + + if in_dygraph_mode(): + if isinstance(repeats, int): + return _C_ops.repeat_interleave(x, None, 'Repeats', repeats, 'dim', + axis) + elif isinstance(repeats, Variable): + return _C_ops.repeat_interleave(x, repeats, 'dim', axis) + + helper = LayerHelper("repeat_interleave", **locals()) + check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], + 'paddle.tensor.manipulation.repeat_interleave') + + out = helper.create_variable_for_type_inference(x.dtype) + + helper.append_op( + type='repeat_interleave', + inputs={ + 'X': x, + 'RepeatsTensor': repeats if isinstance(repeats, Variable) else None + }, + outputs={'Out': out}, + attrs={ + 'dim': axis, + 'Repeats': repeats if isinstance(repeats, int) else 0 + }) + return out + + def moveaxis(x, source, destination, name=None): """ Move the axis of tensor from ``source`` position to ``destination`` position. -- GitLab