未验证 提交 3b32835f 编写于 作者: S seemingwang 提交者: GitHub

move repeat interleave (#44753)

* move repeat interleave

* fix api name

* recover op registration

* fix arguments order

* fix

* fix infermeta

* fix infermeta

* fix header

* fix infermeta

* fix

* fix

* fix dtype

* log&test

* test

* remove logs

* fix

* remove logs

* combine files

* combine

* combine files

* fix cuda place
上级 63df05d3
......@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/repeat_interleave_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_select_op.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
......@@ -164,22 +166,13 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(RepeatInterleaveGradNoNeedBufferVarsInferer,
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(repeat_interleave,
ops::RepeatInterleaveOp,
ops::RepeatInterleaveOpMaker,
ops::RepeatInterleaveGradMaker<paddle::framework::OpDesc>,
ops::RepeatInterleaveGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(repeat_interleave_grad,
ops::RepeatInterleaveGradOp,
ops::RepeatInterleaveGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(repeat_interleave,
ops::RepeatInterleaveKernel<phi::CPUContext, float>,
ops::RepeatInterleaveKernel<phi::CPUContext, double>,
ops::RepeatInterleaveKernel<phi::CPUContext, int>,
ops::RepeatInterleaveKernel<phi::CPUContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
repeat_interleave_grad,
ops::RepeatInterleaveGradKernel<phi::CPUContext, float>,
ops::RepeatInterleaveGradKernel<phi::CPUContext, double>,
ops::RepeatInterleaveGradKernel<phi::CPUContext, int>,
ops::RepeatInterleaveGradKernel<phi::CPUContext, int64_t>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/repeat_interleave_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
// function borrowed from repeat_interleave_op
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input,
T* output,
const IndexT* index,
int64_t N,
int64_t stride,
int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index,
int64_t nums,
int64_t N,
int64_t stride,
int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}
template <typename DeviceContext, typename T>
class RepeatInterleaveCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
// auto* index = context.Input<LoDTensor>("RepeatsTensor");
auto* out = context.Output<LoDTensor>("Out");
int dim = context.Attr<int>("dim");
auto input_dim = in->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
auto stream = context.template device_context<phi::GPUContext>().stream();
int repeats = context.Attr<int>("Repeats");
framework::LoDTensor index;
auto* in_data = in->data<T>();
if (context.HasInput("RepeatsTensor")) {
auto repeats_tensor =
context.Input<framework::LoDTensor>("RepeatsTensor");
PADDLE_ENFORCE_EQ(repeats_tensor->dims()[0] == in->dims()[dim],
true,
platform::errors::InvalidArgument(
"The length of Input(RepeatsTensor) must be the "
"same as length of Input(X) in axis. "
"But received: [%s], required: [%d].",
repeats_tensor->dims()[0],
in->dims()[dim]));
const auto& index_type =
framework::TransToProtoVarType(repeats_tensor->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
platform::errors::InvalidArgument(
"Input(RepeatsTensor) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT64) {
RepeatsTensor2IndexTensor<DeviceContext, int64_t>(*repeats_tensor,
&index);
const int64_t* index_data = index.data<int64_t>();
auto output_dim = phi::vectorize(in->dims());
output_dim[dim] = index.dims()[0];
out->Resize(phi::make_ddim(output_dim));
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = out->numel();
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
index_select_cuda_kernel<T, int64_t>
<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
} else {
RepeatsTensor2IndexTensor<DeviceContext, int>(*repeats_tensor, &index);
const int* index_data = index.data<int>();
auto output_dim = phi::vectorize(in->dims());
output_dim[dim] = index.dims()[0];
out->Resize(phi::make_ddim(output_dim));
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = out->numel();
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
index_select_cuda_kernel<T, int>
<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
}
} else if (repeats > 0) {
int64_t index_size = in->dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < in->dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(phi::make_ddim({index_size}));
auto ctx = paddle::platform::DeviceContextPool::Instance().Get(
context.GetPlace());
paddle::framework::TensorFromVector<int>(index_vec, *ctx, &index);
auto output_dim = phi::vectorize(in->dims());
output_dim[dim] = index_size;
out->Resize(phi::make_ddim(output_dim));
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = out->numel();
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const int* index_data = index.data<int>();
index_select_cuda_kernel<T, int>
<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
platform::GpuStreamSync(stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"repeats must given with RepeatsTensor (tensor) or repeats (int)"));
}
}
};
template <typename DeviceContext, typename T>
class RepeatInterleaveGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* output_grad = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_grad = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* output_grad_data = output_grad->data<T>();
auto* in_grad_data = in_grad->mutable_data<T>(context.GetPlace());
int dim = context.Attr<int>("dim");
auto input_dim = in_grad->dims();
auto output_dim = output_grad->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
int64_t numel = in_grad->numel();
int64_t out_nums = output_grad->numel();
auto stream = context.template device_context<phi::GPUContext>().stream();
index_select_grad_init<T>
<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_grad_data, numel);
int repeats = context.Attr<int>("Repeats");
framework::LoDTensor index;
if (context.HasInput("RepeatsTensor")) {
auto repeats_tensor =
context.Input<framework::LoDTensor>("RepeatsTensor");
const auto& index_type =
framework::TransToProtoVarType(repeats_tensor->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT64) {
RepeatsTensor2IndexTensor<DeviceContext, int64_t>(*repeats_tensor,
&index);
int64_t index_nums = index.numel();
const int64_t* index_data = index.data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t>
<<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
platform::GpuStreamSync(stream);
} else {
RepeatsTensor2IndexTensor<DeviceContext, int>(*repeats_tensor, &index);
int64_t index_nums = index.numel();
const int* index_data = index.data<int>();
index_select_grad_cuda_kernel<T, int>
<<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
platform::GpuStreamSync(stream);
}
} else if (repeats > 0) {
int64_t index_size = in_grad->dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < in_grad->dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(phi::make_ddim({index_size}));
auto ctx = paddle::platform::DeviceContextPool::Instance().Get(
context.GetPlace());
paddle::framework::TensorFromVector<int>(index_vec, *ctx, &index);
const int* index_data = index.data<int>();
int64_t index_nums = index.numel();
index_select_grad_cuda_kernel<T, int>
<<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
platform::GpuStreamSync(stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"repeats must given with RepeatsTensor (tensor) or repeats (int)"));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
repeat_interleave,
ops::RepeatInterleaveCUDAKernel<phi::GPUContext, float>,
ops::RepeatInterleaveCUDAKernel<phi::GPUContext, double>,
ops::RepeatInterleaveCUDAKernel<phi::GPUContext, paddle::platform::float16>,
ops::RepeatInterleaveCUDAKernel<phi::GPUContext, int>,
ops::RepeatInterleaveCUDAKernel<phi::GPUContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
repeat_interleave_grad,
ops::RepeatInterleaveGradCUDAKernel<phi::GPUContext, float>,
ops::RepeatInterleaveGradCUDAKernel<phi::GPUContext, double>,
ops::RepeatInterleaveGradCUDAKernel<phi::GPUContext,
paddle::platform::float16>,
ops::RepeatInterleaveGradCUDAKernel<phi::GPUContext, int>,
ops::RepeatInterleaveGradCUDAKernel<phi::GPUContext, int64_t>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_select_op.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename RepeatsT = int>
void RepeatsTensor2IndexTensor(const LoDTensor& repeats, LoDTensor* index) {
LoDTensor repeats_cpu_copy;
if (!platform::is_cpu_place(repeats.place())) {
framework::TensorCopySync(repeats, platform::CPUPlace(), &repeats_cpu_copy);
}
const RepeatsT* repeats_data = platform::is_cpu_place(repeats.place())
? repeats.data<RepeatsT>()
: repeats_cpu_copy.data<RepeatsT>();
int64_t index_size = 0;
for (int i = 0; i < repeats.dims()[0]; i++) {
index_size += repeats_data[i];
}
std::vector<RepeatsT> index_vec(index_size);
int offset = 0;
for (int i = 0; i < repeats.dims()[0]; i++) {
std::fill_n(index_vec.begin() + offset, repeats_data[i], i);
offset += repeats_data[i];
}
index->Resize(phi::make_ddim({index_size}));
auto ctx =
paddle::platform::DeviceContextPool::Instance().Get(repeats.place());
paddle::framework::TensorFromVector<RepeatsT>(index_vec, *ctx, index);
}
template <typename DeviceContext, typename T>
class RepeatInterleaveKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto inputs = *context.Input<framework::LoDTensor>("X");
auto* output = context.Output<framework::LoDTensor>("Out");
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += inputs.dims().size();
}
int repeats = context.Attr<int>("Repeats");
framework::LoDTensor index;
if (context.HasInput("RepeatsTensor")) {
auto repeats_tensor =
context.Input<framework::LoDTensor>("RepeatsTensor");
PADDLE_ENFORCE_EQ(repeats_tensor->dims()[0] == inputs.dims()[dim],
true,
platform::errors::InvalidArgument(
"The length of Input(RepeatsTensor) must be the "
"same as length of Input(X) in axis. "
"But received: [%s], required: [%d].",
repeats_tensor->dims()[0],
inputs.dims()[dim]));
const auto& index_type =
framework::TransToProtoVarType(repeats_tensor->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
platform::errors::InvalidArgument(
"Input(RepeatsTensor) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
RepeatsTensor2IndexTensor<DeviceContext, int>(*repeats_tensor, &index);
auto output_dim = phi::vectorize(inputs.dims());
output_dim[dim] = index.dims()[0];
output->Resize(phi::make_ddim(output_dim));
IndexSelectInner<DeviceContext, T, int>(
context, &inputs, index, output, dim);
} else if (index_type == framework::proto::VarType::INT64) {
RepeatsTensor2IndexTensor<DeviceContext, int64_t>(*repeats_tensor,
&index);
auto output_dim = phi::vectorize(inputs.dims());
output_dim[dim] = index.dims()[0];
output->Resize(phi::make_ddim(output_dim));
IndexSelectInner<DeviceContext, T, int64_t>(
context, &inputs, index, output, dim);
}
} else if (repeats > 0) {
int64_t index_size = inputs.dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < inputs.dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(phi::make_ddim({index_size}));
paddle::framework::TensorFromVector<int>(index_vec, &index);
auto output_dim = phi::vectorize(inputs.dims());
output_dim[dim] = index_size;
output->Resize(phi::make_ddim(output_dim));
IndexSelectInner<DeviceContext, T, int>(
context, &inputs, index, output, dim);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"repeats must given with RepeatsTensor (tensor) or repeats (int)"));
}
}
};
template <typename DeviceContext, typename T>
class RepeatInterleaveGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_grad =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* out_grad =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += out_grad->dims().size();
}
int repeats = context.Attr<int>("Repeats");
framework::LoDTensor index;
if (context.HasInput("RepeatsTensor")) {
auto repeats_tensor =
context.Input<framework::LoDTensor>("RepeatsTensor");
const auto& index_type =
framework::TransToProtoVarType(repeats_tensor->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
platform::errors::InvalidArgument(
"Input(Repeats) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
RepeatsTensor2IndexTensor<DeviceContext, int>(*repeats_tensor, &index);
IndexSelectGradInner<DeviceContext, T, int>(
context, *out_grad, index, x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) {
RepeatsTensor2IndexTensor<DeviceContext, int64_t>(*repeats_tensor,
&index);
IndexSelectGradInner<DeviceContext, T, int64_t>(
context, *out_grad, index, x_grad, dim);
}
} else if (repeats > 0) {
int64_t index_size = x_grad->dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < x_grad->dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(phi::make_ddim({index_size}));
paddle::framework::TensorFromVector<int>(index_vec, &index);
IndexSelectGradInner<DeviceContext, T, int>(
context, *out_grad, index, x_grad, dim);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"repeats must given with RepeatsTensor (tensor) or repeats (int)"));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -1987,6 +1987,27 @@
func : renorm
backward : renorm_grad
- api : repeat_interleave
args : (Tensor x, int repeats, int dim)
output : Tensor(out)
infer_meta :
func : RepeatInterleaveInferMeta
param : [x,repeats, dim]
kernel :
func : repeat_interleave
backward: repeat_interleave_grad
- api : repeat_interleave_with_tensor_index
args : (Tensor x, Tensor repeats, int dim)
output : Tensor(out)
infer_meta :
func : RepeatInterleaveWithTensorIndexInferMeta
param : [x,repeats, dim]
kernel :
func : repeat_interleave_with_tensor_index
data_type : x
backward: repeat_interleave_with_tensor_index_grad
- api : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out), Tensor(xshape)
......
......@@ -1802,6 +1802,27 @@
kernel :
func : renorm_grad
- backward_api : repeat_interleave_grad
forward : repeat_interleave(Tensor x, int repeats, int dim) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int repeats, int dim)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : repeat_interleave_grad
- backward_api : repeat_interleave_with_tensor_index_grad
forward : repeat_interleave_with_tensor_index(Tensor x, Tensor repeats, int dim) -> Tensor(out)
args : (Tensor x, Tensor repeats, Tensor out_grad, int dim)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : repeat_interleave_with_tensor_index_grad
data_type : x
- backward_api : reshape_double_grad
forward : reshape_grad (Tensor xshape, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor grad_out, Tensor grad_x_grad)
......
......@@ -2017,6 +2017,52 @@ void PriorBoxInferMeta(const MetaTensor& input,
var->set_dims(phi::make_ddim(dim_vec));
}
void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x,
const MetaTensor& repeats,
int dim,
MetaTensor* out) {
const auto& input_dim = x.dims();
auto output_dim = phi::vectorize(input_dim);
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()),
true,
phi::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(),
input_dim.size() - 1,
dim));
auto repeats_dim = repeats.dims();
PADDLE_ENFORCE_EQ(
repeats_dim.size() == 1 ||
(repeats_dim.size() == 2 && repeats_dim[1] == 1),
true,
phi::errors::InvalidArgument(
"The 'shape' of Input(RepeatsTensor) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
repeats_dim,
repeats_dim.size()));
PADDLE_ENFORCE_EQ(repeats_dim[0] != 0,
true,
phi::errors::InvalidArgument(
"The length of Input(RepeatsTensor) can't be 0."));
PADDLE_ENFORCE_NE(out,
nullptr,
phi::errors::InvalidArgument(
"repeat_interleave's output tensor can't be nullptr"));
if (dim < 0) {
dim += input_dim.size();
}
output_dim[dim] = -1;
out->set_dims(phi::make_ddim(output_dim));
out->share_lod(x);
out->set_dtype(x.dtype());
}
void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
const MetaTensor& value,
bool out_int32,
......
......@@ -279,6 +279,10 @@ void PReluInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x,
const MetaTensor& repeats,
int dim,
MetaTensor* out);
void PriorBoxInferMeta(const MetaTensor& input,
const MetaTensor& image,
const std::vector<float>& min_sizes,
......
......@@ -2438,6 +2438,37 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout());
}
void RepeatInterleaveInferMeta(const MetaTensor& x,
int repeats,
int dim,
MetaTensor* out) {
const auto& input_dim = x.dims();
auto output_dim = phi::vectorize(input_dim);
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()),
true,
phi::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(),
input_dim.size() - 1,
dim));
PADDLE_ENFORCE_EQ(
repeats > 0,
true,
phi::errors::InvalidArgument("repeats should be larger than zero"));
PADDLE_ENFORCE_NE(out,
nullptr,
phi::errors::InvalidArgument(
"repeat_interleave's output tensor can't be nullptr"));
output_dim[dim] = input_dim[dim] * repeats;
out->set_dims(phi::make_ddim(output_dim));
out->share_lod(x);
out->set_dtype(x.dtype());
}
void ReshapeInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out,
......
......@@ -322,6 +322,11 @@ void ReduceInferMetaBase(const MetaTensor& x,
bool reduce_all,
MetaTensor* out);
void RepeatInterleaveInferMeta(const MetaTensor& x,
int repeats,
int dim,
MetaTensor* out);
void ReshapeInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/repeat_interleave_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/cpu/index_select_impl.h"
#include "paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h"
namespace phi {
template <typename T, typename Context>
void RepeatInterleaveWithTensorIndexGradKernel(
const Context& ctx,
const DenseTensor& x,
const DenseTensor& repeats_tensor,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad) {
auto input_dim = x_grad->dims();
if (dim < 0) {
dim += input_dim.size();
}
DenseTensor index;
PADDLE_ENFORCE_EQ(repeats_tensor.dims()[0] == x_grad->dims()[dim],
true,
phi::errors::InvalidArgument(
"The length of Input(RepeatsTensor) must be the "
"same as length of Input(X) in axis. "
"But received: [%s], required: [%d].",
repeats_tensor.dims()[0],
x_grad->dims()[dim]));
const auto& index_type =
paddle::framework::TransToProtoVarType(repeats_tensor.dtype());
bool index_type_match =
index_type == paddle::framework::proto::VarType::INT32 ||
index_type == paddle::framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Repeats) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64)));
paddle::platform::DeviceContextPool::Instance().Get(repeats_tensor.place());
if (index_type == paddle::framework::proto::VarType::INT32) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int>(
ctx, repeats_tensor, &index);
IndexSelectGradInner<Context, T, int>(ctx, out_grad, index, x_grad, dim);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int64_t>(
ctx, repeats_tensor, &index);
IndexSelectGradInner<Context, T, int64_t>(
ctx, out_grad, index, x_grad, dim);
}
}
template <typename T, typename Context>
void RepeatInterleaveGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int repeats,
int dim,
DenseTensor* x_grad) {
auto input_dim = x_grad->dims();
if (dim < 0) {
dim += input_dim.size();
}
DenseTensor index;
int64_t index_size = x_grad->dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < x_grad->dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(phi::make_ddim({index_size}));
paddle::framework::TensorFromVector<int>(index_vec, &index);
const DenseTensor index_copy = index;
IndexSelectGradInner<Context, T, int>(ctx, out_grad, index_copy, x_grad, dim);
}
} // namespace phi
PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad,
CPU,
ALL_LAYOUT,
phi::RepeatInterleaveWithTensorIndexGradKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(repeat_interleave_grad,
CPU,
ALL_LAYOUT,
phi::RepeatInterleaveGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/repeat_interleave_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h"
PD_REGISTER_KERNEL(repeat_interleave,
CPU,
ALL_LAYOUT,
phi::RepeatInterleaveKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index,
CPU,
ALL_LAYOUT,
phi::RepeatInterleaveWithTensorIndexKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace funcs {
template <typename Context, typename RepeatsT = int>
void RepeatsTensor2IndexTensor(const Context& ctx,
const DenseTensor& repeats,
DenseTensor* index) {
DenseTensor repeats_cpu_copy;
if (!paddle::platform::is_cpu_place(repeats.place())) {
phi::Copy(
ctx, repeats, paddle::platform::CPUPlace(), true, &repeats_cpu_copy);
}
const RepeatsT* repeats_data = paddle::platform::is_cpu_place(repeats.place())
? repeats.data<RepeatsT>()
: repeats_cpu_copy.data<RepeatsT>();
int64_t index_size = 0;
for (int i = 0; i < repeats.dims()[0]; i++) {
index_size += repeats_data[i];
}
std::vector<RepeatsT> index_vec(index_size);
int offset = 0;
for (int i = 0; i < repeats.dims()[0]; i++) {
std::fill_n(index_vec.begin() + offset, repeats_data[i], i);
offset += repeats_data[i];
}
index->Resize(phi::make_ddim({index_size}));
paddle::framework::TensorFromVector<RepeatsT>(index_vec, ctx, index);
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/repeat_interleave_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h"
PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad,
GPU,
ALL_LAYOUT,
phi::RepeatInterleaveWithTensorIndexGradKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(repeat_interleave_grad,
GPU,
ALL_LAYOUT,
phi::RepeatInterleaveGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/repeat_interleave_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h"
PD_REGISTER_KERNEL(repeat_interleave,
GPU,
ALL_LAYOUT,
phi::RepeatInterleaveKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index,
GPU,
ALL_LAYOUT,
phi::RepeatInterleaveWithTensorIndexKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/cpu/index_select_impl.h"
#include "paddle/phi/kernels/repeat_interleave_grad_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#endif
#include "paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h"
namespace phi {
#if defined(__NVCC__) || defined(__HIPCC__)
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index,
int64_t nums,
int64_t N,
int64_t stride,
int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}
#endif
template <typename T, typename Context>
void RepeatInterleaveWithTensorIndexGradKernel(
const Context& ctx,
const DenseTensor& x,
const DenseTensor& repeats_tensor,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad) {
auto place = ctx.GetPlace();
auto cpu_place = phi::CPUPlace();
auto input_dim = x_grad->dims();
if (dim < 0) {
dim += input_dim.size();
}
DenseTensor index;
PADDLE_ENFORCE_EQ(repeats_tensor.dims()[0] == x_grad->dims()[dim],
true,
phi::errors::InvalidArgument(
"The length of Input(RepeatsTensor) must be the "
"same as length of Input(X) in axis. "
"But received: [%s], required: [%d].",
repeats_tensor.dims()[0],
x_grad->dims()[dim]));
const auto& index_type =
paddle::framework::TransToProtoVarType(repeats_tensor.dtype());
bool index_type_match =
index_type == paddle::framework::proto::VarType::INT32 ||
index_type == paddle::framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Repeats) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64)));
#if defined(__NVCC__) || defined(__HIPCC__)
auto output_dim = out_grad.dims();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
int64_t numel = x_grad->numel();
int64_t out_nums = out_grad.numel();
auto* out_grad_data = out_grad.data<T>();
ctx.template Alloc<T>(x_grad);
auto* in_grad_data = x_grad->data<T>();
auto stream = ctx.stream();
index_select_grad_init<T>
<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_grad_data, numel);
if (index_type == paddle::framework::proto::VarType::INT64) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int64_t>(
ctx, repeats_tensor, &index);
int64_t index_nums = index.numel();
const int64_t* index_data = index.data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t>
<<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(out_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
} else {
phi::funcs::RepeatsTensor2IndexTensor<Context, int>(
ctx, repeats_tensor, &index);
int64_t index_nums = index.numel();
const int* index_data = index.data<int>();
index_select_grad_cuda_kernel<T, int>
<<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(out_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
}
#endif
}
template <typename T, typename Context>
void RepeatInterleaveGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int repeats,
int dim,
DenseTensor* x_grad) {
auto place = ctx.GetPlace();
auto cpu_place = phi::CPUPlace();
auto input_dim = x_grad->dims();
if (dim < 0) {
dim += input_dim.size();
}
DenseTensor index;
#if defined(__NVCC__) || defined(__HIPCC__)
auto output_dim = out_grad.dims();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
int64_t numel = x_grad->numel();
int64_t out_nums = out_grad.numel();
auto* out_grad_data = out_grad.data<T>();
ctx.template Alloc<T>(x_grad);
auto* in_grad_data = x_grad->data<T>();
auto stream = ctx.stream();
index_select_grad_init<T>
<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_grad_data, numel);
int64_t index_size = x_grad->dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < x_grad->dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(phi::make_ddim({index_size}));
paddle::framework::TensorFromVector<int>(index_vec, ctx, &index);
const int* index_data = index.data<int>();
int64_t index_nums = index.numel();
index_select_grad_cuda_kernel<T, int>
<<<(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(out_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
#endif
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/cpu/index_select_impl.h"
#include "paddle/phi/kernels/repeat_interleave_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#endif
#include "paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h"
namespace phi {
#if defined(__NVCC__) || defined(__HIPCC__)
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input,
T* output,
const IndexT* index,
int64_t N,
int64_t stride,
int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}
#endif
template <typename T, typename Context>
void RepeatInterleaveKernel(const Context& ctx,
const DenseTensor& x,
int repeats,
int dim,
DenseTensor* out) {
auto place = ctx.GetPlace();
auto cpu_place = phi::CPUPlace();
auto input_dim = x.dims();
if (dim < 0) {
dim += input_dim.size();
}
DenseTensor index;
int64_t index_size = input_dim[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < input_dim[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(phi::make_ddim({index_size}));
if (place == cpu_place) {
DenseTensor x_copy = x;
paddle::framework::TensorFromVector<int>(index_vec, &index);
auto output_dim = phi::vectorize(x.dims());
output_dim[dim] = index_size;
out->Resize(phi::make_ddim(output_dim));
phi::IndexSelectInner<Context, T, int>(ctx, &x_copy, index, out, dim);
}
#if defined(__NVCC__) || defined(__HIPCC__)
else {
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
paddle::framework::TensorFromVector<int>(index_vec, ctx, &index);
auto stream = ctx.stream();
auto output_dim = phi::vectorize(x.dims());
output_dim[dim] = index_size;
out->Resize(phi::make_ddim(output_dim));
ctx.template Alloc<T>(out);
auto* out_data = out->data<T>();
int64_t numel = out->numel();
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const int* index_data = index.data<int>();
index_select_cuda_kernel<T, int>
<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
x.data<T>(), out_data, index_data, numel, stride, size, delta);
}
#endif
}
template <typename T, typename Context>
void RepeatInterleaveWithTensorIndexKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& repeats_tensor,
int dim,
DenseTensor* out) {
auto place = ctx.GetPlace();
auto cpu_place = phi::CPUPlace();
auto input_dim = x.dims();
if (dim < 0) {
dim += input_dim.size();
}
DenseTensor index;
PADDLE_ENFORCE_EQ(repeats_tensor.dims()[0] == x.dims()[dim],
true,
phi::errors::InvalidArgument(
"The length of Input(RepeatsTensor) must be the "
"same as length of Input(X) in axis. "
"But received: [%s], required: [%d].",
repeats_tensor.dims()[0],
x.dims()[dim]));
const auto& index_type =
paddle::framework::TransToProtoVarType(repeats_tensor.dtype());
bool index_type_match =
index_type == paddle::framework::proto::VarType::INT32 ||
index_type == paddle::framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
phi::errors::InvalidArgument(
"Input(RepeatsTensor) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64)));
if (place == cpu_place) {
auto x_copy = x;
if (index_type == paddle::framework::proto::VarType::INT32) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int>(
ctx, repeats_tensor, &index);
auto output_dim = phi::vectorize(x.dims());
output_dim[dim] = index.dims()[0];
out->Resize(phi::make_ddim(output_dim));
IndexSelectInner<Context, T, int>(ctx, &x_copy, index, out, dim);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int64_t>(
ctx, repeats_tensor, &index);
auto output_dim = phi::vectorize(x.dims());
output_dim[dim] = index.dims()[0];
out->Resize(phi::make_ddim(output_dim));
IndexSelectInner<Context, T, int64_t>(ctx, &x_copy, index, out, dim);
}
}
#if defined(__NVCC__) || defined(__HIPCC__)
else {
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
auto stream = ctx.stream();
auto* in_data = x.data<T>();
if (index_type == paddle::framework::proto::VarType::INT64) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int64_t>(
ctx, repeats_tensor, &index);
const int64_t* index_data = index.data<int64_t>();
auto output_dim = phi::vectorize(x.dims());
output_dim[dim] = index.dims()[0];
out->Resize(phi::make_ddim(output_dim));
T* out_data = ctx.template Alloc<T>(out);
int64_t numel = out->numel();
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
index_select_cuda_kernel<T, int64_t>
<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
} else {
phi::funcs::RepeatsTensor2IndexTensor<Context, int>(
ctx, repeats_tensor, &index);
const int* index_data = index.data<int>();
auto output_dim = phi::vectorize(x.dims());
output_dim[dim] = index.dims()[0];
out->Resize(phi::make_ddim(output_dim));
T* out_data = ctx.template Alloc<T>(out);
int64_t numel = out->numel();
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
index_select_cuda_kernel<T, int>
<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
}
}
#endif
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void RepeatInterleaveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int repeats,
int dim,
DenseTensor* x_grad);
template <typename T, typename Context>
void RepeatInterleaveWithTensorIndexGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& repeats_tensor,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void RepeatInterleaveKernel(const Context& dev_ctx,
const DenseTensor& x,
int repeats,
int dim,
DenseTensor* out);
template <typename T, typename Context>
void RepeatInterleaveWithTensorIndexKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& repeat_tensor,
int dim,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature RepeatInterleaveOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("RepeatsTensor")) {
VLOG(3) << "sig------ repeat_interleave_with_tensor_index";
return KernelSignature("repeat_interleave_with_tensor_index",
{"X", "RepeatsTensor"},
{"dim"},
{"Out"});
} else {
VLOG(3) << "sig ------repeat_interleave";
return KernelSignature(
"repeat_interleave", {"X"}, {"Repeats", "dim"}, {"Out"});
}
}
KernelSignature RepeatInterleaveGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("RepeatsTensor")) {
VLOG(3) << "sig ------repeat_interleave with tensor grad";
return KernelSignature("repeat_interleave_with_tensor_index_grad",
{"X", "RepeatsTensor", "Out@GRAD"},
{"dim"},
{"X@GRAD"});
} else {
VLOG(3) << "sig repeat_interleave grad";
return KernelSignature("repeat_interleave_grad",
{"X", "Out@GRAD"},
{"Repeats", "dim"},
{"X@GRAD"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(repeat_interleave,
phi::RepeatInterleaveOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(repeat_interleave_grad,
phi::RepeatInterleaveGradOpArgumentMapping);
......@@ -27,10 +27,12 @@ class TestRepeatInterleaveOp(OpTest):
def setUp(self):
self.op_type = "repeat_interleave"
self.python_api = paddle.repeat_interleave
self.init_dtype_type()
index_np = np.random.randint(
low=0, high=3, size=self.index_size).astype(self.index_type)
x_np = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {'X': x_np, 'RepeatsTensor': index_np}
self.attrs = {'dim': self.dim}
......@@ -57,16 +59,17 @@ class TestRepeatInterleaveOp(OpTest):
self.index_size = self.x_shape[self.dim]
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
class TestRepeatInterleaveOp2(OpTest):
def setUp(self):
self.op_type = "repeat_interleave"
self.python_api = paddle.repeat_interleave
self.init_dtype_type()
index_np = 2
x_np = np.random.random(self.x_shape).astype(self.x_type)
......@@ -95,10 +98,10 @@ class TestRepeatInterleaveOp2(OpTest):
self.index_size = self.x_shape[self.dim]
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
class TestIndexSelectAPI(unittest.TestCase):
......@@ -115,7 +118,7 @@ class TestIndexSelectAPI(unittest.TestCase):
# case 1:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 4])
index = fluid.layers.data(name='repeats',
index = fluid.layers.data(name='repeats_',
shape=[4],
dtype='int32',
append_batch_size=False)
......@@ -123,7 +126,7 @@ class TestIndexSelectAPI(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={
'x': self.data_x,
'repeats': self.data_index
'repeats_': self.data_index
},
fetch_list=[z.name],
return_numpy=False)
......@@ -134,7 +137,7 @@ class TestIndexSelectAPI(unittest.TestCase):
repeats = np.array([1, 2, 1]).astype('int32')
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 4])
index = fluid.layers.data(name='repeats',
index = fluid.layers.data(name='repeats_',
shape=[3],
dtype='int32',
append_batch_size=False)
......@@ -142,7 +145,7 @@ class TestIndexSelectAPI(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={
'x': self.data_x,
'repeats': repeats,
'repeats_': repeats,
},
fetch_list=[z.name],
return_numpy=False)
......
......@@ -4003,12 +4003,11 @@ def repeat_interleave(x, repeats, axis=None, name=None):
x = paddle.flatten(x)
axis = 0
if paddle.in_dynamic_mode():
if isinstance(repeats, int):
return _C_ops.repeat_interleave(x, None, 'Repeats', repeats, 'dim',
axis)
elif isinstance(repeats, Variable):
return _C_ops.repeat_interleave(x, repeats, 'dim', axis)
if in_dygraph_mode():
if isinstance(repeats, Variable):
return _C_ops.final_state_repeat_interleave_with_tensor_index(
x, repeats, axis)
return _C_ops.final_state_repeat_interleave(x, repeats, axis)
helper = LayerHelper("repeat_interleave", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册