未验证 提交 c0e29233 编写于 作者: F From00 提交者: GitHub

Move psroi_pool OP to phi (#40353)

* Move psroi_pool OP to phi

* Replace platform::TensorCopy with phi::Copy
上级 89ed57e2
......@@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/psroi_pool_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class PSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -82,75 +82,6 @@ class PSROIPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of PSROIPoolOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true,
platform::errors::InvalidArgument(
"Input(ROIs) of PSROIPoolOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of PSROIPoolOp should not be null."));
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
"The format of input tensor is NCHW"));
PADDLE_ENFORCE_EQ(
rois_dims.size(), 2,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
PADDLE_ENFORCE_EQ(
rois_dims[1], 4,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
if (ctx->HasInput("RoisNum")) {
auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
platform::errors::InvalidArgument(
"The second dimension of RoisNum should "
"be 1, but received dimension is %d",
rois_num_dims.size()));
}
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
int output_channels = ctx->Attrs().Get<int>("output_channels");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
PADDLE_ENFORCE_EQ(
input_dims[1], output_channels * pooled_height * pooled_width,
platform::errors::InvalidArgument(
"the channel of X(%d) "
"should be equal to the product of "
"output_channels(%d), pooled_height(%d) and pooled_width(%d)",
input_dims[1], output_channels, pooled_height, pooled_width));
PADDLE_ENFORCE_GT(pooled_height, 0,
platform::errors::InvalidArgument(
"The pooled output height must be greater than 0"));
PADDLE_ENFORCE_GT(pooled_width, 0,
platform::errors::InvalidArgument(
"The pooled output width must be greater than 0"));
PADDLE_ENFORCE_GT(output_channels, 1,
platform::errors::InvalidArgument(
"The pooled output channels must greater than 1"));
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
platform::errors::InvalidArgument(
"The spatial scale must greater than 0."));
auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] =
output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Out", out_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -164,16 +95,6 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"The gradient of Out should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"The gradient of X should not be null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -204,15 +125,13 @@ class PSROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(psroi_pool, PsroiPoolInferShapeFunctor,
PD_INFER_META(phi::PsroiPoolInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(psroi_pool_grad, PsroiPoolGradInferShapeFunctor,
PD_INFER_META(phi::PsroiPoolGradInferMeta));
REGISTER_OPERATOR(psroi_pool, ops::PSROIPoolOp, ops::PSROIPoolOpMaker,
ops::PSROIPoolGradMaker<paddle::framework::OpDesc>,
ops::PSROIPoolGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
psroi_pool,
ops::CPUPSROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPSROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
psroi_pool_grad,
ops::CPUPSROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPSROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>);
ops::PSROIPoolGradMaker<paddle::imperative::OpBase>,
PsroiPoolInferShapeFunctor);
REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp,
PsroiPoolGradInferShapeFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/psroi_pool_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void GPUPSROIPoolForward(
const int nthreads, const T* input_data, const T* input_rois,
const float spatial_scale, const int input_channels, const int height,
const int width, const int output_channels, const int pooled_height,
const int pooled_width, const int* rois_batch_id_data, T* output_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
// The output is in order (n, c, ph, pw)
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % output_channels;
int n = i / pooled_width / pooled_height / output_channels;
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w = static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h = static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = max(roi_end_w - roi_start_w, (T)0.1);
// Compute w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(bin_size_h * static_cast<T>(ph) + roi_start_h);
int wstart = floor(bin_size_w * static_cast<T>(pw) + roi_start_w);
int hend = ceil(bin_size_h * static_cast<T>(ph + 1) + roi_start_h);
int wend = ceil(bin_size_w * static_cast<T>(pw + 1) + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0), width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
const T* offset_input_data =
input_data +
(roi_batch_id * input_channels + input_channel) * height * width;
T outsum = 0;
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
outsum += offset_input_data[input_index];
}
}
T bin_area = static_cast<T>((hend - hstart) * (wend - wstart));
output_data[i] = is_empty ? 0. : outsum / bin_area;
}
}
template <typename T>
__global__ void GPUPSROIPoolBackward(
const int nthreads, const T* input_rois, const T* output_grad_data,
const float spatial_scale, const int input_channels, const int height,
const int width, const int output_channels, const int pooled_height,
const int pooled_width, const int* rois_batch_id_data, T* input_grad_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
// The output is in order (n, c, ph, pw)
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % output_channels;
int n = i / pooled_width / pooled_height / output_channels;
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_offset =
(roi_batch_id * input_channels + input_channel) * height * width;
T* offset_input_grad_data = input_grad_data + input_offset;
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w = static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h = static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = max(roi_end_w - roi_start_w, (T)0.1);
// Compute w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(bin_size_h * static_cast<T>(ph) + roi_start_h);
int wstart = floor(bin_size_w * static_cast<T>(pw) + roi_start_w);
int hend = ceil(bin_size_h * static_cast<T>(ph + 1) + roi_start_h);
int wend = ceil(bin_size_w * static_cast<T>(pw + 1) + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0), width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Accumulate diff_val into input data
T bin_area = static_cast<T>((hend - hstart) * (wend - wstart));
T diff_val = is_empty ? 0. : output_grad_data[i] / bin_area;
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
platform::CudaAtomicAdd(offset_input_grad_data + input_index, diff_val);
}
}
}
}
template <typename Place, typename T>
class GPUPSROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* out = ctx.Output<Tensor>("Out");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto output_channels = ctx.Attr<int>("output_channels");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
PADDLE_ENFORCE_EQ(
input_channels, output_channels * pooled_height * pooled_width,
platform::errors::InvalidArgument(
"The channels %d of input X should equal the product of "
"output_channels %d x pooled_height %d x pooled_width %d.",
input_channels, output_channels, pooled_height, pooled_width));
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
int rois_batch_size;
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
auto* rois_num_data = rois_num_t->data<int>();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be "
"the same but received batch size of input(ROIs) and "
"input(X) is %d and %d respectively.",
rois_batch_size, batch_size));
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(platform::CPUPlace(), rois_num_list.data(), ctx.GetPlace(),
rois_num_data, sizeof(int) * rois_batch_size, 0);
int rois_num_count = 0;
for (int i = 0; i < rois_batch_size; ++i) {
rois_num_count += rois_num_list[i];
}
PADDLE_ENFORCE_EQ(
rois_num_count, rois_num,
platform::errors::InvalidArgument(
"the rois_num from input and RoisNum must be the same"));
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be "
"the same but received batch size of input(ROIs) and "
"input(X) is %d and %d respectively.",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
platform::errors::InvalidArgument(
"The number of rois from input(ROIs) and its LOD "
"must be the same. Received rois %d of input(ROIs) "
"but the number of rois %d from its LOD is %d",
rois_num, rois_num_with_lod));
// set rois batch id
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
framework::Tensor rois_batch_id_list_gpu;
framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &rois_batch_id_list_gpu);
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
// call cuda kernel function
GPUPSROIPoolForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<T>(), spatial_scale,
input_channels, height, width, output_channels, pooled_height,
pooled_width, rois_batch_id_list_gpu.data<int>(),
out->mutable_data<T>(ctx.GetPlace()));
}
};
template <typename Place, typename T>
class GPUPSROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto output_channels = ctx.Attr<int>("output_channels");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
int rois_num = rois->dims()[0];
int input_channels = in->dims()[1];
int height = in->dims()[2];
int width = in->dims()[3];
if (input_grad) {
// set roi batch id
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
int rois_batch_size;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(platform::CPUPlace(), rois_num_list.data(), ctx.GetPlace(),
rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
framework::Tensor rois_batch_id_list_gpu;
framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &rois_batch_id_list_gpu);
input_grad->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<Place, T> set_zero;
set_zero(ctx.cuda_device_context(), input_grad, static_cast<T>(0));
int output_grad_size = output_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUPSROIPoolBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<T>(), output_grad->data<T>(),
spatial_scale, input_channels, height, width, output_channels,
pooled_height, pooled_width, rois_batch_id_list_gpu.data<int>(),
input_grad->mutable_data<T>(ctx.GetPlace()));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
psroi_pool,
ops::GPUPSROIPoolOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUPSROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
psroi_pool_grad,
ops::GPUPSROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUPSROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CPUPSROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto output_channels = ctx.Attr<int>("output_channels");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
PADDLE_ENFORCE_EQ(input_channels,
output_channels * pooled_height * pooled_width,
platform::errors::InvalidArgument(
"the channels of input "
"X should equal the product of "
"output_channels x pooled_height x pooled_width"));
auto in_stride = phi::stride(in_dims);
auto out_stride = phi::stride(out->dims());
const T* input_data = in->data<T>();
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
auto* rois_num_data = rois_num_t->data<int>();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of rois and the batch size of images "
" must be the same. But received the batch size of rois is %d, "
"and the batch size of images is %d",
rois_batch_size, batch_size));
int rois_num_count = 0;
for (int i = 0; i < rois_batch_size; ++i) {
rois_num_count += rois_num_data[i];
}
PADDLE_ENFORCE_EQ(
rois_num_count, rois_num,
platform::errors::InvalidArgument(
"the rois_num from input and RoisNum must be the same"));
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument("the rois_batch_size and input(X) "
"batch_size should be the same."));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(
rois_num_with_lod, rois_num,
platform::errors::InvalidArgument(
"the rois_num from input and lod must be the same"));
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
T* output_data = out->mutable_data<T>(ctx.GetPlace());
const T* input_rois = rois->data<T>();
// calculate psroipooling, parallel processing can be implemented per ROI
for (int n = 0; n < rois_num; ++n) {
// set roi batch id
int roi_batch_id = rois_batch_id_data[n];
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w =
static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h =
static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small rois to be 1 x 1
T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1);
// Compute bin size w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
// calculate each pixel of the output feature map.
int out_roi_offset = n * out_stride[0];
for (int c = 0; c < output_channels; ++c) {
// per category
int out_plane_offset = out_roi_offset + c * out_stride[1];
for (int ph = 0; ph < pooled_height; ++ph) {
int out_row_offset = out_plane_offset + ph * out_stride[2];
for (int pw = 0; pw < pooled_width; ++pw) {
// calculate w and h at input feature map
int hstart = floor(static_cast<T>(ph) * bin_size_h + roi_start_h);
int wstart = floor(static_cast<T>(pw) * bin_size_w + roi_start_w);
int hend = ceil(static_cast<T>(ph + 1) * bin_size_h + roi_start_h);
int wend = ceil(static_cast<T>(pw + 1) * bin_size_w + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
wstart = std::min(std::max(wstart, 0), width);
hend = std::min(std::max(hend, 0), height);
wend = std::min(std::max(wend, 0), width);
int output_index = out_row_offset + pw;
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_plane_offset =
roi_batch_id * in_stride[0] + input_channel * in_stride[1];
const T* offset_input_data = input_data + input_plane_offset;
T out_sum = 0.;
bool is_empty = (hend <= hstart) || (wend <= wstart);
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * in_stride[2] + iw;
out_sum += offset_input_data[input_index];
}
}
T bin_area = (hend - hstart) * (wend - wstart);
output_data[output_index] = is_empty ? 0. : out_sum / bin_area;
}
}
}
}
return;
}
};
template <typename DeviceContext, typename T>
class CPUPSROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto output_channels = ctx.Attr<int>("output_channels");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
if (input_grad) {
auto in_dims = in->dims();
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
// set roi batch id
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
auto* rois_num_data = rois_num_t->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
const T* input_rois = rois->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// set gradient of X to be 0. before backpropagate.
phi::funcs::SetConstant<DeviceContext, T> set_zero;
set_zero(ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0));
// backpropagate gradient per output pixel
int output_grad_size = output_grad->numel();
for (int i = 0; i < output_grad_size; ++i) {
// The output is in order (n, c, ph, pw)
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % output_channels;
int n = i / pooled_width / pooled_height / output_channels;
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_offset =
(roi_batch_id * input_channels + input_channel) * height * width;
T* offset_input_grad_data = input_grad_data + input_offset;
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w =
static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h =
static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1);
// Compute w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(bin_size_h * static_cast<T>(ph) + roi_start_h);
int wstart = floor(bin_size_w * static_cast<T>(pw) + roi_start_w);
int hend = ceil(bin_size_h * static_cast<T>(ph + 1) + roi_start_h);
int wend = ceil(bin_size_w * static_cast<T>(pw + 1) + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
hend = std::min(std::max(hend, 0), height);
wstart = std::min(std::max(wstart, 0), width);
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Accumulate diff_val into input data
T bin_area = static_cast<T>((hend - hstart) * (wend - wstart));
T diff_val = is_empty ? 0. : output_grad_data[i] / bin_area;
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
offset_input_grad_data[input_index] += diff_val;
}
}
}
}
return;
}
};
} // namespace operators
} // namespace paddle
......@@ -115,6 +115,18 @@ void GatherNdGradInferMeta(const MetaTensor& x,
x_grad->set_dtype(dtype);
}
void PsroiPoolGradInferMeta(const MetaTensor& x,
const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num,
const MetaTensor& dout,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
MetaTensor* dx) {
dx->share_meta(x);
}
void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
......
......@@ -47,6 +47,16 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
int axis,
MetaTensor* dx);
void PsroiPoolGradInferMeta(const MetaTensor& x,
const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num,
const MetaTensor& dout,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
MetaTensor* dx);
void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
......
......@@ -28,6 +28,98 @@ std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors) {
return dims;
}
void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& avg_squared_grad,
const MetaTensor& avg_squared_update,
float rho,
float epsilon,
MetaTensor* param_out,
MetaTensor* avg_squared_grad_out,
MetaTensor* avg_squared_update_out) {
auto param_dims = param.dims();
PADDLE_ENFORCE_EQ(
param_dims,
grad.dims(),
errors::InvalidArgument(
"Param and grad input of AdadeltaOp should have same dimension."));
PADDLE_ENFORCE_EQ(
param_dims,
avg_squared_grad.dims(),
errors::InvalidArgument("Param and AvgSquaredGrad input of AdadeltaOp "
"should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims,
avg_squared_update.dims(),
errors::InvalidArgument("Param and AvgSquaredUpdate input of AdadeltaOp "
"should have same dimension"));
param_out->set_dims(param_dims);
param_out->set_dtype(param.dtype());
avg_squared_grad_out->set_dims(param_dims);
avg_squared_grad_out->set_dtype(avg_squared_grad.dtype());
avg_squared_update_out->set_dims(param_dims);
avg_squared_update_out->set_dtype(avg_squared_update.dtype());
}
void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& moment,
const MetaTensor& inf_norm,
const MetaTensor& beta1_pow,
float beta1,
float beta2,
float epsilon,
MetaTensor* param_out,
MetaTensor* moment_out,
MetaTensor* inf_norm_out) {
auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_NE(
product(lr_dims),
0,
errors::InvalidArgument("Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(
product(lr_dims),
1,
errors::InvalidArgument("Learning rate should have 1 dimension"));
auto beta1_pow_dims = beta1_pow.dims();
PADDLE_ENFORCE_EQ(product(beta1_pow_dims),
1,
errors::InvalidArgument(
"Beta1 power accumulator should have 1 dimension"));
auto param_dims = param.dims();
PADDLE_ENFORCE_EQ(
param_dims,
grad.dims(),
errors::InvalidArgument(
"Param and Grad input of AdamaxOp should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims,
moment.dims(),
errors::InvalidArgument(
"Param and Moment input of AdamaxOp should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims,
inf_norm.dims(),
errors::InvalidArgument(
"Param and InfNorm input of AdamaxOp should have same dimension"));
param_out->set_dims(param_dims);
param_out->set_dtype(param.dtype());
moment_out->set_dims(param_dims);
moment_out->set_dtype(moment.dtype());
inf_norm_out->set_dims(param_dims);
inf_norm_out->set_dtype(inf_norm.dtype());
}
void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
......@@ -108,98 +200,6 @@ void AucInferMeta(const MetaTensor& input,
}
}
void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& moment,
const MetaTensor& inf_norm,
const MetaTensor& beta1_pow,
float beta1,
float beta2,
float epsilon,
MetaTensor* param_out,
MetaTensor* moment_out,
MetaTensor* inf_norm_out) {
auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_NE(
product(lr_dims),
0,
errors::InvalidArgument("Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(
product(lr_dims),
1,
errors::InvalidArgument("Learning rate should have 1 dimension"));
auto beta1_pow_dims = beta1_pow.dims();
PADDLE_ENFORCE_EQ(product(beta1_pow_dims),
1,
errors::InvalidArgument(
"Beta1 power accumulator should have 1 dimension"));
auto param_dims = param.dims();
PADDLE_ENFORCE_EQ(
param_dims,
grad.dims(),
errors::InvalidArgument(
"Param and Grad input of AdamaxOp should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims,
moment.dims(),
errors::InvalidArgument(
"Param and Moment input of AdamaxOp should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims,
inf_norm.dims(),
errors::InvalidArgument(
"Param and InfNorm input of AdamaxOp should have same dimension"));
param_out->set_dims(param_dims);
param_out->set_dtype(param.dtype());
moment_out->set_dims(param_dims);
moment_out->set_dtype(moment.dtype());
inf_norm_out->set_dims(param_dims);
inf_norm_out->set_dtype(inf_norm.dtype());
}
void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& avg_squared_grad,
const MetaTensor& avg_squared_update,
float rho,
float epsilon,
MetaTensor* param_out,
MetaTensor* avg_squared_grad_out,
MetaTensor* avg_squared_update_out) {
auto param_dims = param.dims();
PADDLE_ENFORCE_EQ(
param_dims,
grad.dims(),
errors::InvalidArgument(
"Param and grad input of AdadeltaOp should have same dimension."));
PADDLE_ENFORCE_EQ(
param_dims,
avg_squared_grad.dims(),
errors::InvalidArgument("Param and AvgSquaredGrad input of AdadeltaOp "
"should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims,
avg_squared_update.dims(),
errors::InvalidArgument("Param and AvgSquaredUpdate input of AdadeltaOp "
"should have same dimension"));
param_out->set_dims(param_dims);
param_out->set_dtype(param.dtype());
avg_squared_grad_out->set_dims(param_dims);
avg_squared_grad_out->set_dtype(avg_squared_grad.dtype());
avg_squared_update_out->set_dims(param_dims);
avg_squared_update_out->set_dtype(avg_squared_update.dtype());
}
void BilinearTensorProductInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
......@@ -369,6 +369,81 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0));
}
void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
MetaTensor* out) {
auto input_dims = x.dims();
auto rois_dims = rois.dims();
PADDLE_ENFORCE_EQ(
input_dims.size(),
4,
errors::InvalidArgument("The format of input tensor is NCHW"));
PADDLE_ENFORCE_EQ(rois_dims.size(),
2,
errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
PADDLE_ENFORCE_EQ(rois_dims[1],
4,
errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
if (rois_num.get_ptr()) {
auto rois_num_dims = rois_num->dims();
PADDLE_ENFORCE_EQ(
rois_num_dims.size(),
1,
errors::InvalidArgument("The second dimension of RoisNum should "
"be 1, but received dimension is %d",
rois_num_dims.size()));
}
PADDLE_ENFORCE_EQ(
input_dims[1],
output_channels * pooled_height * pooled_width,
errors::InvalidArgument(
"the channel of X(%d) "
"should be equal to the product of "
"output_channels(%d), pooled_height(%d) and pooled_width(%d)",
input_dims[1],
output_channels,
pooled_height,
pooled_width));
PADDLE_ENFORCE_GT(pooled_height,
0,
errors::InvalidArgument(
"The pooled output height must be greater than 0"));
PADDLE_ENFORCE_GT(pooled_width,
0,
errors::InvalidArgument(
"The pooled output width must be greater than 0"));
PADDLE_ENFORCE_GT(output_channels,
1,
errors::InvalidArgument(
"The pooled output channels must greater than 1"));
PADDLE_ENFORCE_GT(
spatial_scale,
0.0f,
errors::InvalidArgument("The spatial scale must greater than 0."));
auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] =
output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
......
......@@ -20,6 +20,29 @@ namespace phi {
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors);
void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& avg_squared_grad,
const MetaTensor& avg_squared_update,
float rho,
float epsilon,
MetaTensor* param_out,
MetaTensor* avg_squared_grad_out,
MetaTensor* avg_squared_update_out);
void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& moment,
const MetaTensor& inf_norm,
const MetaTensor& beta1_pow,
float beta1,
float beta2,
float epsilon,
MetaTensor* param_out,
MetaTensor* moment_out,
MetaTensor* inf_norm_out);
void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
......@@ -47,32 +70,18 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
MetaTensor* out);
void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& moment,
const MetaTensor& inf_norm,
const MetaTensor& beta1_pow,
float beta1,
float beta2,
float epsilon,
MetaTensor* param_out,
MetaTensor* moment_out,
MetaTensor* inf_norm_out);
void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& avg_squared_grad,
const MetaTensor& avg_squared_update,
float rho,
float epsilon,
MetaTensor* param_out,
MetaTensor* avg_squared_grad_out,
MetaTensor* avg_squared_update_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/psroi_pool_grad_kernel.h"
#include <algorithm>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void PsroiPoolGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& rois,
paddle::optional<const DenseTensor&> rois_num,
const DenseTensor& dout,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
DenseTensor* dx) {
if (dx) {
auto in_dims = x.dims();
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num_t = rois.dims()[0];
// set roi batch id
DenseTensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num_t});
int* rois_batch_id_data = ctx.template Alloc<int>(&rois_batch_id_list);
int rois_batch_size;
if (rois_num.get_ptr()) {
rois_batch_size = rois_num->numel();
auto* rois_num_t_data = rois_num->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_t_data[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_t_data[n];
}
} else {
auto rois_lod = rois.lod().back();
rois_batch_size = rois_lod.size() - 1;
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
const T* input_rois = rois.data<T>();
const T* dout_data = dout.data<T>();
T* dx_data = ctx.template Alloc<T>(dx);
// set gradient of X to be 0. before backpropagate.
funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, dx, static_cast<T>(0));
// backpropagate gradient per output pixel
int dout_size = dout.numel();
for (int i = 0; i < dout_size; ++i) {
// The output is in order (n, c, ph, pw)
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % output_channels;
int n = i / pooled_width / pooled_height / output_channels;
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_offset =
(roi_batch_id * input_channels + input_channel) * height * width;
T* offset_dx_data = dx_data + input_offset;
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w =
static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h =
static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1);
// Compute w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(bin_size_h * static_cast<T>(ph) + roi_start_h);
int wstart = floor(bin_size_w * static_cast<T>(pw) + roi_start_w);
int hend = ceil(bin_size_h * static_cast<T>(ph + 1) + roi_start_h);
int wend = ceil(bin_size_w * static_cast<T>(pw + 1) + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
hend = std::min(std::max(hend, 0), height);
wstart = std::min(std::max(wstart, 0), width);
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Accumulate diff_val into input data
T bin_area = static_cast<T>((hend - hstart) * (wend - wstart));
T diff_val = is_empty ? 0. : dout_data[i] / bin_area;
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
offset_dx_data[input_index] += diff_val;
}
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
psroi_pool_grad, CPU, ALL_LAYOUT, phi::PsroiPoolGradKernel, float, double) {
kernel->InputAt(2).SetDataType(
paddle::experimental::CppTypeToDataType<int>::Type());
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/psroi_pool_kernel.h"
#include <algorithm>
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void PsroiPoolKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& rois,
paddle::optional<const DenseTensor&> rois_num,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
DenseTensor* out) {
auto in_dims = x.dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num_t = rois.dims()[0];
PADDLE_ENFORCE_EQ(input_channels,
output_channels * pooled_height * pooled_width,
errors::InvalidArgument(
"the channels of input "
"X should equal the product of "
"output_channels x pooled_height x pooled_width"));
auto in_stride = stride(in_dims);
auto out_stride = stride(out->dims());
const T* input_data = x.data<T>();
DenseTensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num_t});
int* rois_batch_id_data = ctx.template Alloc<int>(&rois_batch_id_list);
int rois_batch_size;
if (rois_num.get_ptr()) {
rois_batch_size = rois_num->numel();
auto* rois_num_data = rois_num->data<int>();
PADDLE_ENFORCE_EQ(
rois_batch_size,
batch_size,
errors::InvalidArgument(
"The batch size of rois and the batch size of images "
" must be the same. But received the batch size of rois is %d, "
"and the batch size of images is %d",
rois_batch_size,
batch_size));
int rois_num_count = 0;
for (int i = 0; i < rois_batch_size; ++i) {
rois_num_count += rois_num_data[i];
}
PADDLE_ENFORCE_EQ(
rois_num_count,
rois_num_t,
errors::InvalidArgument(
"the rois_num from input and RoisNum must be the same"));
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois.lod().back();
rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size,
batch_size,
errors::InvalidArgument("the rois_batch_size and input(X) "
"batch_size should be the same."));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num_with_lod,
rois_num_t,
errors::InvalidArgument(
"the rois_num from input and lod must be the same"));
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
T* output_data = ctx.template Alloc<T>(out);
const T* input_rois = rois.data<T>();
// calculate psroipooling, parallel processing can be implemented per ROI
for (int n = 0; n < rois_num_t; ++n) {
// set roi batch id
int roi_batch_id = rois_batch_id_data[n];
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w = static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h = static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small rois to be 1 x 1
T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1);
// Compute bin size w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
// calculate each pixel of the output feature map.
int out_roi_offset = n * out_stride[0];
for (int c = 0; c < output_channels; ++c) {
// per category
int out_plane_offset = out_roi_offset + c * out_stride[1];
for (int ph = 0; ph < pooled_height; ++ph) {
int out_row_offset = out_plane_offset + ph * out_stride[2];
for (int pw = 0; pw < pooled_width; ++pw) {
// calculate w and h at input feature map
int hstart = floor(static_cast<T>(ph) * bin_size_h + roi_start_h);
int wstart = floor(static_cast<T>(pw) * bin_size_w + roi_start_w);
int hend = ceil(static_cast<T>(ph + 1) * bin_size_h + roi_start_h);
int wend = ceil(static_cast<T>(pw + 1) * bin_size_w + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
wstart = std::min(std::max(wstart, 0), width);
hend = std::min(std::max(hend, 0), height);
wend = std::min(std::max(wend, 0), width);
int output_index = out_row_offset + pw;
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_plane_offset =
roi_batch_id * in_stride[0] + input_channel * in_stride[1];
const T* offset_input_data = input_data + input_plane_offset;
T out_sum = 0.;
bool is_empty = (hend <= hstart) || (wend <= wstart);
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * in_stride[2] + iw;
out_sum += offset_input_data[input_index];
}
}
T bin_area = (hend - hstart) * (wend - wstart);
output_data[output_index] = is_empty ? 0. : out_sum / bin_area;
}
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
psroi_pool, CPU, ALL_LAYOUT, phi::PsroiPoolKernel, float, double) {
kernel->InputAt(2).SetDataType(
paddle::experimental::CppTypeToDataType<int>::Type());
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/psroi_pool_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void GPUPSROIPoolBackward(const int nthreads,
const T* input_rois,
const T* dout_data,
const float spatial_scale,
const int input_channels,
const int height,
const int width,
const int output_channels,
const int pooled_height,
const int pooled_width,
const int* rois_batch_id_data,
T* dx_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
// The output is in order (n, c, ph, pw)
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % output_channels;
int n = i / pooled_width / pooled_height / output_channels;
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_offset =
(roi_batch_id * input_channels + input_channel) * height * width;
T* offset_dx_data = dx_data + input_offset;
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w = static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h = static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = max(roi_end_w - roi_start_w, (T)0.1);
// Compute w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(bin_size_h * static_cast<T>(ph) + roi_start_h);
int wstart = floor(bin_size_w * static_cast<T>(pw) + roi_start_w);
int hend = ceil(bin_size_h * static_cast<T>(ph + 1) + roi_start_h);
int wend = ceil(bin_size_w * static_cast<T>(pw + 1) + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0), width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Accumulate diff_val into input data
T bin_area = static_cast<T>((hend - hstart) * (wend - wstart));
T diff_val = is_empty ? 0. : dout_data[i] / bin_area;
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
paddle::platform::CudaAtomicAdd(offset_dx_data + input_index, diff_val);
}
}
}
}
template <typename T, typename Context>
void PsroiPoolGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& rois,
paddle::optional<const DenseTensor&> rois_num,
const DenseTensor& dout,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
DenseTensor* dx) {
int rois_num_t = rois.dims()[0];
int input_channels = x.dims()[1];
int height = x.dims()[2];
int width = x.dims()[3];
if (dx) {
// set roi batch id
DenseTensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num_t});
int* rois_batch_id_data = ctx.template HostAlloc<int>(&rois_batch_id_list);
int rois_batch_size;
if (rois_num.get_ptr()) {
rois_batch_size = rois_num->numel();
std::vector<int> rois_num_list(rois_batch_size);
paddle::memory::Copy(CPUPlace(),
rois_num_list.data(),
ctx.GetPlace(),
rois_num->data<int>(),
sizeof(int) * rois_batch_size,
0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois.lod().back();
rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
DenseTensor rois_batch_id_list_gpu;
Copy(ctx,
rois_batch_id_list,
ctx.GetPlace(),
false,
&rois_batch_id_list_gpu);
ctx.template Alloc<T>(dx);
funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, dx, static_cast<T>(0));
int dout_size = dout.numel();
int blocks = NumBlocks(dout_size);
int threads = kNumCUDAThreads;
if (dout_size > 0) {
GPUPSROIPoolBackward<T><<<blocks, threads, 0, ctx.stream()>>>(
dout_size,
rois.data<T>(),
dout.data<T>(),
spatial_scale,
input_channels,
height,
width,
output_channels,
pooled_height,
pooled_width,
rois_batch_id_list_gpu.data<int>(),
ctx.template Alloc<T>(dx));
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
psroi_pool_grad, GPU, ALL_LAYOUT, phi::PsroiPoolGradKernel, float, double) {
kernel->InputAt(2).SetDataType(
paddle::experimental::CppTypeToDataType<int>::Type());
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/psroi_pool_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void GPUPSROIPoolForward(const int nthreads,
const T* input_data,
const T* input_rois,
const float spatial_scale,
const int input_channels,
const int height,
const int width,
const int output_channels,
const int pooled_height,
const int pooled_width,
const int* rois_batch_id_data,
T* output_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
// The output is in order (n, c, ph, pw)
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % output_channels;
int n = i / pooled_width / pooled_height / output_channels;
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
T roi_start_w = static_cast<T>(round(offset_input_rois[0])) * spatial_scale;
T roi_start_h = static_cast<T>(round(offset_input_rois[1])) * spatial_scale;
T roi_end_w =
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = max(roi_end_w - roi_start_w, (T)0.1);
// Compute w and h at input feature map
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(bin_size_h * static_cast<T>(ph) + roi_start_h);
int wstart = floor(bin_size_w * static_cast<T>(pw) + roi_start_w);
int hend = ceil(bin_size_h * static_cast<T>(ph + 1) + roi_start_h);
int wend = ceil(bin_size_w * static_cast<T>(pw + 1) + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0), width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
const T* offset_input_data =
input_data +
(roi_batch_id * input_channels + input_channel) * height * width;
T outsum = 0;
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
outsum += offset_input_data[input_index];
}
}
T bin_area = static_cast<T>((hend - hstart) * (wend - wstart));
output_data[i] = is_empty ? 0. : outsum / bin_area;
}
}
template <typename T, typename Context>
void PsroiPoolKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& rois,
paddle::optional<const DenseTensor&> rois_num,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
DenseTensor* out) {
auto in_dims = x.dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
PADDLE_ENFORCE_EQ(
input_channels,
output_channels * pooled_height * pooled_width,
errors::InvalidArgument(
"The channels %d of input X should equal the product of "
"output_channels %d x pooled_height %d x pooled_width %d.",
input_channels,
output_channels,
pooled_height,
pooled_width));
int rois_num_t = rois.dims()[0];
if (rois_num_t == 0) return;
int rois_batch_size;
DenseTensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num_t});
int* rois_batch_id_data = ctx.template HostAlloc<int>(&rois_batch_id_list);
if (rois_num.get_ptr()) {
rois_batch_size = rois_num->numel();
auto* rois_num_data = rois_num->data<int>();
PADDLE_ENFORCE_EQ(rois_batch_size,
batch_size,
errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be "
"the same but received batch size of input(ROIs) and "
"input(X) is %d and %d respectively.",
rois_batch_size,
batch_size));
std::vector<int> rois_num_list(rois_batch_size);
paddle::memory::Copy(CPUPlace(),
rois_num_list.data(),
ctx.GetPlace(),
rois_num_data,
sizeof(int) * rois_batch_size,
0);
int rois_num_count = 0;
for (int i = 0; i < rois_batch_size; ++i) {
rois_num_count += rois_num_list[i];
}
PADDLE_ENFORCE_EQ(
rois_num_count,
rois_num_t,
errors::InvalidArgument(
"the rois_num from input and RoisNum must be the same"));
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois.lod().back();
rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(rois_batch_size,
batch_size,
errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be "
"the same but received batch size of input(ROIs) and "
"input(X) is %d and %d respectively.",
rois_batch_size,
batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num_t,
rois_num_with_lod,
errors::InvalidArgument(
"The number of rois from input(ROIs) and its LOD "
"must be the same. Received rois %d of input(ROIs) "
"but the number of rois %d from its LOD is %d",
rois_num,
rois_num_with_lod));
// set rois batch id
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
DenseTensor rois_batch_id_list_gpu;
Copy(ctx, rois_batch_id_list, ctx.GetPlace(), false, &rois_batch_id_list_gpu);
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
// call cuda kernel function
GPUPSROIPoolForward<T><<<blocks, threads, 0, ctx.stream()>>>(
output_size,
x.data<T>(),
rois.data<T>(),
spatial_scale,
input_channels,
height,
width,
output_channels,
pooled_height,
pooled_width,
rois_batch_id_list_gpu.data<int>(),
ctx.template Alloc<T>(out));
}
} // namespace phi
PD_REGISTER_KERNEL(
psroi_pool, GPU, ALL_LAYOUT, phi::PsroiPoolKernel, float, double) {
kernel->InputAt(2).SetDataType(
paddle::experimental::CppTypeToDataType<int>::Type());
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void PsroiPoolGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& rois,
paddle::optional<const DenseTensor&> rois_num,
const DenseTensor& dout,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
DenseTensor* dx);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void PsroiPoolKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& rois,
paddle::optional<const DenseTensor&> rois_num,
int pooled_height,
int pooled_width,
int output_channels,
float spatial_scale,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PsroiPoolOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"psroi_pool",
{"X", "ROIs", "RoisNum"},
{"pooled_height", "pooled_width", "output_channels", "spatial_scale"},
{"Out"});
}
KernelSignature PsroiPoolGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"psroi_pool_grad",
{"X", "ROIs", "RoisNum", GradVarName("Out")},
{"pooled_height", "pooled_width", "output_channels", "spatial_scale"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(psroi_pool, phi::PsroiPoolOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(psroi_pool_grad,
phi::PsroiPoolGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册