未验证 提交 a7de0e66 编写于 作者: K kuizhiqing 提交者: GitHub

add op/api repeat/interleave (#37981)

上级 885767e3
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/repeat_interleave_op.h"
#include <memory>
namespace paddle {
namespace operators {
using framework::Tensor;
class RepeatInterleaveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of RepeatInterleaveOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of RepeatInterleaveOp should not be null."));
auto input_dim = ctx->GetInputDim("X");
auto dim = ctx->Attrs().Get<int>("dim");
auto output_dim = framework::vectorize(input_dim);
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()), true,
platform::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(), input_dim.size() - 1, dim));
auto repeats = ctx->Attrs().Get<int>("Repeats");
if (ctx->HasInput("RepeatsTensor")) {
auto repeats_dim = ctx->GetInputDim("RepeatsTensor");
PADDLE_ENFORCE_EQ(
repeats_dim.size() == 1 ||
(repeats_dim.size() == 2 && repeats_dim[1] == 1),
true, platform::errors::InvalidArgument(
"The 'shape' of Input(RepeatsTensor) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
repeats_dim, repeats_dim.size()));
PADDLE_ENFORCE_EQ(repeats_dim[0] != 0, true,
platform::errors::InvalidArgument(
"The length of Input(RepeatsTensor) can't be 0."));
if (dim < 0) {
dim += input_dim.size();
}
output_dim[dim] = -1;
} else if (repeats > 0) {
output_dim[dim] = input_dim[dim] * repeats;
}
VLOG(3) << "infershap out " << output_dim[dim];
ctx->SetOutputDim("Out", framework::make_ddim(output_dim));
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class RepeatInterleaveGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(X@GRAD) should be not null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class RepeatInterleaveOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) the input tensor.");
AddInput("RepeatsTensor",
"the 1-D tensor containing the repeats alongsize the axis.")
.AsDispensable();
AddOutput("Out", "the output tensor.");
AddAttr<int>("Repeats", "the number of repetitions for each element.")
.SetDefault(0);
AddAttr<int>("dim", "the dimension in which we repeat.").SetDefault(0);
AddComment(R"DOC(
Returns a new tensor which repeats the input tensor
along dimension dim using the entries in repeats which
is a Tensor or int.
The returned tensor has the same number of dimensions
as the original tensor (input), except along the given axis.
)DOC");
}
};
template <typename T>
class RepeatInterleaveGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("repeat_interleave_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("RepeatsTensor", this->Input("RepeatsTensor"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(RepeatInterleaveGradNoNeedBufferVarsInferer,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(repeat_interleave, ops::RepeatInterleaveOp,
ops::RepeatInterleaveOpMaker,
ops::RepeatInterleaveGradMaker<paddle::framework::OpDesc>,
ops::RepeatInterleaveGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(repeat_interleave_grad, ops::RepeatInterleaveGradOp,
ops::RepeatInterleaveGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
repeat_interleave,
ops::RepeatInterleaveKernel<paddle::platform::CPUDeviceContext, float>,
ops::RepeatInterleaveKernel<paddle::platform::CPUDeviceContext, double>,
ops::RepeatInterleaveKernel<paddle::platform::CPUDeviceContext, int>,
ops::RepeatInterleaveKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
repeat_interleave_grad,
ops::RepeatInterleaveGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RepeatInterleaveGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::RepeatInterleaveGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::RepeatInterleaveGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/repeat_interleave_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
// function borrowed from repeat_interleave_op
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input, T* output,
const IndexT* index, int64_t N,
int64_t stride, int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index, int64_t nums,
int64_t N, int64_t stride,
int64_t size, int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}
template <typename DeviceContext, typename T>
class RepeatInterleaveCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
// auto* index = context.Input<LoDTensor>("RepeatsTensor");
auto* out = context.Output<LoDTensor>("Out");
int dim = context.Attr<int>("dim");
auto input_dim = in->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = framework::stride(input_dim);
int64_t stride = stride_dim[dim];
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
int repeats = context.Attr<int>("Repeats");
framework::LoDTensor index;
auto* in_data = in->data<T>();
if (context.HasInput("RepeatsTensor")) {
auto repeats_tensor =
context.Input<framework::LoDTensor>("RepeatsTensor");
PADDLE_ENFORCE_EQ(repeats_tensor->dims()[0] == in->dims()[dim], true,
platform::errors::InvalidArgument(
"The length of Input(RepeatsTensor) must be the "
"same as length of Input(X) in axis. "
"But received: [%s], required: [%d].",
repeats_tensor->dims()[0], in->dims()[dim]));
const auto& index_type = repeats_tensor->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"Input(RepeatsTensor) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT64) {
RepeatsTensor2IndexTensor<DeviceContext, int64_t>(*repeats_tensor,
&index);
const int64_t* index_data = index.data<int64_t>();
auto output_dim = framework::vectorize(in->dims());
output_dim[dim] = index.dims()[0];
out->Resize(framework::make_ddim(output_dim));
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = out->numel();
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
index_select_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data,
numel, stride, size, delta);
} else {
RepeatsTensor2IndexTensor<DeviceContext, int>(*repeats_tensor, &index);
const int* index_data = index.data<int>();
auto output_dim = framework::vectorize(in->dims());
output_dim[dim] = index.dims()[0];
out->Resize(framework::make_ddim(output_dim));
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = out->numel();
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
index_select_cuda_kernel<T, int><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data,
numel, stride, size, delta);
}
} else if (repeats > 0) {
int64_t index_size = in->dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < in->dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(framework::make_ddim({index_size}));
auto ctx = paddle::platform::DeviceContextPool::Instance().Get(
context.GetPlace());
paddle::framework::TensorFromVector<int>(index_vec, *ctx, &index);
auto output_dim = framework::vectorize(in->dims());
output_dim[dim] = index_size;
out->Resize(framework::make_ddim(output_dim));
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = out->numel();
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const int* index_data = index.data<int>();
index_select_cuda_kernel<T, int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
platform::GpuStreamSync(stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"repeats must given with RepeatsTensor (tensor) or repeats (int)"));
}
}
};
template <typename DeviceContext, typename T>
class RepeatInterleaveGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* output_grad = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_grad = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* output_grad_data = output_grad->data<T>();
auto* in_grad_data = in_grad->mutable_data<T>(context.GetPlace());
int dim = context.Attr<int>("dim");
auto input_dim = in_grad->dims();
auto output_dim = output_grad->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = framework::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
int64_t numel = in_grad->numel();
int64_t out_nums = output_grad->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel);
int repeats = context.Attr<int>("Repeats");
framework::LoDTensor index;
if (context.HasInput("RepeatsTensor")) {
auto repeats_tensor =
context.Input<framework::LoDTensor>("RepeatsTensor");
const auto& index_type = repeats_tensor->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT64) {
RepeatsTensor2IndexTensor<DeviceContext, int64_t>(*repeats_tensor,
&index);
int64_t index_nums = index.numel();
const int64_t* index_data = index.data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
output_grad_data, in_grad_data, index_data, index_nums, out_nums,
stride, size, delta);
platform::GpuStreamSync(stream);
} else {
RepeatsTensor2IndexTensor<DeviceContext, int>(*repeats_tensor, &index);
int64_t index_nums = index.numel();
const int* index_data = index.data<int>();
index_select_grad_cuda_kernel<T, int><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
output_grad_data, in_grad_data, index_data, index_nums, out_nums,
stride, size, delta);
platform::GpuStreamSync(stream);
}
} else if (repeats > 0) {
int64_t index_size = in_grad->dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < in_grad->dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(framework::make_ddim({index_size}));
auto ctx = paddle::platform::DeviceContextPool::Instance().Get(
context.GetPlace());
paddle::framework::TensorFromVector<int>(index_vec, *ctx, &index);
const int* index_data = index.data<int>();
int64_t index_nums = index.numel();
index_select_grad_cuda_kernel<T, int><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums,
out_nums, stride, size, delta);
platform::GpuStreamSync(stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"repeats must given with RepeatsTensor (tensor) or repeats (int)"));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
repeat_interleave,
ops::RepeatInterleaveCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::RepeatInterleaveCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
ops::RepeatInterleaveCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::RepeatInterleaveCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::RepeatInterleaveCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
REGISTER_OP_CUDA_KERNEL(
repeat_interleave_grad,
ops::RepeatInterleaveGradCUDAKernel<paddle::platform::CUDADeviceContext,
float>,
ops::RepeatInterleaveGradCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
ops::RepeatInterleaveGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::RepeatInterleaveGradCUDAKernel<paddle::platform::CUDADeviceContext,
int>,
ops::RepeatInterleaveGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/index_select_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename RepeatsT = int>
void RepeatsTensor2IndexTensor(const LoDTensor& repeats, LoDTensor* index) {
LoDTensor repeats_cpu_copy;
if (!platform::is_cpu_place(repeats.place())) {
framework::TensorCopySync(repeats, platform::CPUPlace(), &repeats_cpu_copy);
}
const RepeatsT* repeats_data = platform::is_cpu_place(repeats.place())
? repeats.data<RepeatsT>()
: repeats_cpu_copy.data<RepeatsT>();
int64_t index_size = 0;
for (int i = 0; i < repeats.dims()[0]; i++) {
index_size += repeats_data[i];
}
std::vector<RepeatsT> index_vec(index_size);
int offset = 0;
for (int i = 0; i < repeats.dims()[0]; i++) {
std::fill_n(index_vec.begin() + offset, repeats_data[i], i);
offset += repeats_data[i];
}
index->Resize(framework::make_ddim({index_size}));
auto ctx =
paddle::platform::DeviceContextPool::Instance().Get(repeats.place());
paddle::framework::TensorFromVector<RepeatsT>(index_vec, *ctx, index);
}
template <typename DeviceContext, typename T>
class RepeatInterleaveKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto inputs = *context.Input<framework::LoDTensor>("X");
auto* output = context.Output<framework::LoDTensor>("Out");
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += inputs.dims().size();
}
int repeats = context.Attr<int>("Repeats");
framework::LoDTensor index;
if (context.HasInput("RepeatsTensor")) {
auto repeats_tensor =
context.Input<framework::LoDTensor>("RepeatsTensor");
PADDLE_ENFORCE_EQ(repeats_tensor->dims()[0] == inputs.dims()[dim], true,
platform::errors::InvalidArgument(
"The length of Input(RepeatsTensor) must be the "
"same as length of Input(X) in axis. "
"But received: [%s], required: [%d].",
repeats_tensor->dims()[0], inputs.dims()[dim]));
const auto& index_type = repeats_tensor->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"Input(RepeatsTensor) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
RepeatsTensor2IndexTensor<DeviceContext, int>(*repeats_tensor, &index);
auto output_dim = framework::vectorize(inputs.dims());
output_dim[dim] = index.dims()[0];
output->Resize(framework::make_ddim(output_dim));
IndexSelectInner<DeviceContext, T, int>(context, &inputs, index, output,
dim);
} else if (index_type == framework::proto::VarType::INT64) {
RepeatsTensor2IndexTensor<DeviceContext, int64_t>(*repeats_tensor,
&index);
auto output_dim = framework::vectorize(inputs.dims());
output_dim[dim] = index.dims()[0];
output->Resize(framework::make_ddim(output_dim));
IndexSelectInner<DeviceContext, T, int64_t>(context, &inputs, index,
output, dim);
}
} else if (repeats > 0) {
int64_t index_size = inputs.dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < inputs.dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(framework::make_ddim({index_size}));
paddle::framework::TensorFromVector<int>(index_vec, &index);
auto output_dim = framework::vectorize(inputs.dims());
output_dim[dim] = index_size;
output->Resize(framework::make_ddim(output_dim));
IndexSelectInner<DeviceContext, T, int>(context, &inputs, index, output,
dim);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"repeats must given with RepeatsTensor (tensor) or repeats (int)"));
}
}
};
template <typename DeviceContext, typename T>
class RepeatInterleaveGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_grad =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* out_grad =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += out_grad->dims().size();
}
int repeats = context.Attr<int>("Repeats");
framework::LoDTensor index;
if (context.HasInput("RepeatsTensor")) {
auto repeats_tensor =
context.Input<framework::LoDTensor>("RepeatsTensor");
const auto& index_type = repeats_tensor->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"Input(Repeats) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
RepeatsTensor2IndexTensor<DeviceContext, int>(*repeats_tensor, &index);
IndexSelectGradInner<DeviceContext, T, int>(context, *out_grad, index,
x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) {
RepeatsTensor2IndexTensor<DeviceContext, int64_t>(*repeats_tensor,
&index);
IndexSelectGradInner<DeviceContext, T, int64_t>(context, *out_grad,
index, x_grad, dim);
}
} else if (repeats > 0) {
int64_t index_size = x_grad->dims()[dim] * repeats;
std::vector<int> index_vec(index_size);
for (int i = 0; i < x_grad->dims()[dim]; i++) {
std::fill_n(index_vec.begin() + i * repeats, repeats, i);
}
index.Resize(framework::make_ddim({index_size}));
paddle::framework::TensorFromVector<int>(index_vec, &index);
IndexSelectGradInner<DeviceContext, T, int>(context, *out_grad, index,
x_grad, dim);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"repeats must given with RepeatsTensor (tensor) or repeats (int)"));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -44,6 +44,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"nll_loss", {"X", "Label", "Weight"}},
{"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}},
{"gather", {"X", "Index", "Axis"}},
{"repeat_interleave", {"X", "RepeatsTensor"}},
{"roi_pool", {"X", "ROIs", "RoisNum"}},
{"roi_align", {"X", "ROIs", "RoisNum"}},
{"psroi_pool", {"X", "ROIs", "RoisNum"}},
......
......@@ -159,6 +159,7 @@ from .tensor.manipulation import tensordot # noqa: F401
from .tensor.manipulation import as_complex # noqa: F401
from .tensor.manipulation import as_real # noqa: F401
from .tensor.manipulation import moveaxis # noqa: F401
from .tensor.manipulation import repeat_interleave # noqa: F401
from .tensor.math import abs # noqa: F401
from .tensor.math import acos # noqa: F401
from .tensor.math import asin # noqa: F401
......@@ -579,4 +580,5 @@ __all__ = [ # noqa
'fmax',
'fmin',
'moveaxis',
'repeat_interleave',
]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestRepeatInterleaveOp(OpTest):
def setUp(self):
self.op_type = "repeat_interleave"
self.init_dtype_type()
index_np = np.random.randint(
low=0, high=3, size=self.index_size).astype(self.index_type)
x_np = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {'X': x_np, 'RepeatsTensor': index_np}
self.attrs = {'dim': self.dim}
outer_loop = np.prod(self.x_shape[:self.dim])
x_reshape = [outer_loop] + list(self.x_shape[self.dim:])
x_np_reshape = np.reshape(x_np, tuple(x_reshape))
out_list = []
for i in range(outer_loop):
for j in range(self.index_size):
for k in range(index_np[j]):
out_list.append(x_np_reshape[i, j])
self.out_shape = list(self.x_shape)
self.out_shape[self.dim] = np.sum(index_np)
self.out_shape = tuple(self.out_shape)
out = np.reshape(out_list, self.out_shape)
self.outputs = {'Out': out}
def init_dtype_type(self):
self.dim = 1
self.x_type = np.float64
self.index_type = np.int64
self.x_shape = (8, 4, 5)
self.index_size = self.x_shape[self.dim]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
class TestRepeatInterleaveOp2(OpTest):
def setUp(self):
self.op_type = "repeat_interleave"
self.init_dtype_type()
index_np = 2
x_np = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {'X': x_np} #, 'RepeatsTensor': None}
self.attrs = {'dim': self.dim, 'Repeats': index_np}
outer_loop = np.prod(self.x_shape[:self.dim])
x_reshape = [outer_loop] + list(self.x_shape[self.dim:])
x_np_reshape = np.reshape(x_np, tuple(x_reshape))
out_list = []
for i in range(outer_loop):
for j in range(self.index_size):
for k in range(index_np):
out_list.append(x_np_reshape[i, j])
self.out_shape = list(self.x_shape)
self.out_shape[self.dim] = index_np * self.index_size
self.out_shape = tuple(self.out_shape)
out = np.reshape(out_list, self.out_shape)
self.outputs = {'Out': out}
def init_dtype_type(self):
self.dim = 1
self.x_type = np.float64
self.x_shape = (8, 4, 5)
self.index_size = self.x_shape[self.dim]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
class TestIndexSelectAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]])
self.data_index = np.array([0, 1, 2, 1]).astype('int32')
def test_repeat_interleave_api(self):
paddle.enable_static()
self.input_data()
# case 1:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 4])
index = fluid.layers.data(
name='repeats',
shape=[4],
dtype='int32',
append_batch_size=False)
z = paddle.repeat_interleave(x, index, axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x,
'repeats': self.data_index},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.repeat(self.data_x, self.data_index, axis=1)
self.assertTrue(np.allclose(expect_out, np.array(res)))
# case 2:
repeats = np.array([1, 2, 1]).astype('int32')
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 4])
index = fluid.layers.data(
name='repeats',
shape=[3],
dtype='int32',
append_batch_size=False)
z = paddle.repeat_interleave(x, index, axis=0)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={
'x': self.data_x,
'repeats': repeats,
},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.repeat(self.data_x, repeats, axis=0)
self.assertTrue(np.allclose(expect_out, np.array(res)))
repeats = 2
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 4])
z = paddle.repeat_interleave(x, repeats, axis=0)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.repeat(self.data_x, repeats, axis=0)
self.assertTrue(np.allclose(expect_out, np.array(res)))
def test_dygraph_api(self):
self.input_data()
# case axis none
input_x = np.array([[1, 2, 1], [1, 2, 3]]).astype('int32')
index_x = np.array([1, 1, 2, 1, 2, 2]).astype('int32')
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(input_x)
index = fluid.dygraph.to_variable(index_x)
z = paddle.repeat_interleave(x, index, None)
np_z = z.numpy()
expect_out = np.repeat(input_x, index_x, axis=None)
self.assertTrue(np.allclose(expect_out, np_z))
# case repeats int
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(input_x)
index = 2
z = paddle.repeat_interleave(x, index, None)
np_z = z.numpy()
expect_out = np.repeat(input_x, index, axis=None)
self.assertTrue(np.allclose(expect_out, np_z))
# case 1:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
index = fluid.dygraph.to_variable(self.data_index)
z = paddle.repeat_interleave(x, index, -1)
np_z = z.numpy()
expect_out = np.repeat(self.data_x, self.data_index, axis=-1)
self.assertTrue(np.allclose(expect_out, np_z))
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
index = fluid.dygraph.to_variable(self.data_index)
z = paddle.repeat_interleave(x, index, 1)
np_z = z.numpy()
expect_out = np.repeat(self.data_x, self.data_index, axis=1)
self.assertTrue(np.allclose(expect_out, np_z))
# case 2:
index_x = np.array([1, 2, 1]).astype('int32')
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
index = fluid.dygraph.to_variable(index_x)
z = paddle.repeat_interleave(x, index, axis=0)
np_z = z.numpy()
expect_out = np.repeat(self.data_x, index, axis=0)
self.assertTrue(np.allclose(expect_out, np_z))
if __name__ == '__main__':
unittest.main()
......@@ -114,6 +114,7 @@ from .manipulation import tensordot # noqa: F401
from .manipulation import as_complex # noqa: F401
from .manipulation import as_real # noqa: F401
from .manipulation import moveaxis # noqa: F401
from .manipulation import repeat_interleave # noqa: F401
from .math import abs # noqa: F401
from .math import acos # noqa: F401
from .math import asin # noqa: F401
......@@ -436,7 +437,8 @@ tensor_method_func = [ #noqa
'lerp',
'lerp_',
'angle',
'moveaxis'
'moveaxis',
'repeat_interleave',
]
#this list used in math_op_patch.py for magic_method bind
......
......@@ -2584,6 +2584,68 @@ def as_real(x, name=None):
return out
def repeat_interleave(x, repeats, axis=None, name=None):
"""
Returns a new tensor which repeats the ``x`` tensor along dimension ``axis`` using
the entries in ``repeats`` which is a int or a Tensor.
Args:
x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float32, float64, int32, int64.
repeats (Tensor or int): The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
axis (int, optional): The dimension in which we manipulate. Default: if None, the output tensor is flatten.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor with same data type as ``x``.
x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
repeats = paddle.to_tensor([3, 2, 1], dtype='int32')
paddle.repeat_interleave(x, repeats, 1)
# [[1, 1, 1, 2, 2, 3],
# [4, 4, 4, 5, 5, 6]]
paddle.repeat_interleave(x, 2, 0)
# [[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]
paddle.repeat_interleave(x, 2, None)
# [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]
"""
if axis is None:
x = paddle.flatten(x)
axis = 0
if in_dygraph_mode():
if isinstance(repeats, int):
return _C_ops.repeat_interleave(x, None, 'Repeats', repeats, 'dim',
axis)
elif isinstance(repeats, Variable):
return _C_ops.repeat_interleave(x, repeats, 'dim', axis)
helper = LayerHelper("repeat_interleave", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'paddle.tensor.manipulation.repeat_interleave')
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='repeat_interleave',
inputs={
'X': x,
'RepeatsTensor': repeats if isinstance(repeats, Variable) else None
},
outputs={'Out': out},
attrs={
'dim': axis,
'Repeats': repeats if isinstance(repeats, int) else 0
})
return out
def moveaxis(x, source, destination, name=None):
"""
Move the axis of tensor from ``source`` position to ``destination`` position.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册