From 3b32835fd7cb3e07bfdf22f9cbfb1686a7d7262b Mon Sep 17 00:00:00 2001 From: seemingwang Date: Wed, 3 Aug 2022 14:14:12 +0800 Subject: [PATCH] move repeat interleave (#44753) * move repeat interleave * fix api name * recover op registration * fix arguments order * fix * fix infermeta * fix infermeta * fix header * fix infermeta * fix * fix * fix dtype * log&test * test * remove logs * fix * remove logs * combine files * combine * combine files * fix cuda place --- .../fluid/operators/repeat_interleave_op.cc | 19 +- .../fluid/operators/repeat_interleave_op.cu | 341 ------------------ paddle/fluid/operators/repeat_interleave_op.h | 202 ----------- paddle/phi/api/yaml/legacy_api.yaml | 21 ++ paddle/phi/api/yaml/legacy_backward.yaml | 21 ++ paddle/phi/infermeta/binary.cc | 46 +++ paddle/phi/infermeta/binary.h | 4 + paddle/phi/infermeta/unary.cc | 31 ++ paddle/phi/infermeta/unary.h | 5 + .../cpu/repeat_interleave_grad_kernel.cc | 119 ++++++ .../kernels/cpu/repeat_interleave_kernel.cc | 37 ++ .../funcs/repeat_tensor2index_tensor.h | 48 +++ .../gpu/repeat_interleave_grad_kernel.cu | 36 ++ .../kernels/gpu/repeat_interleave_kernel.cu | 37 ++ .../impl/repeat_interleave_grad_kernel_impl.h | 226 ++++++++++++ .../impl/repeat_interleave_kernel_impl.h | 215 +++++++++++ .../kernels/repeat_interleave_grad_kernel.h | 38 ++ paddle/phi/kernels/repeat_interleave_kernel.h | 35 ++ .../phi/ops/compat/repeat_interleave_sig.cc | 56 +++ .../unittests/test_repeat_interleave_op.py | 19 +- python/paddle/tensor/manipulation.py | 11 +- 21 files changed, 997 insertions(+), 570 deletions(-) delete mode 100644 paddle/fluid/operators/repeat_interleave_op.cu delete mode 100644 paddle/fluid/operators/repeat_interleave_op.h create mode 100644 paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/repeat_interleave_kernel.cc create mode 100644 paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h create mode 100644 paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/repeat_interleave_kernel.cu create mode 100644 paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h create mode 100644 paddle/phi/kernels/repeat_interleave_grad_kernel.h create mode 100644 paddle/phi/kernels/repeat_interleave_kernel.h create mode 100644 paddle/phi/ops/compat/repeat_interleave_sig.cc diff --git a/paddle/fluid/operators/repeat_interleave_op.cc b/paddle/fluid/operators/repeat_interleave_op.cc index f3bec9489f..a3f04dd202 100644 --- a/paddle/fluid/operators/repeat_interleave_op.cc +++ b/paddle/fluid/operators/repeat_interleave_op.cc @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/repeat_interleave_op.h" - #include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/index_select_op.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { @@ -164,22 +166,13 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(RepeatInterleaveGradNoNeedBufferVarsInferer, } // namespace paddle namespace ops = paddle::operators; + REGISTER_OPERATOR(repeat_interleave, ops::RepeatInterleaveOp, ops::RepeatInterleaveOpMaker, ops::RepeatInterleaveGradMaker, ops::RepeatInterleaveGradMaker); + REGISTER_OPERATOR(repeat_interleave_grad, ops::RepeatInterleaveGradOp, ops::RepeatInterleaveGradNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL(repeat_interleave, - ops::RepeatInterleaveKernel, - ops::RepeatInterleaveKernel, - ops::RepeatInterleaveKernel, - ops::RepeatInterleaveKernel); -REGISTER_OP_CPU_KERNEL( - repeat_interleave_grad, - ops::RepeatInterleaveGradKernel, - ops::RepeatInterleaveGradKernel, - ops::RepeatInterleaveGradKernel, - ops::RepeatInterleaveGradKernel); diff --git a/paddle/fluid/operators/repeat_interleave_op.cu b/paddle/fluid/operators/repeat_interleave_op.cu deleted file mode 100644 index 07099c3027..0000000000 --- a/paddle/fluid/operators/repeat_interleave_op.cu +++ /dev/null @@ -1,341 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/fluid/operators/repeat_interleave_op.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -using platform::PADDLE_CUDA_NUM_THREADS; -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -// function borrowed from repeat_interleave_op -template -__global__ void index_select_cuda_kernel(const T* input, - T* output, - const IndexT* index, - int64_t N, - int64_t stride, - int64_t size, - int64_t delta) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; - } - - int64_t pre_idx = idx / (stride * size); - int64_t dim_idx = idx % (stride * size) / stride; - IndexT src_dim_idx = index[dim_idx]; - int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; - output[idx] = input[input_idx]; -} - -template -__global__ void index_select_grad_cuda_kernel(const T* output_grad, - T* input_grad, - const IndexT* index, - int64_t nums, - int64_t N, - int64_t stride, - int64_t size, - int64_t delta) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; - } - - int64_t pre_idx = idx / (stride * size); - int64_t dim_idx = idx % (stride * size) / stride; - IndexT src_dim_idx = index[dim_idx]; - int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; - paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]); -} - -template -__global__ void index_select_grad_init(T* input_grad, int64_t N) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; - } - input_grad[idx] = 0.0; -} -template -class RepeatInterleaveCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - // auto* index = context.Input("RepeatsTensor"); - auto* out = context.Output("Out"); - int dim = context.Attr("dim"); - auto input_dim = in->dims(); - dim = dim >= 0 ? dim : dim + input_dim.size(); - auto stride_dim = phi::stride(input_dim); - int64_t stride = stride_dim[dim]; - - auto stream = context.template device_context().stream(); - - int repeats = context.Attr("Repeats"); - framework::LoDTensor index; - auto* in_data = in->data(); - if (context.HasInput("RepeatsTensor")) { - auto repeats_tensor = - context.Input("RepeatsTensor"); - - PADDLE_ENFORCE_EQ(repeats_tensor->dims()[0] == in->dims()[dim], - true, - platform::errors::InvalidArgument( - "The length of Input(RepeatsTensor) must be the " - "same as length of Input(X) in axis. " - "But received: [%s], required: [%d].", - repeats_tensor->dims()[0], - in->dims()[dim])); - - const auto& index_type = - framework::TransToProtoVarType(repeats_tensor->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT64 || - index_type == framework::proto::VarType::INT32; - PADDLE_ENFORCE_EQ( - index_type_match, - true, - platform::errors::InvalidArgument( - "Input(RepeatsTensor) holds the wrong type, it holds %s, but " - "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - - if (index_type == framework::proto::VarType::INT64) { - RepeatsTensor2IndexTensor(*repeats_tensor, - &index); - - const int64_t* index_data = index.data(); - auto output_dim = phi::vectorize(in->dims()); - output_dim[dim] = index.dims()[0]; - out->Resize(phi::make_ddim(output_dim)); - auto* out_data = out->mutable_data(context.GetPlace()); - int64_t numel = out->numel(); - int64_t size = output_dim[dim]; - int64_t delta = input_dim[dim] - size; - - index_select_cuda_kernel - <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>( - in_data, out_data, index_data, numel, stride, size, delta); - } else { - RepeatsTensor2IndexTensor(*repeats_tensor, &index); - - const int* index_data = index.data(); - auto output_dim = phi::vectorize(in->dims()); - output_dim[dim] = index.dims()[0]; - out->Resize(phi::make_ddim(output_dim)); - auto* out_data = out->mutable_data(context.GetPlace()); - int64_t numel = out->numel(); - int64_t size = output_dim[dim]; - int64_t delta = input_dim[dim] - size; - - index_select_cuda_kernel - <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>( - in_data, out_data, index_data, numel, stride, size, delta); - } - } else if (repeats > 0) { - int64_t index_size = in->dims()[dim] * repeats; - std::vector index_vec(index_size); - for (int i = 0; i < in->dims()[dim]; i++) { - std::fill_n(index_vec.begin() + i * repeats, repeats, i); - } - index.Resize(phi::make_ddim({index_size})); - auto ctx = paddle::platform::DeviceContextPool::Instance().Get( - context.GetPlace()); - paddle::framework::TensorFromVector(index_vec, *ctx, &index); - - auto output_dim = phi::vectorize(in->dims()); - output_dim[dim] = index_size; - out->Resize(phi::make_ddim(output_dim)); - auto* out_data = out->mutable_data(context.GetPlace()); - - int64_t numel = out->numel(); - int64_t size = output_dim[dim]; - int64_t delta = input_dim[dim] - size; - - const int* index_data = index.data(); - index_select_cuda_kernel - <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>( - in_data, out_data, index_data, numel, stride, size, delta); - platform::GpuStreamSync(stream); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "repeats must given with RepeatsTensor (tensor) or repeats (int)")); - } - } -}; - -template -class RepeatInterleaveGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* output_grad = context.Input(framework::GradVarName("Out")); - auto* in_grad = context.Output(framework::GradVarName("X")); - - auto* output_grad_data = output_grad->data(); - auto* in_grad_data = in_grad->mutable_data(context.GetPlace()); - - int dim = context.Attr("dim"); - auto input_dim = in_grad->dims(); - auto output_dim = output_grad->dims(); - dim = dim >= 0 ? dim : dim + input_dim.size(); - auto stride_dim = phi::stride(input_dim); - int64_t stride = stride_dim[dim]; - int64_t size = output_dim[dim]; - int64_t delta = input_dim[dim] - size; - - int64_t numel = in_grad->numel(); - int64_t out_nums = output_grad->numel(); - - auto stream = context.template device_context().stream(); - - index_select_grad_init - <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(in_grad_data, numel); - - int repeats = context.Attr("Repeats"); - framework::LoDTensor index; - if (context.HasInput("RepeatsTensor")) { - auto repeats_tensor = - context.Input("RepeatsTensor"); - - const auto& index_type = - framework::TransToProtoVarType(repeats_tensor->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT64 || - index_type == framework::proto::VarType::INT32; - PADDLE_ENFORCE_EQ( - index_type_match, - true, - platform::errors::InvalidArgument( - "Input(Index) holds the wrong type, it holds %s, but " - "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - - if (index_type == framework::proto::VarType::INT64) { - RepeatsTensor2IndexTensor(*repeats_tensor, - &index); - int64_t index_nums = index.numel(); - - const int64_t* index_data = index.data(); - index_select_grad_cuda_kernel - <<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(output_grad_data, - in_grad_data, - index_data, - index_nums, - out_nums, - stride, - size, - delta); - platform::GpuStreamSync(stream); - } else { - RepeatsTensor2IndexTensor(*repeats_tensor, &index); - int64_t index_nums = index.numel(); - - const int* index_data = index.data(); - index_select_grad_cuda_kernel - <<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(output_grad_data, - in_grad_data, - index_data, - index_nums, - out_nums, - stride, - size, - delta); - platform::GpuStreamSync(stream); - } - } else if (repeats > 0) { - int64_t index_size = in_grad->dims()[dim] * repeats; - std::vector index_vec(index_size); - for (int i = 0; i < in_grad->dims()[dim]; i++) { - std::fill_n(index_vec.begin() + i * repeats, repeats, i); - } - index.Resize(phi::make_ddim({index_size})); - auto ctx = paddle::platform::DeviceContextPool::Instance().Get( - context.GetPlace()); - paddle::framework::TensorFromVector(index_vec, *ctx, &index); - - const int* index_data = index.data(); - int64_t index_nums = index.numel(); - index_select_grad_cuda_kernel - <<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(output_grad_data, - in_grad_data, - index_data, - index_nums, - out_nums, - stride, - size, - delta); - platform::GpuStreamSync(stream); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "repeats must given with RepeatsTensor (tensor) or repeats (int)")); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - repeat_interleave, - ops::RepeatInterleaveCUDAKernel, - ops::RepeatInterleaveCUDAKernel, - ops::RepeatInterleaveCUDAKernel, - ops::RepeatInterleaveCUDAKernel, - ops::RepeatInterleaveCUDAKernel); -REGISTER_OP_CUDA_KERNEL( - repeat_interleave_grad, - ops::RepeatInterleaveGradCUDAKernel, - ops::RepeatInterleaveGradCUDAKernel, - ops::RepeatInterleaveGradCUDAKernel, - ops::RepeatInterleaveGradCUDAKernel, - ops::RepeatInterleaveGradCUDAKernel); diff --git a/paddle/fluid/operators/repeat_interleave_op.h b/paddle/fluid/operators/repeat_interleave_op.h deleted file mode 100644 index 2dd62a90a8..0000000000 --- a/paddle/fluid/operators/repeat_interleave_op.h +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/index_select_op.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using DDim = framework::DDim; - -template -void RepeatsTensor2IndexTensor(const LoDTensor& repeats, LoDTensor* index) { - LoDTensor repeats_cpu_copy; - if (!platform::is_cpu_place(repeats.place())) { - framework::TensorCopySync(repeats, platform::CPUPlace(), &repeats_cpu_copy); - } - const RepeatsT* repeats_data = platform::is_cpu_place(repeats.place()) - ? repeats.data() - : repeats_cpu_copy.data(); - - int64_t index_size = 0; - for (int i = 0; i < repeats.dims()[0]; i++) { - index_size += repeats_data[i]; - } - std::vector index_vec(index_size); - int offset = 0; - for (int i = 0; i < repeats.dims()[0]; i++) { - std::fill_n(index_vec.begin() + offset, repeats_data[i], i); - offset += repeats_data[i]; - } - index->Resize(phi::make_ddim({index_size})); - - auto ctx = - paddle::platform::DeviceContextPool::Instance().Get(repeats.place()); - paddle::framework::TensorFromVector(index_vec, *ctx, index); -} - -template -class RepeatInterleaveKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto inputs = *context.Input("X"); - auto* output = context.Output("Out"); - - int dim = context.Attr("dim"); - if (dim < 0) { - dim += inputs.dims().size(); - } - - int repeats = context.Attr("Repeats"); - framework::LoDTensor index; - if (context.HasInput("RepeatsTensor")) { - auto repeats_tensor = - context.Input("RepeatsTensor"); - - PADDLE_ENFORCE_EQ(repeats_tensor->dims()[0] == inputs.dims()[dim], - true, - platform::errors::InvalidArgument( - "The length of Input(RepeatsTensor) must be the " - "same as length of Input(X) in axis. " - "But received: [%s], required: [%d].", - repeats_tensor->dims()[0], - inputs.dims()[dim])); - - const auto& index_type = - framework::TransToProtoVarType(repeats_tensor->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, - true, - platform::errors::InvalidArgument( - "Input(RepeatsTensor) holds the wrong type, it holds %s, but " - "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - - if (index_type == framework::proto::VarType::INT32) { - RepeatsTensor2IndexTensor(*repeats_tensor, &index); - auto output_dim = phi::vectorize(inputs.dims()); - output_dim[dim] = index.dims()[0]; - output->Resize(phi::make_ddim(output_dim)); - IndexSelectInner( - context, &inputs, index, output, dim); - } else if (index_type == framework::proto::VarType::INT64) { - RepeatsTensor2IndexTensor(*repeats_tensor, - &index); - auto output_dim = phi::vectorize(inputs.dims()); - output_dim[dim] = index.dims()[0]; - output->Resize(phi::make_ddim(output_dim)); - IndexSelectInner( - context, &inputs, index, output, dim); - } - } else if (repeats > 0) { - int64_t index_size = inputs.dims()[dim] * repeats; - std::vector index_vec(index_size); - for (int i = 0; i < inputs.dims()[dim]; i++) { - std::fill_n(index_vec.begin() + i * repeats, repeats, i); - } - index.Resize(phi::make_ddim({index_size})); - paddle::framework::TensorFromVector(index_vec, &index); - - auto output_dim = phi::vectorize(inputs.dims()); - output_dim[dim] = index_size; - output->Resize(phi::make_ddim(output_dim)); - - IndexSelectInner( - context, &inputs, index, output, dim); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "repeats must given with RepeatsTensor (tensor) or repeats (int)")); - } - } -}; - -template -class RepeatInterleaveGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x_grad = - context.Output(framework::GradVarName("X")); - auto* out_grad = - context.Input(framework::GradVarName("Out")); - - int dim = context.Attr("dim"); - if (dim < 0) { - dim += out_grad->dims().size(); - } - - int repeats = context.Attr("Repeats"); - framework::LoDTensor index; - if (context.HasInput("RepeatsTensor")) { - auto repeats_tensor = - context.Input("RepeatsTensor"); - const auto& index_type = - framework::TransToProtoVarType(repeats_tensor->dtype()); - - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, - true, - platform::errors::InvalidArgument( - "Input(Repeats) holds the wrong type, it holds %s, but " - "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - - if (index_type == framework::proto::VarType::INT32) { - RepeatsTensor2IndexTensor(*repeats_tensor, &index); - IndexSelectGradInner( - context, *out_grad, index, x_grad, dim); - } else if (index_type == framework::proto::VarType::INT64) { - RepeatsTensor2IndexTensor(*repeats_tensor, - &index); - IndexSelectGradInner( - context, *out_grad, index, x_grad, dim); - } - } else if (repeats > 0) { - int64_t index_size = x_grad->dims()[dim] * repeats; - std::vector index_vec(index_size); - for (int i = 0; i < x_grad->dims()[dim]; i++) { - std::fill_n(index_vec.begin() + i * repeats, repeats, i); - } - index.Resize(phi::make_ddim({index_size})); - paddle::framework::TensorFromVector(index_vec, &index); - - IndexSelectGradInner( - context, *out_grad, index, x_grad, dim); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "repeats must given with RepeatsTensor (tensor) or repeats (int)")); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 78184cf1da..60e3012f57 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1987,6 +1987,27 @@ func : renorm backward : renorm_grad +- api : repeat_interleave + args : (Tensor x, int repeats, int dim) + output : Tensor(out) + infer_meta : + func : RepeatInterleaveInferMeta + param : [x,repeats, dim] + kernel : + func : repeat_interleave + backward: repeat_interleave_grad + +- api : repeat_interleave_with_tensor_index + args : (Tensor x, Tensor repeats, int dim) + output : Tensor(out) + infer_meta : + func : RepeatInterleaveWithTensorIndexInferMeta + param : [x,repeats, dim] + kernel : + func : repeat_interleave_with_tensor_index + data_type : x + backward: repeat_interleave_with_tensor_index_grad + - api : reshape args : (Tensor x, IntArray shape) output : Tensor(out), Tensor(xshape) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 64b68ba6b3..5182709967 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1802,6 +1802,27 @@ kernel : func : renorm_grad +- backward_api : repeat_interleave_grad + forward : repeat_interleave(Tensor x, int repeats, int dim) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int repeats, int dim) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : repeat_interleave_grad + +- backward_api : repeat_interleave_with_tensor_index_grad + forward : repeat_interleave_with_tensor_index(Tensor x, Tensor repeats, int dim) -> Tensor(out) + args : (Tensor x, Tensor repeats, Tensor out_grad, int dim) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : repeat_interleave_with_tensor_index_grad + data_type : x + - backward_api : reshape_double_grad forward : reshape_grad (Tensor xshape, Tensor grad_out) -> Tensor(grad_x) args : (Tensor grad_out, Tensor grad_x_grad) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 566a2a953d..8ba4290e69 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2017,6 +2017,52 @@ void PriorBoxInferMeta(const MetaTensor& input, var->set_dims(phi::make_ddim(dim_vec)); } +void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x, + const MetaTensor& repeats, + int dim, + MetaTensor* out) { + const auto& input_dim = x.dims(); + auto output_dim = phi::vectorize(input_dim); + PADDLE_ENFORCE_EQ( + dim < input_dim.size() && dim >= (0 - input_dim.size()), + true, + phi::errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + input_dim.size(), + input_dim.size() - 1, + dim)); + + auto repeats_dim = repeats.dims(); + + PADDLE_ENFORCE_EQ( + repeats_dim.size() == 1 || + (repeats_dim.size() == 2 && repeats_dim[1] == 1), + true, + phi::errors::InvalidArgument( + "The 'shape' of Input(RepeatsTensor) must be 1-D tensor. " + "But received: the 'shape' of Input(Index) is [%s], " + "the dimension of Input(Index) is [%d].", + repeats_dim, + repeats_dim.size())); + + PADDLE_ENFORCE_EQ(repeats_dim[0] != 0, + true, + phi::errors::InvalidArgument( + "The length of Input(RepeatsTensor) can't be 0.")); + PADDLE_ENFORCE_NE(out, + nullptr, + phi::errors::InvalidArgument( + "repeat_interleave's output tensor can't be nullptr")); + if (dim < 0) { + dim += input_dim.size(); + } + output_dim[dim] = -1; + + out->set_dims(phi::make_ddim(output_dim)); + out->share_lod(x); + out->set_dtype(x.dtype()); +} void SearchsortedInferMeta(const MetaTensor& sorted_sequence, const MetaTensor& value, bool out_int32, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 3b03ce01a7..9f548256f4 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -279,6 +279,10 @@ void PReluInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x, + const MetaTensor& repeats, + int dim, + MetaTensor* out); void PriorBoxInferMeta(const MetaTensor& input, const MetaTensor& image, const std::vector& min_sizes, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a659909df9..8add8e970c 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2438,6 +2438,37 @@ void ReduceInferMetaBase(const MetaTensor& x, out->set_layout(x.layout()); } +void RepeatInterleaveInferMeta(const MetaTensor& x, + int repeats, + int dim, + MetaTensor* out) { + const auto& input_dim = x.dims(); + auto output_dim = phi::vectorize(input_dim); + + PADDLE_ENFORCE_EQ( + dim < input_dim.size() && dim >= (0 - input_dim.size()), + true, + phi::errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(dim) = %d.", + input_dim.size(), + input_dim.size() - 1, + dim)); + PADDLE_ENFORCE_EQ( + repeats > 0, + true, + phi::errors::InvalidArgument("repeats should be larger than zero")); + + PADDLE_ENFORCE_NE(out, + nullptr, + phi::errors::InvalidArgument( + "repeat_interleave's output tensor can't be nullptr")); + + output_dim[dim] = input_dim[dim] * repeats; + out->set_dims(phi::make_ddim(output_dim)); + out->share_lod(x); + out->set_dtype(x.dtype()); +} void ReshapeInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index b8fe4a2205..a2753e46c8 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -322,6 +322,11 @@ void ReduceInferMetaBase(const MetaTensor& x, bool reduce_all, MetaTensor* out); +void RepeatInterleaveInferMeta(const MetaTensor& x, + int repeats, + int dim, + MetaTensor* out); + void ReshapeInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out, diff --git a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc new file mode 100644 index 0000000000..8f4af6a82c --- /dev/null +++ b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/repeat_interleave_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/cpu/index_select_impl.h" +#include "paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h" + +namespace phi { + +template +void RepeatInterleaveWithTensorIndexGradKernel( + const Context& ctx, + const DenseTensor& x, + const DenseTensor& repeats_tensor, + const DenseTensor& out_grad, + int dim, + DenseTensor* x_grad) { + auto input_dim = x_grad->dims(); + if (dim < 0) { + dim += input_dim.size(); + } + + DenseTensor index; + PADDLE_ENFORCE_EQ(repeats_tensor.dims()[0] == x_grad->dims()[dim], + true, + phi::errors::InvalidArgument( + "The length of Input(RepeatsTensor) must be the " + "same as length of Input(X) in axis. " + "But received: [%s], required: [%d].", + repeats_tensor.dims()[0], + x_grad->dims()[dim])); + + const auto& index_type = + paddle::framework::TransToProtoVarType(repeats_tensor.dtype()); + + bool index_type_match = + index_type == paddle::framework::proto::VarType::INT32 || + index_type == paddle::framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Input(Repeats) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + paddle::framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + paddle::framework::proto::VarType::INT64))); + + paddle::platform::DeviceContextPool::Instance().Get(repeats_tensor.place()); + if (index_type == paddle::framework::proto::VarType::INT32) { + phi::funcs::RepeatsTensor2IndexTensor( + ctx, repeats_tensor, &index); + IndexSelectGradInner(ctx, out_grad, index, x_grad, dim); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + phi::funcs::RepeatsTensor2IndexTensor( + ctx, repeats_tensor, &index); + IndexSelectGradInner( + ctx, out_grad, index, x_grad, dim); + } +} + +template +void RepeatInterleaveGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int repeats, + int dim, + DenseTensor* x_grad) { + auto input_dim = x_grad->dims(); + if (dim < 0) { + dim += input_dim.size(); + } + + DenseTensor index; + int64_t index_size = x_grad->dims()[dim] * repeats; + std::vector index_vec(index_size); + for (int i = 0; i < x_grad->dims()[dim]; i++) { + std::fill_n(index_vec.begin() + i * repeats, repeats, i); + } + index.Resize(phi::make_ddim({index_size})); + paddle::framework::TensorFromVector(index_vec, &index); + const DenseTensor index_copy = index; + IndexSelectGradInner(ctx, out_grad, index_copy, x_grad, dim); +} +} // namespace phi + +PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, + CPU, + ALL_LAYOUT, + phi::RepeatInterleaveWithTensorIndexGradKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(repeat_interleave_grad, + CPU, + ALL_LAYOUT, + phi::RepeatInterleaveGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc new file mode 100644 index 0000000000..388e243eff --- /dev/null +++ b/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/repeat_interleave_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h" + +PD_REGISTER_KERNEL(repeat_interleave, + CPU, + ALL_LAYOUT, + phi::RepeatInterleaveKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, + CPU, + ALL_LAYOUT, + phi::RepeatInterleaveWithTensorIndexKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h b/paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h new file mode 100644 index 0000000000..545ecb660f --- /dev/null +++ b/paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/phi/core/dense_tensor.h" +namespace phi { +namespace funcs { +template +void RepeatsTensor2IndexTensor(const Context& ctx, + const DenseTensor& repeats, + DenseTensor* index) { + DenseTensor repeats_cpu_copy; + if (!paddle::platform::is_cpu_place(repeats.place())) { + phi::Copy( + ctx, repeats, paddle::platform::CPUPlace(), true, &repeats_cpu_copy); + } + const RepeatsT* repeats_data = paddle::platform::is_cpu_place(repeats.place()) + ? repeats.data() + : repeats_cpu_copy.data(); + + int64_t index_size = 0; + for (int i = 0; i < repeats.dims()[0]; i++) { + index_size += repeats_data[i]; + } + std::vector index_vec(index_size); + int offset = 0; + for (int i = 0; i < repeats.dims()[0]; i++) { + std::fill_n(index_vec.begin() + offset, repeats_data[i], i); + offset += repeats_data[i]; + } + index->Resize(phi::make_ddim({index_size})); + + paddle::framework::TensorFromVector(index_vec, ctx, index); +} +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu new file mode 100644 index 0000000000..52a0e31339 --- /dev/null +++ b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/repeat_interleave_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, + GPU, + ALL_LAYOUT, + phi::RepeatInterleaveWithTensorIndexGradKernel, + float, + double, + int, + int64_t) {} +PD_REGISTER_KERNEL(repeat_interleave_grad, + GPU, + ALL_LAYOUT, + phi::RepeatInterleaveGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu new file mode 100644 index 0000000000..ed62278f06 --- /dev/null +++ b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/repeat_interleave_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h" + +PD_REGISTER_KERNEL(repeat_interleave, + GPU, + ALL_LAYOUT, + phi::RepeatInterleaveKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, + GPU, + ALL_LAYOUT, + phi::RepeatInterleaveWithTensorIndexKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h b/paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h new file mode 100644 index 0000000000..ddaaebafbc --- /dev/null +++ b/paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h @@ -0,0 +1,226 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/cpu/index_select_impl.h" +#include "paddle/phi/kernels/repeat_interleave_grad_kernel.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" +#ifdef __NVCC__ +#include "cub/cub.cuh" +#else +#include +namespace cub = hipcub; +#endif +#endif + +#include "paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h" + +namespace phi { + +#if defined(__NVCC__) || defined(__HIPCC__) +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void index_select_grad_cuda_kernel(const T* output_grad, + T* input_grad, + const IndexT* index, + int64_t nums, + int64_t N, + int64_t stride, + int64_t size, + int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]); +} + +template +__global__ void index_select_grad_init(T* input_grad, int64_t N) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + input_grad[idx] = 0.0; +} +#endif +template +void RepeatInterleaveWithTensorIndexGradKernel( + const Context& ctx, + const DenseTensor& x, + const DenseTensor& repeats_tensor, + const DenseTensor& out_grad, + int dim, + DenseTensor* x_grad) { + auto place = ctx.GetPlace(); + auto cpu_place = phi::CPUPlace(); + + auto input_dim = x_grad->dims(); + if (dim < 0) { + dim += input_dim.size(); + } + + DenseTensor index; + PADDLE_ENFORCE_EQ(repeats_tensor.dims()[0] == x_grad->dims()[dim], + true, + phi::errors::InvalidArgument( + "The length of Input(RepeatsTensor) must be the " + "same as length of Input(X) in axis. " + "But received: [%s], required: [%d].", + repeats_tensor.dims()[0], + x_grad->dims()[dim])); + + const auto& index_type = + paddle::framework::TransToProtoVarType(repeats_tensor.dtype()); + + bool index_type_match = + index_type == paddle::framework::proto::VarType::INT32 || + index_type == paddle::framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Input(Repeats) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + paddle::framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + paddle::framework::proto::VarType::INT64))); +#if defined(__NVCC__) || defined(__HIPCC__) + + auto output_dim = out_grad.dims(); + auto stride_dim = phi::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + int64_t numel = x_grad->numel(); + int64_t out_nums = out_grad.numel(); + auto* out_grad_data = out_grad.data(); + ctx.template Alloc(x_grad); + auto* in_grad_data = x_grad->data(); + auto stream = ctx.stream(); + index_select_grad_init + <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(in_grad_data, numel); + + if (index_type == paddle::framework::proto::VarType::INT64) { + phi::funcs::RepeatsTensor2IndexTensor( + ctx, repeats_tensor, &index); + int64_t index_nums = index.numel(); + + const int64_t* index_data = index.data(); + index_select_grad_cuda_kernel + <<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(out_grad_data, + in_grad_data, + index_data, + index_nums, + out_nums, + stride, + size, + delta); + } else { + phi::funcs::RepeatsTensor2IndexTensor( + ctx, repeats_tensor, &index); + int64_t index_nums = index.numel(); + + const int* index_data = index.data(); + index_select_grad_cuda_kernel + <<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(out_grad_data, + in_grad_data, + index_data, + index_nums, + out_nums, + stride, + size, + delta); + } +#endif +} + +template +void RepeatInterleaveGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int repeats, + int dim, + DenseTensor* x_grad) { + auto place = ctx.GetPlace(); + auto cpu_place = phi::CPUPlace(); + + auto input_dim = x_grad->dims(); + if (dim < 0) { + dim += input_dim.size(); + } + + DenseTensor index; +#if defined(__NVCC__) || defined(__HIPCC__) + auto output_dim = out_grad.dims(); + auto stride_dim = phi::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + int64_t numel = x_grad->numel(); + int64_t out_nums = out_grad.numel(); + auto* out_grad_data = out_grad.data(); + ctx.template Alloc(x_grad); + auto* in_grad_data = x_grad->data(); + auto stream = ctx.stream(); + index_select_grad_init + <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(in_grad_data, numel); + int64_t index_size = x_grad->dims()[dim] * repeats; + std::vector index_vec(index_size); + for (int i = 0; i < x_grad->dims()[dim]; i++) { + std::fill_n(index_vec.begin() + i * repeats, repeats, i); + } + index.Resize(phi::make_ddim({index_size})); + paddle::framework::TensorFromVector(index_vec, ctx, &index); + + const int* index_data = index.data(); + int64_t index_nums = index.numel(); + index_select_grad_cuda_kernel + <<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(out_grad_data, + in_grad_data, + index_data, + index_nums, + out_nums, + stride, + size, + delta); +#endif +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h b/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h new file mode 100644 index 0000000000..dd950a14f6 --- /dev/null +++ b/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h @@ -0,0 +1,215 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/cpu/index_select_impl.h" +#include "paddle/phi/kernels/repeat_interleave_kernel.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_resources.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" +#endif + +#include "paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h" + +namespace phi { + +#if defined(__NVCC__) || defined(__HIPCC__) +using paddle::platform::PADDLE_CUDA_NUM_THREADS; +template +__global__ void index_select_cuda_kernel(const T* input, + T* output, + const IndexT* index, + int64_t N, + int64_t stride, + int64_t size, + int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + output[idx] = input[input_idx]; +} +#endif + +template +void RepeatInterleaveKernel(const Context& ctx, + const DenseTensor& x, + int repeats, + int dim, + DenseTensor* out) { + auto place = ctx.GetPlace(); + auto cpu_place = phi::CPUPlace(); + + auto input_dim = x.dims(); + if (dim < 0) { + dim += input_dim.size(); + } + + DenseTensor index; + int64_t index_size = input_dim[dim] * repeats; + std::vector index_vec(index_size); + for (int i = 0; i < input_dim[dim]; i++) { + std::fill_n(index_vec.begin() + i * repeats, repeats, i); + } + index.Resize(phi::make_ddim({index_size})); + if (place == cpu_place) { + DenseTensor x_copy = x; + paddle::framework::TensorFromVector(index_vec, &index); + + auto output_dim = phi::vectorize(x.dims()); + output_dim[dim] = index_size; + out->Resize(phi::make_ddim(output_dim)); + phi::IndexSelectInner(ctx, &x_copy, index, out, dim); + } +#if defined(__NVCC__) || defined(__HIPCC__) + else { + auto stride_dim = phi::stride(input_dim); + int64_t stride = stride_dim[dim]; + paddle::framework::TensorFromVector(index_vec, ctx, &index); + auto stream = ctx.stream(); + auto output_dim = phi::vectorize(x.dims()); + output_dim[dim] = index_size; + out->Resize(phi::make_ddim(output_dim)); + ctx.template Alloc(out); + auto* out_data = out->data(); + int64_t numel = out->numel(); + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + const int* index_data = index.data(); + index_select_cuda_kernel + <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>( + x.data(), out_data, index_data, numel, stride, size, delta); + } +#endif +} + +template +void RepeatInterleaveWithTensorIndexKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& repeats_tensor, + int dim, + DenseTensor* out) { + auto place = ctx.GetPlace(); + auto cpu_place = phi::CPUPlace(); + + auto input_dim = x.dims(); + if (dim < 0) { + dim += input_dim.size(); + } + DenseTensor index; + PADDLE_ENFORCE_EQ(repeats_tensor.dims()[0] == x.dims()[dim], + true, + phi::errors::InvalidArgument( + "The length of Input(RepeatsTensor) must be the " + "same as length of Input(X) in axis. " + "But received: [%s], required: [%d].", + repeats_tensor.dims()[0], + x.dims()[dim])); + const auto& index_type = + paddle::framework::TransToProtoVarType(repeats_tensor.dtype()); + bool index_type_match = + index_type == paddle::framework::proto::VarType::INT32 || + index_type == paddle::framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + phi::errors::InvalidArgument( + "Input(RepeatsTensor) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + paddle::framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + paddle::framework::proto::VarType::INT64))); + if (place == cpu_place) { + auto x_copy = x; + if (index_type == paddle::framework::proto::VarType::INT32) { + phi::funcs::RepeatsTensor2IndexTensor( + ctx, repeats_tensor, &index); + auto output_dim = phi::vectorize(x.dims()); + output_dim[dim] = index.dims()[0]; + out->Resize(phi::make_ddim(output_dim)); + IndexSelectInner(ctx, &x_copy, index, out, dim); + } else if (index_type == paddle::framework::proto::VarType::INT64) { + phi::funcs::RepeatsTensor2IndexTensor( + ctx, repeats_tensor, &index); + auto output_dim = phi::vectorize(x.dims()); + output_dim[dim] = index.dims()[0]; + out->Resize(phi::make_ddim(output_dim)); + IndexSelectInner(ctx, &x_copy, index, out, dim); + } + } +#if defined(__NVCC__) || defined(__HIPCC__) + else { + auto stride_dim = phi::stride(input_dim); + int64_t stride = stride_dim[dim]; + auto stream = ctx.stream(); + auto* in_data = x.data(); + if (index_type == paddle::framework::proto::VarType::INT64) { + phi::funcs::RepeatsTensor2IndexTensor( + ctx, repeats_tensor, &index); + + const int64_t* index_data = index.data(); + auto output_dim = phi::vectorize(x.dims()); + output_dim[dim] = index.dims()[0]; + out->Resize(phi::make_ddim(output_dim)); + T* out_data = ctx.template Alloc(out); + int64_t numel = out->numel(); + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + index_select_cuda_kernel + <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>( + in_data, out_data, index_data, numel, stride, size, delta); + } else { + phi::funcs::RepeatsTensor2IndexTensor( + ctx, repeats_tensor, &index); + + const int* index_data = index.data(); + auto output_dim = phi::vectorize(x.dims()); + output_dim[dim] = index.dims()[0]; + out->Resize(phi::make_ddim(output_dim)); + T* out_data = ctx.template Alloc(out); + int64_t numel = out->numel(); + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + index_select_cuda_kernel + <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>( + in_data, out_data, index_data, numel, stride, size, delta); + } + } +#endif +} + +} // namespace phi diff --git a/paddle/phi/kernels/repeat_interleave_grad_kernel.h b/paddle/phi/kernels/repeat_interleave_grad_kernel.h new file mode 100644 index 0000000000..75f493bd99 --- /dev/null +++ b/paddle/phi/kernels/repeat_interleave_grad_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RepeatInterleaveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int repeats, + int dim, + DenseTensor* x_grad); + +template +void RepeatInterleaveWithTensorIndexGradKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& repeats_tensor, + const DenseTensor& out_grad, + int dim, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/repeat_interleave_kernel.h b/paddle/phi/kernels/repeat_interleave_kernel.h new file mode 100644 index 0000000000..871b720800 --- /dev/null +++ b/paddle/phi/kernels/repeat_interleave_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RepeatInterleaveKernel(const Context& dev_ctx, + const DenseTensor& x, + int repeats, + int dim, + DenseTensor* out); + +template +void RepeatInterleaveWithTensorIndexKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& repeat_tensor, + int dim, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/repeat_interleave_sig.cc b/paddle/phi/ops/compat/repeat_interleave_sig.cc new file mode 100644 index 0000000000..ad087ed467 --- /dev/null +++ b/paddle/phi/ops/compat/repeat_interleave_sig.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RepeatInterleaveOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.HasInput("RepeatsTensor")) { + VLOG(3) << "sig------ repeat_interleave_with_tensor_index"; + return KernelSignature("repeat_interleave_with_tensor_index", + {"X", "RepeatsTensor"}, + {"dim"}, + {"Out"}); + } else { + VLOG(3) << "sig ------repeat_interleave"; + return KernelSignature( + "repeat_interleave", {"X"}, {"Repeats", "dim"}, {"Out"}); + } +} + +KernelSignature RepeatInterleaveGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.HasInput("RepeatsTensor")) { + VLOG(3) << "sig ------repeat_interleave with tensor grad"; + return KernelSignature("repeat_interleave_with_tensor_index_grad", + {"X", "RepeatsTensor", "Out@GRAD"}, + {"dim"}, + {"X@GRAD"}); + } else { + VLOG(3) << "sig repeat_interleave grad"; + return KernelSignature("repeat_interleave_grad", + {"X", "Out@GRAD"}, + {"Repeats", "dim"}, + {"X@GRAD"}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(repeat_interleave, + phi::RepeatInterleaveOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(repeat_interleave_grad, + phi::RepeatInterleaveGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py b/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py index 7abc758617..28e49c7d16 100644 --- a/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py +++ b/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py @@ -27,10 +27,12 @@ class TestRepeatInterleaveOp(OpTest): def setUp(self): self.op_type = "repeat_interleave" + self.python_api = paddle.repeat_interleave self.init_dtype_type() index_np = np.random.randint( low=0, high=3, size=self.index_size).astype(self.index_type) x_np = np.random.random(self.x_shape).astype(self.x_type) + self.inputs = {'X': x_np, 'RepeatsTensor': index_np} self.attrs = {'dim': self.dim} @@ -57,16 +59,17 @@ class TestRepeatInterleaveOp(OpTest): self.index_size = self.x_shape[self.dim] def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestRepeatInterleaveOp2(OpTest): def setUp(self): self.op_type = "repeat_interleave" + self.python_api = paddle.repeat_interleave self.init_dtype_type() index_np = 2 x_np = np.random.random(self.x_shape).astype(self.x_type) @@ -95,10 +98,10 @@ class TestRepeatInterleaveOp2(OpTest): self.index_size = self.x_shape[self.dim] def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestIndexSelectAPI(unittest.TestCase): @@ -115,7 +118,7 @@ class TestIndexSelectAPI(unittest.TestCase): # case 1: with program_guard(Program(), Program()): x = fluid.layers.data(name='x', shape=[-1, 4]) - index = fluid.layers.data(name='repeats', + index = fluid.layers.data(name='repeats_', shape=[4], dtype='int32', append_batch_size=False) @@ -123,7 +126,7 @@ class TestIndexSelectAPI(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace()) res, = exe.run(feed={ 'x': self.data_x, - 'repeats': self.data_index + 'repeats_': self.data_index }, fetch_list=[z.name], return_numpy=False) @@ -134,7 +137,7 @@ class TestIndexSelectAPI(unittest.TestCase): repeats = np.array([1, 2, 1]).astype('int32') with program_guard(Program(), Program()): x = fluid.layers.data(name='x', shape=[-1, 4]) - index = fluid.layers.data(name='repeats', + index = fluid.layers.data(name='repeats_', shape=[3], dtype='int32', append_batch_size=False) @@ -142,7 +145,7 @@ class TestIndexSelectAPI(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace()) res, = exe.run(feed={ 'x': self.data_x, - 'repeats': repeats, + 'repeats_': repeats, }, fetch_list=[z.name], return_numpy=False) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index b58ba75270..9da7f76e70 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4003,12 +4003,11 @@ def repeat_interleave(x, repeats, axis=None, name=None): x = paddle.flatten(x) axis = 0 - if paddle.in_dynamic_mode(): - if isinstance(repeats, int): - return _C_ops.repeat_interleave(x, None, 'Repeats', repeats, 'dim', - axis) - elif isinstance(repeats, Variable): - return _C_ops.repeat_interleave(x, repeats, 'dim', axis) + if in_dygraph_mode(): + if isinstance(repeats, Variable): + return _C_ops.final_state_repeat_interleave_with_tensor_index( + x, repeats, axis) + return _C_ops.final_state_repeat_interleave(x, repeats, axis) helper = LayerHelper("repeat_interleave", **locals()) check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], -- GitLab