未验证 提交 7d0db629 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] move roi_pool kernel to phi (#40574)

* move roi_pool forward kernel to phi

* move roi_pool_grad to phi

* fix compile bug

* fix compile bug

* fix register data_type
上级 681a6865
......@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/roi_pool_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/roi_pool_kernel.h"
namespace paddle {
namespace operators {
......@@ -57,7 +58,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
"%d-dimensional LoDTensor",
rois_dims.size()));
PADDLE_ENFORCE_EQ(
rois_dims[1], kROISize,
rois_dims[1], phi::kROISize,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ...]. But the second dimension of "
......@@ -216,16 +217,7 @@ REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
ops::ROIPoolGradMaker<paddle::framework::OpDesc>,
ops::ROIPoolGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
roi_pool,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(
roi_pool_grad,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_VERSION(roi_pool)
.AddCheckpoint(
R"ROC(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/roi_pool_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void GPUROIPoolForward(
const int nthreads, const T* input_data, const T* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
int* roi_batch_id_data, T* output_data, int64_t* argmax_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
const T* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n];
int roi_start_w = round(offset_input_rois[0] * spatial_scale);
int roi_start_h = round(offset_input_rois[1] * spatial_scale);
int roi_end_w = round(offset_input_rois[2] * spatial_scale);
int roi_end_h = round(offset_input_rois[3] * spatial_scale);
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
int hstart = static_cast<int>(floor(static_cast<double>(ph) *
static_cast<double>(roi_height) /
static_cast<double>(pooled_height)));
int wstart = static_cast<int>(floor(static_cast<double>(pw) *
static_cast<double>(roi_width) /
static_cast<double>(pooled_width)));
int hend = static_cast<int>(ceil(static_cast<double>(ph + 1) *
static_cast<double>(roi_height) /
static_cast<double>(pooled_height)));
int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) *
static_cast<double>(roi_width) /
static_cast<double>(pooled_width)));
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
int maxidx = -1;
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_data_index = h * width + w;
if (offset_input_data[input_data_index] > maxval) {
maxval = offset_input_data[input_data_index];
maxidx = input_data_index;
}
}
}
output_data[i] = maxval;
if (argmax_data) {
argmax_data[i] = maxidx;
}
}
}
template <typename T>
__global__ void GPUROIPoolBackward(
const int nthreads, const T* input_rois, const T* output_grad,
const int64_t* argmax_data, const int num_rois, const float spatial_scale,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, int* roi_batch_id_data,
T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
int roi_batch_ind = roi_batch_id_data[n];
int input_offset = (roi_batch_ind * channels + c) * height * width;
int output_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_output_grad = output_grad + output_offset;
T* offset_input_grad = input_grad + input_offset;
const int64_t* offset_argmax_data = argmax_data + output_offset;
int argmax = offset_argmax_data[ph * pooled_width + pw];
if (argmax != -1) {
platform::CudaAtomicAdd(
offset_input_grad + argmax,
static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
}
}
}
template <typename Place, typename T>
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* out = ctx.Output<Tensor>("Out");
auto* argmax = ctx.Output<Tensor>("Argmax");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims();
int batch_size = in_dims[0];
auto in_stride = phi::stride(in_dims);
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
auto cplace = platform::CPUPlace();
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto& dev_ctx = ctx.cuda_device_context();
auto gplace = ctx.GetPlace();
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be the same but "
"received batch size of input(ROIs) and input(X) is %d and %d "
"respectively.",
rois_batch_size, batch_size));
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_num_list.data(), gplace,
rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be the same but "
"received batch size of input(ROIs) and input(X) is %d and %d "
"respectively.",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
platform::errors::InvalidArgument(
"The number of rois from input(ROIs) and its LOD "
"must be the same. Received rois %d of input(ROIs) "
"but the number of rois %d from its LOD is %d",
rois_num, rois_num_with_lod));
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
}
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
dev_ctx.stream());
GPUROIPoolForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
height, width, pooled_height, pooled_width, roi_id_data,
out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
}
};
template <typename Place, typename T>
class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* rois_lod = ctx.Input<Tensor>("RoisNum");
auto* argmax = ctx.Input<Tensor>("Argmax");
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
int rois_num = rois->dims()[0];
int channels = in->dims()[1];
int height = in->dims()[2];
int width = in->dims()[3];
if (x_grad) {
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
auto cplace = platform::CPUPlace();
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto& dev_ctx = ctx.cuda_device_context();
auto gplace = ctx.GetPlace();
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_num_t->numel();
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_num_list.data(), gplace,
rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
}
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
dev_ctx.stream());
x_grad->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<Place, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
int output_grad_size = out_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUROIPoolBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_grad_size, rois->data<T>(), out_grad->data<T>(),
argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
width, pooled_height, pooled_width, roi_id_data,
x_grad->mutable_data<T>(ctx.GetPlace()));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
roi_pool,
ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
roi_pool_grad,
ops::GPUROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <limits>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
static constexpr int kROISize = 4;
template <typename DeviceContext, typename T>
class CPUROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* argmax = ctx.Output<framework::Tensor>("Argmax");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
auto in_stride = phi::stride(in_dims);
auto argmax_stride = phi::stride(argmax->dims());
auto roi_stride = phi::stride(rois->dims());
auto out_stride = phi::stride(out->dims());
const T* input_data = in->data<T>();
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument("The rois_batch_size and imgs "
"batch_size must be the same."));
auto* rois_num_data = rois_num_t->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument("The rois_batch_size and imgs "
"batch_size must be the same."));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(
rois_num, rois_num_with_lod,
platform::errors::InvalidArgument("The rois_num from input "
"and lod must be the same."));
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
}
T* output_data = out->mutable_data<T>(ctx.GetPlace());
int64_t* argmax_data = argmax->mutable_data<int64_t>(ctx.GetPlace());
const T* rois_data = rois->data<T>();
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = roi_batch_id_data[n];
int roi_start_w = round(rois_data[0] * spatial_scale);
int roi_start_h = round(rois_data[1] * spatial_scale);
int roi_end_w = round(rois_data[2] * spatial_scale);
int roi_end_h = round(rois_data[3] * spatial_scale);
// Force malformed ROIs to be 1x1
int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
const float bin_size_h =
static_cast<float>(roi_height) / static_cast<float>(pooled_height);
const float bin_size_w =
static_cast<float>(roi_width) / static_cast<float>(pooled_width);
const T* batch_data = input_data + roi_batch_id * in_stride[0];
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
// Compute pooling region for this output unit:
// start (included) = floor(ph * roi_height / pooled_height_)
// end (excluded) = ceil((ph + 1) * roi_height / pooled_height_)
int hstart =
static_cast<int>(floor(static_cast<float>(ph) * bin_size_h));
int wstart =
static_cast<int>(floor(static_cast<float>(pw) * bin_size_w));
int hend =
static_cast<int>(ceil(static_cast<float>(ph + 1) * bin_size_h));
int wend =
static_cast<int>(ceil(static_cast<float>(pw + 1) * bin_size_w));
hstart = std::min(std::max(hstart + roi_start_h, 0), height);
hend = std::min(std::max(hend + roi_start_h, 0), height);
wstart = std::min(std::max(wstart + roi_start_w, 0), width);
wend = std::min(std::max(wend + roi_start_w, 0), width);
const int pool_index = ph * pooled_width + pw;
// Define an empty pooling region to be zero
bool is_empty = (hend <= hstart) || (wend <= wstart);
output_data[pool_index] =
is_empty ? 0 : -std::numeric_limits<T>::max();
argmax_data[pool_index] = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
const int index = h * width + w;
if (batch_data[index] > output_data[pool_index]) {
output_data[pool_index] = batch_data[index];
argmax_data[pool_index] = index;
}
}
}
}
}
batch_data += in_stride[1];
output_data += out_stride[1];
argmax_data += argmax_stride[1];
}
// Increment ROI data pointer
rois_data += roi_stride[0];
}
return;
}
};
template <typename DeviceContext, typename T>
class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* argmax = ctx.Input<framework::Tensor>("Argmax");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
if (in_grad) {
int rois_num = rois->dims()[0];
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
auto* rois_num_data = rois_num_t->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
}
const T* rois_data = rois->data<T>();
const T* out_grad_data = out_grad->data<T>();
const int64_t* argmax_data = argmax->data<int64_t>();
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> set_zero;
set_zero(ctx.template device_context<DeviceContext>(), in_grad,
static_cast<T>(0));
auto in_stride = phi::stride(in->dims());
auto argmax_stride = phi::stride(argmax->dims());
auto roi_stride = phi::stride(rois->dims());
auto out_stride = phi::stride(out_grad->dims());
int channels = in->dims()[1];
for (int n = 0; n < rois_num; ++n) {
int roi_batch_idx = roi_batch_id_data[n];
T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0];
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int pool_index = ph * pooled_width + pw;
if (argmax_data[pool_index] >= 0) {
auto index = argmax_data[pool_index];
batch_grad_data[index] += out_grad_data[pool_index];
}
}
}
batch_grad_data += in_stride[1];
out_grad_data += out_stride[1];
argmax_data += argmax_stride[1];
}
rois_data += roi_stride[0];
}
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roi_pool_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void RoiPoolGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& boxes,
paddle::optional<const DenseTensor&> boxes_num,
const DenseTensor& arg_max,
const DenseTensor& out_grad,
int pooled_height,
int pooled_width,
float spatial_scale,
DenseTensor* dx) {
if (dx) {
int rois_num = boxes.dims()[0];
DenseTensor box_batch_id_list = Empty<int>(dev_ctx, {rois_num});
int* box_batch_id_data = box_batch_id_list.data<int>();
int boxes_batch_size;
if (boxes_num) {
boxes_batch_size = boxes_num->numel();
auto* boxes_num_data = boxes_num->data<int>();
int start = 0;
for (int n = 0; n < boxes_batch_size; ++n) {
for (int i = start; i < start + boxes_num_data[n]; ++i) {
box_batch_id_data[i] = n;
}
start += boxes_num_data[n];
}
} else {
auto boxes_lod = boxes.lod().back();
boxes_batch_size = boxes_lod.size() - 1;
for (int n = 0; n < boxes_batch_size; ++n) {
for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) {
box_batch_id_data[i] = n;
}
}
}
const T* boxes_data = boxes.data<T>();
const T* out_grad_data = out_grad.data<T>();
const int64_t* arg_max_data = arg_max.data<int64_t>();
T* dx_data = dev_ctx.template Alloc<T>(dx);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, dx, static_cast<T>(0));
auto in_stride = phi::stride(x.dims());
auto arg_max_stride = phi::stride(arg_max.dims());
auto roi_stride = phi::stride(boxes.dims());
auto out_stride = phi::stride(out_grad.dims());
int channels = x.dims()[1];
for (int n = 0; n < rois_num; ++n) {
int roi_batch_idx = box_batch_id_data[n];
T* batch_grad_data = dx_data + roi_batch_idx * in_stride[0];
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int pool_index = ph * pooled_width + pw;
if (arg_max_data[pool_index] >= 0) {
auto index = arg_max_data[pool_index];
batch_grad_data[index] += out_grad_data[pool_index];
}
}
}
batch_grad_data += in_stride[1];
out_grad_data += out_stride[1];
arg_max_data += arg_max_stride[1];
}
boxes_data += roi_stride[0];
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(roi_pool_grad,
CPU,
ALL_LAYOUT,
phi::RoiPoolGradKernel,
float,
double,
int) {
kernel->InputAt(3).SetDataType(phi::DataType::INT64);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roi_pool_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
template <typename T, typename Context>
void RoiPoolKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& boxes,
paddle::optional<const DenseTensor&> boxes_num,
int pooled_height,
int pooled_width,
float spatial_scale,
DenseTensor* out,
DenseTensor* arg_max) {
auto x_dims = x.dims();
int batch_size = x_dims[0];
int channels = x_dims[1];
int height = x_dims[2];
int width = x_dims[3];
int rois_num = boxes.dims()[0];
auto in_stride = phi::stride(x_dims);
auto arg_max_stride = phi::stride(arg_max->dims());
auto box_stride = phi::stride(boxes.dims());
auto out_stride = phi::stride(out->dims());
const T* input_data = x.data<T>();
DenseTensor box_batch_id_list = Empty<int>(dev_ctx, {rois_num});
int* box_batch_id_data = box_batch_id_list.data<int>();
int boxes_batch_size;
if (boxes_num) {
boxes_batch_size = boxes_num->numel();
PADDLE_ENFORCE_EQ(
boxes_batch_size,
batch_size,
phi::errors::InvalidArgument("The boxes_batch_size and imgs "
"batch_size must be the same."));
auto* boxes_num_data = boxes_num->data<int>();
int start = 0;
for (int n = 0; n < boxes_batch_size; ++n) {
for (int i = start; i < start + boxes_num_data[n]; ++i) {
box_batch_id_data[i] = n;
}
start += boxes_num_data[n];
}
} else {
auto boxes_lod = boxes.lod().back();
boxes_batch_size = boxes_lod.size() - 1;
PADDLE_ENFORCE_EQ(
boxes_batch_size,
batch_size,
phi::errors::InvalidArgument("The boxes_batch_size and imgs "
"batch_size must be the same."));
int rois_num_with_lod = boxes_lod[boxes_batch_size];
PADDLE_ENFORCE_EQ(
rois_num,
rois_num_with_lod,
phi::errors::InvalidArgument("The rois_num from input "
"and lod must be the same."));
for (int n = 0; n < boxes_batch_size; ++n) {
for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) {
box_batch_id_data[i] = n;
}
}
}
T* output_data = dev_ctx.template Alloc<T>(out);
int64_t* arg_max_data = dev_ctx.template Alloc<int64_t>(arg_max);
const T* boxes_data = boxes.data<T>();
for (int n = 0; n < rois_num; ++n) {
int box_batch_id = box_batch_id_data[n];
int box_start_w = round(boxes_data[0] * spatial_scale);
int box_start_h = round(boxes_data[1] * spatial_scale);
int box_end_w = round(boxes_data[2] * spatial_scale);
int box_end_h = round(boxes_data[3] * spatial_scale);
// Force malformed ROIs to be 1x1
int box_height = std::max(box_end_h - box_start_h + 1, 1);
int box_width = std::max(box_end_w - box_start_w + 1, 1);
const float bin_size_h =
static_cast<float>(box_height) / static_cast<float>(pooled_height);
const float bin_size_w =
static_cast<float>(box_width) / static_cast<float>(pooled_width);
const T* batch_data = input_data + box_batch_id * in_stride[0];
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
// Compute pooling region for this output unit:
// start (included) = floor(ph * box_height / pooled_height_)
// end (excluded) = ceil((ph + 1) * box_height / pooled_height_)
int hstart =
static_cast<int>(floor(static_cast<float>(ph) * bin_size_h));
int wstart =
static_cast<int>(floor(static_cast<float>(pw) * bin_size_w));
int hend =
static_cast<int>(ceil(static_cast<float>(ph + 1) * bin_size_h));
int wend =
static_cast<int>(ceil(static_cast<float>(pw + 1) * bin_size_w));
hstart = std::min(std::max(hstart + box_start_h, 0), height);
hend = std::min(std::max(hend + box_start_h, 0), height);
wstart = std::min(std::max(wstart + box_start_w, 0), width);
wend = std::min(std::max(wend + box_start_w, 0), width);
const int pool_index = ph * pooled_width + pw;
// Define an empty pooling region to be zero
bool is_empty = (hend <= hstart) || (wend <= wstart);
output_data[pool_index] =
is_empty ? 0 : -std::numeric_limits<T>::max();
arg_max_data[pool_index] = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
const int index = h * width + w;
if (batch_data[index] > output_data[pool_index]) {
output_data[pool_index] = batch_data[index];
arg_max_data[pool_index] = index;
}
}
}
}
}
batch_data += in_stride[1];
output_data += out_stride[1];
arg_max_data += arg_max_stride[1];
}
// Increment ROI data pointer
boxes_data += box_stride[0];
}
}
} // namespace phi
PD_REGISTER_KERNEL(
roi_pool, CPU, ALL_LAYOUT, phi::RoiPoolKernel, float, double, int) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -18,7 +18,6 @@
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/fluid/memory/memory.h"
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roi_pool_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void GPURoiPoolBackward(const int nthreads,
const T* input_rois,
const T* output_grad,
const int64_t* arg_max_data,
const int num_rois,
const float spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
int* box_batch_id_data,
T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
int roi_batch_ind = box_batch_id_data[n];
int input_offset = (roi_batch_ind * channels + c) * height * width;
int output_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_output_grad = output_grad + output_offset;
T* offset_input_grad = input_grad + input_offset;
const int64_t* offset_arg_max_data = arg_max_data + output_offset;
int arg_max = offset_arg_max_data[ph * pooled_width + pw];
if (arg_max != -1) {
paddle::platform::CudaAtomicAdd(
offset_input_grad + arg_max,
static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
}
}
}
template <typename T, typename Context>
void RoiPoolGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& boxes,
paddle::optional<const DenseTensor&> boxes_num,
const DenseTensor& arg_max,
const DenseTensor& out_grad,
int pooled_height,
int pooled_width,
float spatial_scale,
DenseTensor* dx) {
auto x_dims = x.dims();
int channels = x_dims[1];
int height = x_dims[2];
int width = x_dims[3];
int rois_num = boxes.dims()[0];
if (dx) {
DenseTensor box_batch_id_list;
box_batch_id_list.Resize({rois_num});
int* box_batch_id_data =
dev_ctx.template HostAlloc<int>(&box_batch_id_list);
auto gplace = dev_ctx.GetPlace();
if (boxes_num) {
int boxes_batch_size = boxes_num->numel();
std::vector<int> boxes_num_list(boxes_batch_size);
paddle::memory::Copy(phi::CPUPlace(),
boxes_num_list.data(),
gplace,
boxes_num->data<int>(),
sizeof(int) * boxes_batch_size,
0);
int start = 0;
for (int n = 0; n < boxes_batch_size; ++n) {
for (int i = start; i < start + boxes_num_list[n]; ++i) {
box_batch_id_data[i] = n;
}
start += boxes_num_list[n];
}
} else {
auto boxes_lod = boxes.lod().back();
int boxes_batch_size = boxes_lod.size() - 1;
for (int n = 0; n < boxes_batch_size; ++n) {
for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) {
box_batch_id_data[i] = n;
}
}
}
int bytes = box_batch_id_list.numel() * sizeof(int);
auto roi_ptr = paddle::memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
paddle::memory::Copy(gplace,
roi_id_data,
phi::CPUPlace(),
box_batch_id_data,
bytes,
dev_ctx.stream());
dev_ctx.template Alloc<T>(dx);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, dx, static_cast<T>(0));
int output_grad_size = out_grad.numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPURoiPoolBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_grad_size,
boxes.data<T>(),
out_grad.data<T>(),
arg_max.data<int64_t>(),
rois_num,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
roi_id_data,
dx->data<T>());
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
roi_pool_grad, GPU, ALL_LAYOUT, phi::RoiPoolGradKernel, float, double) {
kernel->InputAt(3).SetDataType(phi::DataType::INT64);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roi_pool_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/memory/memory.h"
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void GPURoiPoolForward(const int nthreads,
const T* input_data,
const T* input_rois,
const float spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
int* box_batch_id_data,
T* output_data,
int64_t* arg_max_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
const T* offset_input_rois = input_rois + n * kROISize;
int box_batch_ind = box_batch_id_data[n];
int box_start_w = round(offset_input_rois[0] * spatial_scale);
int box_start_h = round(offset_input_rois[1] * spatial_scale);
int box_end_w = round(offset_input_rois[2] * spatial_scale);
int box_end_h = round(offset_input_rois[3] * spatial_scale);
int box_width = max(box_end_w - box_start_w + 1, 1);
int box_height = max(box_end_h - box_start_h + 1, 1);
int hstart = static_cast<int>(floor(static_cast<double>(ph) *
static_cast<double>(box_height) /
static_cast<double>(pooled_height)));
int wstart = static_cast<int>(floor(static_cast<double>(pw) *
static_cast<double>(box_width) /
static_cast<double>(pooled_width)));
int hend = static_cast<int>(ceil(static_cast<double>(ph + 1) *
static_cast<double>(box_height) /
static_cast<double>(pooled_height)));
int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) *
static_cast<double>(box_width) /
static_cast<double>(pooled_width)));
hstart = min(max(hstart + box_start_h, 0), height);
hend = min(max(hend + box_start_h, 0), height);
wstart = min(max(wstart + box_start_w, 0), width);
wend = min(max(wend + box_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
int maxidx = -1;
const T* offset_input_data =
input_data + (box_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_data_index = h * width + w;
if (offset_input_data[input_data_index] > maxval) {
maxval = offset_input_data[input_data_index];
maxidx = input_data_index;
}
}
}
output_data[i] = maxval;
if (arg_max_data) {
arg_max_data[i] = maxidx;
}
}
}
template <typename T, typename Context>
void RoiPoolKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& boxes,
paddle::optional<const DenseTensor&> boxes_num,
int pooled_height,
int pooled_width,
float spatial_scale,
DenseTensor* out,
DenseTensor* arg_max) {
auto x_dims = x.dims();
int batch_size = x_dims[0];
auto in_stride = phi::stride(x_dims);
int channels = x_dims[1];
int height = x_dims[2];
int width = x_dims[3];
int rois_num = boxes.dims()[0];
if (rois_num == 0) return;
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
DenseTensor box_batch_id_list;
box_batch_id_list.Resize({rois_num});
int* box_batch_id_data = dev_ctx.template HostAlloc<int>(&box_batch_id_list);
auto gplace = dev_ctx.GetPlace();
if (boxes_num) {
int boxes_batch_size = boxes_num->numel();
PADDLE_ENFORCE_EQ(
boxes_batch_size,
batch_size,
phi::errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be the same but "
"received batch size of input(ROIs) and input(X) is %d and %d "
"respectively.",
boxes_batch_size,
batch_size));
std::vector<int> boxes_num_list(boxes_batch_size);
paddle::memory::Copy(phi::CPUPlace(),
boxes_num_list.data(),
gplace,
boxes_num->data<int>(),
sizeof(int) * boxes_batch_size,
0);
int start = 0;
for (int n = 0; n < boxes_batch_size; ++n) {
for (int i = start; i < start + boxes_num_list[n]; ++i) {
box_batch_id_data[i] = n;
}
start += boxes_num_list[n];
}
} else {
auto boxes_lod = boxes.lod().back();
int boxes_batch_size = boxes_lod.size() - 1;
PADDLE_ENFORCE_EQ(
boxes_batch_size,
batch_size,
phi::errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be the same but "
"received batch size of input(ROIs) and input(X) is %d and %d "
"respectively.",
boxes_batch_size,
batch_size));
int boxes_num_with_lod = boxes_lod[boxes_batch_size];
PADDLE_ENFORCE_EQ(rois_num,
boxes_num_with_lod,
phi::errors::InvalidArgument(
"The number of rois from input(ROIs) and its LOD "
"must be the same. Received rois %d of input(ROIs) "
"but the number of rois %d from its LOD is %d",
rois_num,
boxes_num_with_lod));
for (int n = 0; n < boxes_batch_size; ++n) {
for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) {
box_batch_id_data[i] = n;
}
}
}
int bytes = box_batch_id_list.numel() * sizeof(int);
auto box_ptr = paddle::memory::Alloc(dev_ctx, bytes);
int* box_id_data = reinterpret_cast<int*>(box_ptr->ptr());
paddle::memory::Copy(gplace,
box_id_data,
phi::CPUPlace(),
box_batch_id_data,
bytes,
dev_ctx.stream());
T* output_data = dev_ctx.template Alloc<T>(out);
int64_t* arg_max_data = dev_ctx.template Alloc<int64_t>(arg_max);
GPURoiPoolForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_size,
x.data<T>(),
boxes.data<T>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
box_id_data,
output_data,
arg_max_data);
}
} // namespace phi
PD_REGISTER_KERNEL(
roi_pool, GPU, ALL_LAYOUT, phi::RoiPoolKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void RoiPooGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& boxes,
paddle::optional<const DenseTensor&> boxes_num,
const DenseTensor& arg_max,
const DenseTensor& out_grad,
int pooled_height,
int pooled_width,
float spatial_scale,
DenseTensor* dx);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
static constexpr int kROISize = 4;
template <typename T, typename Context>
void RoiPoolKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& boxes,
paddle::optional<const DenseTensor&> boxes_num,
int pooled_height,
int pooled_width,
float spatial_scale,
DenseTensor* out,
DenseTensor* arg_max);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature RoiPoolOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("roi_pool",
{"X", "ROIs", "RoisNum"},
{"pooled_height", "pooled_width", "spatial_scale"},
{"Out", "Argmax"});
}
KernelSignature RoiPoolOpGradArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("roi_pool_grad",
{"X", "ROIs", "RoisNum", "Argmax", GradVarName("Out")},
{"pooled_height", "pooled_width", "spatial_scale"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(roi_pool, phi::RoiPoolOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(roi_pool_grad, phi::RoiPoolOpGradArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册