diff --git a/paddle/fluid/operators/psroi_pool_op.cc b/paddle/fluid/operators/psroi_pool_op.cc index da637dfeb237dd4f17816e784882720dc2f2ff64..cfacffff234105ac9c6dc41b86f06594d319dcbb 100644 --- a/paddle/fluid/operators/psroi_pool_op.cc +++ b/paddle/fluid/operators/psroi_pool_op.cc @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/psroi_pool_op.h" -#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - class PSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -82,75 +82,6 @@ class PSROIPoolOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(X) of PSROIPoolOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true, - platform::errors::InvalidArgument( - "Input(ROIs) of PSROIPoolOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(Out) of PSROIPoolOp should not be null.")); - auto input_dims = ctx->GetInputDim("X"); - auto rois_dims = ctx->GetInputDim("ROIs"); - - PADDLE_ENFORCE_EQ(input_dims.size(), 4, - platform::errors::InvalidArgument( - "The format of input tensor is NCHW")); - PADDLE_ENFORCE_EQ( - rois_dims.size(), 2, - platform::errors::InvalidArgument( - "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " - "given as [(x1, y1, x2, y2), ...]")); - PADDLE_ENFORCE_EQ( - rois_dims[1], 4, - platform::errors::InvalidArgument( - "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " - "given as [(x1, y1, x2, y2), ...]")); - if (ctx->HasInput("RoisNum")) { - auto rois_num_dims = ctx->GetInputDim("RoisNum"); - PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1, - platform::errors::InvalidArgument( - "The second dimension of RoisNum should " - "be 1, but received dimension is %d", - rois_num_dims.size())); - } - int pooled_height = ctx->Attrs().Get("pooled_height"); - int pooled_width = ctx->Attrs().Get("pooled_width"); - int output_channels = ctx->Attrs().Get("output_channels"); - float spatial_scale = ctx->Attrs().Get("spatial_scale"); - - PADDLE_ENFORCE_EQ( - input_dims[1], output_channels * pooled_height * pooled_width, - platform::errors::InvalidArgument( - "the channel of X(%d) " - "should be equal to the product of " - "output_channels(%d), pooled_height(%d) and pooled_width(%d)", - input_dims[1], output_channels, pooled_height, pooled_width)); - - PADDLE_ENFORCE_GT(pooled_height, 0, - platform::errors::InvalidArgument( - "The pooled output height must be greater than 0")); - PADDLE_ENFORCE_GT(pooled_width, 0, - platform::errors::InvalidArgument( - "The pooled output width must be greater than 0")); - PADDLE_ENFORCE_GT(output_channels, 1, - platform::errors::InvalidArgument( - "The pooled output channels must greater than 1")); - PADDLE_ENFORCE_GT(spatial_scale, 0.0f, - platform::errors::InvalidArgument( - "The spatial scale must greater than 0.")); - - auto out_dims = input_dims; - out_dims[0] = rois_dims[0]; - out_dims[1] = - output_channels; // input_dims[1] / (pooled_height * pooled_width); - out_dims[2] = pooled_height; - out_dims[3] = pooled_width; - ctx->SetOutputDim("Out", out_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -164,16 +95,6 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, - platform::errors::InvalidArgument( - "The gradient of Out should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, - platform::errors::InvalidArgument( - "The gradient of X should not be null.")); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -204,15 +125,13 @@ class PSROIPoolGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(psroi_pool, PsroiPoolInferShapeFunctor, + PD_INFER_META(phi::PsroiPoolInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(psroi_pool_grad, PsroiPoolGradInferShapeFunctor, + PD_INFER_META(phi::PsroiPoolGradInferMeta)); REGISTER_OPERATOR(psroi_pool, ops::PSROIPoolOp, ops::PSROIPoolOpMaker, ops::PSROIPoolGradMaker, - ops::PSROIPoolGradMaker); -REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp); -REGISTER_OP_CPU_KERNEL( - psroi_pool, - ops::CPUPSROIPoolOpKernel, - ops::CPUPSROIPoolOpKernel); -REGISTER_OP_CPU_KERNEL( - psroi_pool_grad, - ops::CPUPSROIPoolGradOpKernel, - ops::CPUPSROIPoolGradOpKernel); + ops::PSROIPoolGradMaker, + PsroiPoolInferShapeFunctor); +REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp, + PsroiPoolGradInferShapeFunctor); diff --git a/paddle/fluid/operators/psroi_pool_op.cu b/paddle/fluid/operators/psroi_pool_op.cu deleted file mode 100644 index c1917501db8b5afebf4b7951b0f04de69758b49d..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/psroi_pool_op.cu +++ /dev/null @@ -1,350 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/psroi_pool_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaximumNumBlocks = 4096; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaximumNumBlocks); -} - -template -__global__ void GPUPSROIPoolForward( - const int nthreads, const T* input_data, const T* input_rois, - const float spatial_scale, const int input_channels, const int height, - const int width, const int output_channels, const int pooled_height, - const int pooled_width, const int* rois_batch_id_data, T* output_data) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - // The output is in order (n, c, ph, pw) - int pw = i % pooled_width; - int ph = (i / pooled_width) % pooled_height; - int c = (i / pooled_width / pooled_height) % output_channels; - int n = i / pooled_width / pooled_height / output_channels; - - // set roi_batch_id - int roi_batch_id = rois_batch_id_data[n]; - - // [start, end) interval for spatial sampling - const T* offset_input_rois = input_rois + n * 4; - T roi_start_w = static_cast(round(offset_input_rois[0])) * spatial_scale; - T roi_start_h = static_cast(round(offset_input_rois[1])) * spatial_scale; - T roi_end_w = - static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; - T roi_end_h = - static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; - - // Force too small ROIs to be 1x1 - T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 - T roi_width = max(roi_end_w - roi_start_w, (T)0.1); - - // Compute w and h at input feature map - T bin_size_h = roi_height / static_cast(pooled_height); - T bin_size_w = roi_width / static_cast(pooled_width); - - int hstart = floor(bin_size_h * static_cast(ph) + roi_start_h); - int wstart = floor(bin_size_w * static_cast(pw) + roi_start_w); - int hend = ceil(bin_size_h * static_cast(ph + 1) + roi_start_h); - int wend = ceil(bin_size_w * static_cast(pw + 1) + roi_start_w); - - // Add roi offsets and clip to input boundaries - hstart = min(max(hstart, 0), height); - hend = min(max(hend, 0), height); - wstart = min(max(wstart, 0), width); - wend = min(max(wend, 0), width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - int input_channel = (c * pooled_height + ph) * pooled_width + pw; - const T* offset_input_data = - input_data + - (roi_batch_id * input_channels + input_channel) * height * width; - T outsum = 0; - - for (int ih = hstart; ih < hend; ++ih) { - for (int iw = wstart; iw < wend; ++iw) { - int input_index = ih * width + iw; - outsum += offset_input_data[input_index]; - } - } - - T bin_area = static_cast((hend - hstart) * (wend - wstart)); - output_data[i] = is_empty ? 0. : outsum / bin_area; - } -} - -template -__global__ void GPUPSROIPoolBackward( - const int nthreads, const T* input_rois, const T* output_grad_data, - const float spatial_scale, const int input_channels, const int height, - const int width, const int output_channels, const int pooled_height, - const int pooled_width, const int* rois_batch_id_data, T* input_grad_data) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (int i = index; i < nthreads; i += offset) { - // The output is in order (n, c, ph, pw) - int pw = i % pooled_width; - int ph = (i / pooled_width) % pooled_height; - int c = (i / pooled_width / pooled_height) % output_channels; - int n = i / pooled_width / pooled_height / output_channels; - - // set roi_batch_id - int roi_batch_id = rois_batch_id_data[n]; - int input_channel = (c * pooled_height + ph) * pooled_width + pw; - int input_offset = - (roi_batch_id * input_channels + input_channel) * height * width; - T* offset_input_grad_data = input_grad_data + input_offset; - - // [start, end) interval for spatial sampling - const T* offset_input_rois = input_rois + n * 4; - T roi_start_w = static_cast(round(offset_input_rois[0])) * spatial_scale; - T roi_start_h = static_cast(round(offset_input_rois[1])) * spatial_scale; - T roi_end_w = - static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; - T roi_end_h = - static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; - - // Force too small ROIs to be 1x1 - T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 - T roi_width = max(roi_end_w - roi_start_w, (T)0.1); - - // Compute w and h at input feature map - T bin_size_h = roi_height / static_cast(pooled_height); - T bin_size_w = roi_width / static_cast(pooled_width); - - int hstart = floor(bin_size_h * static_cast(ph) + roi_start_h); - int wstart = floor(bin_size_w * static_cast(pw) + roi_start_w); - int hend = ceil(bin_size_h * static_cast(ph + 1) + roi_start_h); - int wend = ceil(bin_size_w * static_cast(pw + 1) + roi_start_w); - - // Add roi offsets and clip to input boundaries - hstart = min(max(hstart, 0), height); - hend = min(max(hend, 0), height); - wstart = min(max(wstart, 0), width); - wend = min(max(wend, 0), width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - // Accumulate diff_val into input data - T bin_area = static_cast((hend - hstart) * (wend - wstart)); - T diff_val = is_empty ? 0. : output_grad_data[i] / bin_area; - for (int ih = hstart; ih < hend; ++ih) { - for (int iw = wstart; iw < wend; ++iw) { - int input_index = ih * width + iw; - platform::CudaAtomicAdd(offset_input_grad_data + input_index, diff_val); - } - } - } -} - -template -class GPUPSROIPoolOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* out = ctx.Output("Out"); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto output_channels = ctx.Attr("output_channels"); - auto spatial_scale = ctx.Attr("spatial_scale"); - - auto in_dims = in->dims(); - int batch_size = in_dims[0]; - int input_channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - - PADDLE_ENFORCE_EQ( - input_channels, output_channels * pooled_height * pooled_width, - platform::errors::InvalidArgument( - "The channels %d of input X should equal the product of " - "output_channels %d x pooled_height %d x pooled_width %d.", - input_channels, output_channels, pooled_height, pooled_width)); - - int rois_num = rois->dims()[0]; - if (rois_num == 0) return; - int rois_batch_size; - framework::Tensor rois_batch_id_list; - rois_batch_id_list.Resize({rois_num}); - int* rois_batch_id_data = - rois_batch_id_list.mutable_data(platform::CPUPlace()); - - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - rois_batch_size = rois_num_t->numel(); - auto* rois_num_data = rois_num_t->data(); - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The batch size of input(ROIs) and input(X) must be " - "the same but received batch size of input(ROIs) and " - "input(X) is %d and %d respectively.", - rois_batch_size, batch_size)); - std::vector rois_num_list(rois_batch_size); - memory::Copy(platform::CPUPlace(), rois_num_list.data(), ctx.GetPlace(), - rois_num_data, sizeof(int) * rois_batch_size, 0); - int rois_num_count = 0; - for (int i = 0; i < rois_batch_size; ++i) { - rois_num_count += rois_num_list[i]; - } - PADDLE_ENFORCE_EQ( - rois_num_count, rois_num, - platform::errors::InvalidArgument( - "the rois_num from input and RoisNum must be the same")); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_list[n]; ++i) { - rois_batch_id_data[i] = n; - } - start += rois_num_list[n]; - } - } else { - auto rois_lod = rois->lod().back(); - rois_batch_size = rois_lod.size() - 1; - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The batch size of input(ROIs) and input(X) must be " - "the same but received batch size of input(ROIs) and " - "input(X) is %d and %d respectively.", - rois_batch_size, batch_size)); - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, - platform::errors::InvalidArgument( - "The number of rois from input(ROIs) and its LOD " - "must be the same. Received rois %d of input(ROIs) " - "but the number of rois %d from its LOD is %d", - rois_num, rois_num_with_lod)); - - // set rois batch id - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; - } - } - } - framework::Tensor rois_batch_id_list_gpu; - framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(), - ctx.device_context(), &rois_batch_id_list_gpu); - - int output_size = out->numel(); - int blocks = NumBlocks(output_size); - int threads = kNumCUDAThreads; - - // call cuda kernel function - GPUPSROIPoolForward< - T><<>>( - output_size, in->data(), rois->data(), spatial_scale, - input_channels, height, width, output_channels, pooled_height, - pooled_width, rois_batch_id_list_gpu.data(), - out->mutable_data(ctx.GetPlace())); - } -}; - -template -class GPUPSROIPoolGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - - auto* output_grad = ctx.Input(framework::GradVarName("Out")); - auto* input_grad = ctx.Output(framework::GradVarName("X")); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto output_channels = ctx.Attr("output_channels"); - auto spatial_scale = ctx.Attr("spatial_scale"); - - int rois_num = rois->dims()[0]; - int input_channels = in->dims()[1]; - int height = in->dims()[2]; - int width = in->dims()[3]; - - if (input_grad) { - // set roi batch id - framework::Tensor rois_batch_id_list; - rois_batch_id_list.Resize({rois_num}); - int* rois_batch_id_data = - rois_batch_id_list.mutable_data(platform::CPUPlace()); - int rois_batch_size; - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - rois_batch_size = rois_num_t->numel(); - std::vector rois_num_list(rois_batch_size); - memory::Copy(platform::CPUPlace(), rois_num_list.data(), ctx.GetPlace(), - rois_num_t->data(), sizeof(int) * rois_batch_size, 0); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_list[n]; ++i) { - rois_batch_id_data[i] = n; - } - start += rois_num_list[n]; - } - } else { - auto rois_lod = rois->lod().back(); - rois_batch_size = rois_lod.size() - 1; - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; - } - } - } - framework::Tensor rois_batch_id_list_gpu; - framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(), - ctx.device_context(), &rois_batch_id_list_gpu); - - input_grad->mutable_data(ctx.GetPlace()); - phi::funcs::SetConstant set_zero; - set_zero(ctx.cuda_device_context(), input_grad, static_cast(0)); - - int output_grad_size = output_grad->numel(); - int blocks = NumBlocks(output_grad_size); - int threads = kNumCUDAThreads; - - if (output_grad_size > 0) { - GPUPSROIPoolBackward< - T><<>>( - output_grad_size, rois->data(), output_grad->data(), - spatial_scale, input_channels, height, width, output_channels, - pooled_height, pooled_width, rois_batch_id_list_gpu.data(), - input_grad->mutable_data(ctx.GetPlace())); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - psroi_pool, - ops::GPUPSROIPoolOpKernel, - ops::GPUPSROIPoolOpKernel); -REGISTER_OP_CUDA_KERNEL( - psroi_pool_grad, - ops::GPUPSROIPoolGradOpKernel, - ops::GPUPSROIPoolGradOpKernel); diff --git a/paddle/fluid/operators/psroi_pool_op.h b/paddle/fluid/operators/psroi_pool_op.h deleted file mode 100644 index 3f020d93391b0e648898c1b83858a7bd9809aa03..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/psroi_pool_op.h +++ /dev/null @@ -1,295 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -class CPUPSROIPoolOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* out = ctx.Output("Out"); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto spatial_scale = ctx.Attr("spatial_scale"); - auto output_channels = ctx.Attr("output_channels"); - - auto in_dims = in->dims(); - int batch_size = in_dims[0]; - int input_channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - int rois_num = rois->dims()[0]; - - PADDLE_ENFORCE_EQ(input_channels, - output_channels * pooled_height * pooled_width, - platform::errors::InvalidArgument( - "the channels of input " - "X should equal the product of " - "output_channels x pooled_height x pooled_width")); - - auto in_stride = phi::stride(in_dims); - auto out_stride = phi::stride(out->dims()); - - const T* input_data = in->data(); - - framework::Tensor rois_batch_id_list; - rois_batch_id_list.Resize({rois_num}); - int* rois_batch_id_data = - rois_batch_id_list.mutable_data(ctx.GetPlace()); - int rois_batch_size; - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - rois_batch_size = rois_num_t->numel(); - auto* rois_num_data = rois_num_t->data(); - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The batch size of rois and the batch size of images " - " must be the same. But received the batch size of rois is %d, " - "and the batch size of images is %d", - rois_batch_size, batch_size)); - int rois_num_count = 0; - for (int i = 0; i < rois_batch_size; ++i) { - rois_num_count += rois_num_data[i]; - } - PADDLE_ENFORCE_EQ( - rois_num_count, rois_num, - platform::errors::InvalidArgument( - "the rois_num from input and RoisNum must be the same")); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_data[n]; ++i) { - rois_batch_id_data[i] = n; - } - start += rois_num_data[n]; - } - } else { - auto rois_lod = rois->lod().back(); - rois_batch_size = rois_lod.size() - 1; - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument("the rois_batch_size and input(X) " - "batch_size should be the same.")); - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ( - rois_num_with_lod, rois_num, - platform::errors::InvalidArgument( - "the rois_num from input and lod must be the same")); - // calculate batch id index for each roi according to LoD - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; - } - } - } - T* output_data = out->mutable_data(ctx.GetPlace()); - const T* input_rois = rois->data(); - - // calculate psroipooling, parallel processing can be implemented per ROI - for (int n = 0; n < rois_num; ++n) { - // set roi batch id - int roi_batch_id = rois_batch_id_data[n]; - - // [start, end) interval for spatial sampling - const T* offset_input_rois = input_rois + n * 4; - T roi_start_w = - static_cast(round(offset_input_rois[0])) * spatial_scale; - T roi_start_h = - static_cast(round(offset_input_rois[1])) * spatial_scale; - T roi_end_w = - static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; - T roi_end_h = - static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; - // Force too small rois to be 1 x 1 - T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 - T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1); - - // Compute bin size w and h at input feature map - T bin_size_h = roi_height / static_cast(pooled_height); - T bin_size_w = roi_width / static_cast(pooled_width); - - // calculate each pixel of the output feature map. - int out_roi_offset = n * out_stride[0]; - for (int c = 0; c < output_channels; ++c) { - // per category - int out_plane_offset = out_roi_offset + c * out_stride[1]; - for (int ph = 0; ph < pooled_height; ++ph) { - int out_row_offset = out_plane_offset + ph * out_stride[2]; - for (int pw = 0; pw < pooled_width; ++pw) { - // calculate w and h at input feature map - int hstart = floor(static_cast(ph) * bin_size_h + roi_start_h); - int wstart = floor(static_cast(pw) * bin_size_w + roi_start_w); - int hend = ceil(static_cast(ph + 1) * bin_size_h + roi_start_h); - int wend = ceil(static_cast(pw + 1) * bin_size_w + roi_start_w); - // Add roi offsets and clip to input boundaries - hstart = std::min(std::max(hstart, 0), height); - wstart = std::min(std::max(wstart, 0), width); - hend = std::min(std::max(hend, 0), height); - wend = std::min(std::max(wend, 0), width); - - int output_index = out_row_offset + pw; - int input_channel = (c * pooled_height + ph) * pooled_width + pw; - int input_plane_offset = - roi_batch_id * in_stride[0] + input_channel * in_stride[1]; - const T* offset_input_data = input_data + input_plane_offset; - T out_sum = 0.; - bool is_empty = (hend <= hstart) || (wend <= wstart); - for (int ih = hstart; ih < hend; ++ih) { - for (int iw = wstart; iw < wend; ++iw) { - int input_index = ih * in_stride[2] + iw; - out_sum += offset_input_data[input_index]; - } - } - T bin_area = (hend - hstart) * (wend - wstart); - output_data[output_index] = is_empty ? 0. : out_sum / bin_area; - } - } - } - } - return; - } -}; - -template -class CPUPSROIPoolGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* output_grad = - ctx.Input(framework::GradVarName("Out")); - auto* input_grad = - ctx.Output(framework::GradVarName("X")); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto output_channels = ctx.Attr("output_channels"); - auto spatial_scale = ctx.Attr("spatial_scale"); - - if (input_grad) { - auto in_dims = in->dims(); - int input_channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - int rois_num = rois->dims()[0]; - - // set roi batch id - framework::Tensor rois_batch_id_list; - rois_batch_id_list.Resize({rois_num}); - int* rois_batch_id_data = - rois_batch_id_list.mutable_data(ctx.GetPlace()); - int rois_batch_size; - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - rois_batch_size = rois_num_t->numel(); - auto* rois_num_data = rois_num_t->data(); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_data[n]; ++i) { - rois_batch_id_data[i] = n; - } - start += rois_num_data[n]; - } - } else { - auto rois_lod = rois->lod().back(); - rois_batch_size = rois_lod.size() - 1; - // calculate batch id index for each roi according to LoD - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; - } - } - } - const T* input_rois = rois->data(); - const T* output_grad_data = output_grad->data(); - T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - - // set gradient of X to be 0. before backpropagate. - phi::funcs::SetConstant set_zero; - set_zero(ctx.template device_context(), input_grad, - static_cast(0)); - - // backpropagate gradient per output pixel - int output_grad_size = output_grad->numel(); - for (int i = 0; i < output_grad_size; ++i) { - // The output is in order (n, c, ph, pw) - int pw = i % pooled_width; - int ph = (i / pooled_width) % pooled_height; - int c = (i / pooled_width / pooled_height) % output_channels; - int n = i / pooled_width / pooled_height / output_channels; - - // set roi_batch_id - int roi_batch_id = rois_batch_id_data[n]; - int input_channel = (c * pooled_height + ph) * pooled_width + pw; - int input_offset = - (roi_batch_id * input_channels + input_channel) * height * width; - T* offset_input_grad_data = input_grad_data + input_offset; - - // [start, end) interval for spatial sampling - const T* offset_input_rois = input_rois + n * 4; - T roi_start_w = - static_cast(round(offset_input_rois[0])) * spatial_scale; - T roi_start_h = - static_cast(round(offset_input_rois[1])) * spatial_scale; - T roi_end_w = - static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; - T roi_end_h = - static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; - - // Force too small ROIs to be 1x1 - T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 - T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1); - - // Compute w and h at input feature map - T bin_size_h = roi_height / static_cast(pooled_height); - T bin_size_w = roi_width / static_cast(pooled_width); - - int hstart = floor(bin_size_h * static_cast(ph) + roi_start_h); - int wstart = floor(bin_size_w * static_cast(pw) + roi_start_w); - int hend = ceil(bin_size_h * static_cast(ph + 1) + roi_start_h); - int wend = ceil(bin_size_w * static_cast(pw + 1) + roi_start_w); - - // Add roi offsets and clip to input boundaries - hstart = std::min(std::max(hstart, 0), height); - hend = std::min(std::max(hend, 0), height); - wstart = std::min(std::max(wstart, 0), width); - wend = std::min(std::max(wend, 0), width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - // Accumulate diff_val into input data - T bin_area = static_cast((hend - hstart) * (wend - wstart)); - T diff_val = is_empty ? 0. : output_grad_data[i] / bin_area; - for (int ih = hstart; ih < hend; ++ih) { - for (int iw = wstart; iw < wend; ++iw) { - int input_index = ih * width + iw; - offset_input_grad_data[input_index] += diff_val; - } - } - } - } - return; - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 4ddef5b0002e286181ce5ac1ad198136424861a9..0a2b4dcae58ca07b054e04a5a5f8e7a720591034 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -115,6 +115,18 @@ void GatherNdGradInferMeta(const MetaTensor& x, x_grad->set_dtype(dtype); } +void PsroiPoolGradInferMeta(const MetaTensor& x, + const MetaTensor& rois, + paddle::optional rois_num, + const MetaTensor& dout, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + MetaTensor* dx) { + dx->share_meta(x); +} + void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index f7b0eed5dd929e180810af52914e9a3139676e8a..c4003ca1fe76b865079e8f577fdee9db3be895ab 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -47,6 +47,16 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out, int axis, MetaTensor* dx); +void PsroiPoolGradInferMeta(const MetaTensor& x, + const MetaTensor& rois, + paddle::optional rois_num, + const MetaTensor& dout, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + MetaTensor* dx); + void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index acce40713b82159e9e6fd902a30c8b269c6c4e52..84441ed8b740be172ddaa7de3fc23ad420ebf077 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -28,6 +28,98 @@ std::vector GetMetaTensorsDim(const std::vector& tensors) { return dims; } +void AdadeltaInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& avg_squared_grad, + const MetaTensor& avg_squared_update, + float rho, + float epsilon, + MetaTensor* param_out, + MetaTensor* avg_squared_grad_out, + MetaTensor* avg_squared_update_out) { + auto param_dims = param.dims(); + PADDLE_ENFORCE_EQ( + param_dims, + grad.dims(), + errors::InvalidArgument( + "Param and grad input of AdadeltaOp should have same dimension.")); + PADDLE_ENFORCE_EQ( + param_dims, + avg_squared_grad.dims(), + errors::InvalidArgument("Param and AvgSquaredGrad input of AdadeltaOp " + "should have same dimension")); + PADDLE_ENFORCE_EQ( + param_dims, + avg_squared_update.dims(), + errors::InvalidArgument("Param and AvgSquaredUpdate input of AdadeltaOp " + "should have same dimension")); + + param_out->set_dims(param_dims); + param_out->set_dtype(param.dtype()); + + avg_squared_grad_out->set_dims(param_dims); + avg_squared_grad_out->set_dtype(avg_squared_grad.dtype()); + + avg_squared_update_out->set_dims(param_dims); + avg_squared_update_out->set_dtype(avg_squared_update.dtype()); +} + +void AdamaxInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& moment, + const MetaTensor& inf_norm, + const MetaTensor& beta1_pow, + float beta1, + float beta2, + float epsilon, + MetaTensor* param_out, + MetaTensor* moment_out, + MetaTensor* inf_norm_out) { + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_NE( + product(lr_dims), + 0, + errors::InvalidArgument("Maybe the Input variable LearningRate has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.")); + PADDLE_ENFORCE_EQ( + product(lr_dims), + 1, + errors::InvalidArgument("Learning rate should have 1 dimension")); + auto beta1_pow_dims = beta1_pow.dims(); + PADDLE_ENFORCE_EQ(product(beta1_pow_dims), + 1, + errors::InvalidArgument( + "Beta1 power accumulator should have 1 dimension")); + auto param_dims = param.dims(); + PADDLE_ENFORCE_EQ( + param_dims, + grad.dims(), + errors::InvalidArgument( + "Param and Grad input of AdamaxOp should have same dimension")); + PADDLE_ENFORCE_EQ( + param_dims, + moment.dims(), + errors::InvalidArgument( + "Param and Moment input of AdamaxOp should have same dimension")); + PADDLE_ENFORCE_EQ( + param_dims, + inf_norm.dims(), + errors::InvalidArgument( + "Param and InfNorm input of AdamaxOp should have same dimension")); + + param_out->set_dims(param_dims); + param_out->set_dtype(param.dtype()); + + moment_out->set_dims(param_dims); + moment_out->set_dtype(moment.dtype()); + + inf_norm_out->set_dims(param_dims); + inf_norm_out->set_dtype(inf_norm.dtype()); +} + void AucInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& stat_pos, @@ -108,98 +200,6 @@ void AucInferMeta(const MetaTensor& input, } } -void AdamaxInferMeta(const MetaTensor& param, - const MetaTensor& grad, - const MetaTensor& learning_rate, - const MetaTensor& moment, - const MetaTensor& inf_norm, - const MetaTensor& beta1_pow, - float beta1, - float beta2, - float epsilon, - MetaTensor* param_out, - MetaTensor* moment_out, - MetaTensor* inf_norm_out) { - auto lr_dims = learning_rate.dims(); - PADDLE_ENFORCE_NE( - product(lr_dims), - 0, - errors::InvalidArgument("Maybe the Input variable LearningRate has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.")); - PADDLE_ENFORCE_EQ( - product(lr_dims), - 1, - errors::InvalidArgument("Learning rate should have 1 dimension")); - auto beta1_pow_dims = beta1_pow.dims(); - PADDLE_ENFORCE_EQ(product(beta1_pow_dims), - 1, - errors::InvalidArgument( - "Beta1 power accumulator should have 1 dimension")); - auto param_dims = param.dims(); - PADDLE_ENFORCE_EQ( - param_dims, - grad.dims(), - errors::InvalidArgument( - "Param and Grad input of AdamaxOp should have same dimension")); - PADDLE_ENFORCE_EQ( - param_dims, - moment.dims(), - errors::InvalidArgument( - "Param and Moment input of AdamaxOp should have same dimension")); - PADDLE_ENFORCE_EQ( - param_dims, - inf_norm.dims(), - errors::InvalidArgument( - "Param and InfNorm input of AdamaxOp should have same dimension")); - - param_out->set_dims(param_dims); - param_out->set_dtype(param.dtype()); - - moment_out->set_dims(param_dims); - moment_out->set_dtype(moment.dtype()); - - inf_norm_out->set_dims(param_dims); - inf_norm_out->set_dtype(inf_norm.dtype()); -} - -void AdadeltaInferMeta(const MetaTensor& param, - const MetaTensor& grad, - const MetaTensor& avg_squared_grad, - const MetaTensor& avg_squared_update, - float rho, - float epsilon, - MetaTensor* param_out, - MetaTensor* avg_squared_grad_out, - MetaTensor* avg_squared_update_out) { - auto param_dims = param.dims(); - PADDLE_ENFORCE_EQ( - param_dims, - grad.dims(), - errors::InvalidArgument( - "Param and grad input of AdadeltaOp should have same dimension.")); - PADDLE_ENFORCE_EQ( - param_dims, - avg_squared_grad.dims(), - errors::InvalidArgument("Param and AvgSquaredGrad input of AdadeltaOp " - "should have same dimension")); - PADDLE_ENFORCE_EQ( - param_dims, - avg_squared_update.dims(), - errors::InvalidArgument("Param and AvgSquaredUpdate input of AdadeltaOp " - "should have same dimension")); - - param_out->set_dims(param_dims); - param_out->set_dtype(param.dtype()); - - avg_squared_grad_out->set_dims(param_dims); - avg_squared_grad_out->set_dtype(avg_squared_grad.dtype()); - - avg_squared_update_out->set_dims(param_dims); - avg_squared_update_out->set_dtype(avg_squared_update.dtype()); -} - void BilinearTensorProductInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, @@ -369,6 +369,81 @@ void ConcatInferMeta(const std::vector& x, out->share_lod(*x.at(0)); } +void PsroiPoolInferMeta(const MetaTensor& x, + const MetaTensor& rois, + paddle::optional rois_num, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + MetaTensor* out) { + auto input_dims = x.dims(); + auto rois_dims = rois.dims(); + + PADDLE_ENFORCE_EQ( + input_dims.size(), + 4, + errors::InvalidArgument("The format of input tensor is NCHW")); + PADDLE_ENFORCE_EQ(rois_dims.size(), + 2, + errors::InvalidArgument( + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]")); + PADDLE_ENFORCE_EQ(rois_dims[1], + 4, + errors::InvalidArgument( + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]")); + if (rois_num.get_ptr()) { + auto rois_num_dims = rois_num->dims(); + PADDLE_ENFORCE_EQ( + rois_num_dims.size(), + 1, + errors::InvalidArgument("The second dimension of RoisNum should " + "be 1, but received dimension is %d", + rois_num_dims.size())); + } + + PADDLE_ENFORCE_EQ( + input_dims[1], + output_channels * pooled_height * pooled_width, + errors::InvalidArgument( + "the channel of X(%d) " + "should be equal to the product of " + "output_channels(%d), pooled_height(%d) and pooled_width(%d)", + input_dims[1], + output_channels, + pooled_height, + pooled_width)); + + PADDLE_ENFORCE_GT(pooled_height, + 0, + errors::InvalidArgument( + "The pooled output height must be greater than 0")); + PADDLE_ENFORCE_GT(pooled_width, + 0, + errors::InvalidArgument( + "The pooled output width must be greater than 0")); + PADDLE_ENFORCE_GT(output_channels, + 1, + errors::InvalidArgument( + "The pooled output channels must greater than 1")); + PADDLE_ENFORCE_GT( + spatial_scale, + 0.0f, + errors::InvalidArgument("The spatial scale must greater than 0.")); + + auto out_dims = input_dims; + out_dims[0] = rois_dims[0]; + out_dims[1] = + output_channels; // input_dims[1] / (pooled_height * pooled_width); + out_dims[2] = pooled_height; + out_dims[3] = pooled_width; + + out->set_dims(out_dims); + out->set_dtype(x.dtype()); +} + void WhereInferMeta(const MetaTensor& condition, const MetaTensor& x, const MetaTensor& y, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 26bdc62302f18ad011fd3ab74f4b2dd708d4c1ef..c11843212ed33fd8170e6677a4d6e0ad95b730dc 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -20,6 +20,29 @@ namespace phi { std::vector GetMetaTensorsDim(const std::vector& tensors); +void AdadeltaInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& avg_squared_grad, + const MetaTensor& avg_squared_update, + float rho, + float epsilon, + MetaTensor* param_out, + MetaTensor* avg_squared_grad_out, + MetaTensor* avg_squared_update_out); + +void AdamaxInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& moment, + const MetaTensor& inf_norm, + const MetaTensor& beta1_pow, + float beta1, + float beta2, + float epsilon, + MetaTensor* param_out, + MetaTensor* moment_out, + MetaTensor* inf_norm_out); + void AucInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& stat_pos, @@ -47,32 +70,18 @@ void ConcatInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void PsroiPoolInferMeta(const MetaTensor& x, + const MetaTensor& rois, + paddle::optional rois_num, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + MetaTensor* out); + void WhereInferMeta(const MetaTensor& condition, const MetaTensor& x, const MetaTensor& y, MetaTensor* out); -void AdamaxInferMeta(const MetaTensor& param, - const MetaTensor& grad, - const MetaTensor& learning_rate, - const MetaTensor& moment, - const MetaTensor& inf_norm, - const MetaTensor& beta1_pow, - float beta1, - float beta2, - float epsilon, - MetaTensor* param_out, - MetaTensor* moment_out, - MetaTensor* inf_norm_out); - -void AdadeltaInferMeta(const MetaTensor& param, - const MetaTensor& grad, - const MetaTensor& avg_squared_grad, - const MetaTensor& avg_squared_update, - float rho, - float epsilon, - MetaTensor* param_out, - MetaTensor* avg_squared_grad_out, - MetaTensor* avg_squared_update_out); - } // namespace phi diff --git a/paddle/phi/kernels/cpu/psroi_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/psroi_pool_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..fbed3f1cb133ada68b90a5283fc182373488c565 --- /dev/null +++ b/paddle/phi/kernels/cpu/psroi_pool_grad_kernel.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/psroi_pool_grad_kernel.h" + +#include +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PsroiPoolGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& rois, + paddle::optional rois_num, + const DenseTensor& dout, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + DenseTensor* dx) { + if (dx) { + auto in_dims = x.dims(); + int input_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num_t = rois.dims()[0]; + + // set roi batch id + DenseTensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num_t}); + int* rois_batch_id_data = ctx.template Alloc(&rois_batch_id_list); + int rois_batch_size; + if (rois_num.get_ptr()) { + rois_batch_size = rois_num->numel(); + auto* rois_num_t_data = rois_num->data(); + int start = 0; + for (int n = 0; n < rois_batch_size; ++n) { + for (int i = start; i < start + rois_num_t_data[n]; ++i) { + rois_batch_id_data[i] = n; + } + start += rois_num_t_data[n]; + } + } else { + auto rois_lod = rois.lod().back(); + rois_batch_size = rois_lod.size() - 1; + // calculate batch id index for each roi according to LoD + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + } + const T* input_rois = rois.data(); + const T* dout_data = dout.data(); + T* dx_data = ctx.template Alloc(dx); + + // set gradient of X to be 0. before backpropagate. + funcs::SetConstant set_zero; + set_zero(ctx, dx, static_cast(0)); + + // backpropagate gradient per output pixel + int dout_size = dout.numel(); + for (int i = 0; i < dout_size; ++i) { + // The output is in order (n, c, ph, pw) + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % output_channels; + int n = i / pooled_width / pooled_height / output_channels; + + // set roi_batch_id + int roi_batch_id = rois_batch_id_data[n]; + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_offset = + (roi_batch_id * input_channels + input_channel) * height * width; + T* offset_dx_data = dx_data + input_offset; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = + static_cast(round(offset_input_rois[0])) * spatial_scale; + T roi_start_h = + static_cast(round(offset_input_rois[1])) * spatial_scale; + T roi_end_w = + static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; + T roi_end_h = + static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; + + // Force too small ROIs to be 1x1 + T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 + T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int hstart = floor(bin_size_h * static_cast(ph) + roi_start_h); + int wstart = floor(bin_size_w * static_cast(pw) + roi_start_w); + int hend = ceil(bin_size_h * static_cast(ph + 1) + roi_start_h); + int wend = ceil(bin_size_w * static_cast(pw + 1) + roi_start_w); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart, 0), height); + hend = std::min(std::max(hend, 0), height); + wstart = std::min(std::max(wstart, 0), width); + wend = std::min(std::max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Accumulate diff_val into input data + T bin_area = static_cast((hend - hstart) * (wend - wstart)); + T diff_val = is_empty ? 0. : dout_data[i] / bin_area; + for (int ih = hstart; ih < hend; ++ih) { + for (int iw = wstart; iw < wend; ++iw) { + int input_index = ih * width + iw; + offset_dx_data[input_index] += diff_val; + } + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + psroi_pool_grad, CPU, ALL_LAYOUT, phi::PsroiPoolGradKernel, float, double) { + kernel->InputAt(2).SetDataType( + paddle::experimental::CppTypeToDataType::Type()); +} diff --git a/paddle/phi/kernels/cpu/psroi_pool_kernel.cc b/paddle/phi/kernels/cpu/psroi_pool_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..06cd03395d9656614995ef0ad91dad04b27717bf --- /dev/null +++ b/paddle/phi/kernels/cpu/psroi_pool_kernel.cc @@ -0,0 +1,174 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/psroi_pool_kernel.h" + +#include +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void PsroiPoolKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& rois, + paddle::optional rois_num, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + DenseTensor* out) { + auto in_dims = x.dims(); + int batch_size = in_dims[0]; + int input_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num_t = rois.dims()[0]; + + PADDLE_ENFORCE_EQ(input_channels, + output_channels * pooled_height * pooled_width, + errors::InvalidArgument( + "the channels of input " + "X should equal the product of " + "output_channels x pooled_height x pooled_width")); + + auto in_stride = stride(in_dims); + auto out_stride = stride(out->dims()); + + const T* input_data = x.data(); + + DenseTensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num_t}); + int* rois_batch_id_data = ctx.template Alloc(&rois_batch_id_list); + + int rois_batch_size; + if (rois_num.get_ptr()) { + rois_batch_size = rois_num->numel(); + auto* rois_num_data = rois_num->data(); + PADDLE_ENFORCE_EQ( + rois_batch_size, + batch_size, + errors::InvalidArgument( + "The batch size of rois and the batch size of images " + " must be the same. But received the batch size of rois is %d, " + "and the batch size of images is %d", + rois_batch_size, + batch_size)); + int rois_num_count = 0; + for (int i = 0; i < rois_batch_size; ++i) { + rois_num_count += rois_num_data[i]; + } + PADDLE_ENFORCE_EQ( + rois_num_count, + rois_num_t, + errors::InvalidArgument( + "the rois_num from input and RoisNum must be the same")); + int start = 0; + for (int n = 0; n < rois_batch_size; ++n) { + for (int i = start; i < start + rois_num_data[n]; ++i) { + rois_batch_id_data[i] = n; + } + start += rois_num_data[n]; + } + } else { + auto rois_lod = rois.lod().back(); + rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, + batch_size, + errors::InvalidArgument("the rois_batch_size and input(X) " + "batch_size should be the same.")); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num_with_lod, + rois_num_t, + errors::InvalidArgument( + "the rois_num from input and lod must be the same")); + // calculate batch id index for each roi according to LoD + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + } + T* output_data = ctx.template Alloc(out); + const T* input_rois = rois.data(); + + // calculate psroipooling, parallel processing can be implemented per ROI + for (int n = 0; n < rois_num_t; ++n) { + // set roi batch id + int roi_batch_id = rois_batch_id_data[n]; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = static_cast(round(offset_input_rois[0])) * spatial_scale; + T roi_start_h = static_cast(round(offset_input_rois[1])) * spatial_scale; + T roi_end_w = + static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; + T roi_end_h = + static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; + // Force too small rois to be 1 x 1 + T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 + T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1); + + // Compute bin size w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + // calculate each pixel of the output feature map. + int out_roi_offset = n * out_stride[0]; + for (int c = 0; c < output_channels; ++c) { + // per category + int out_plane_offset = out_roi_offset + c * out_stride[1]; + for (int ph = 0; ph < pooled_height; ++ph) { + int out_row_offset = out_plane_offset + ph * out_stride[2]; + for (int pw = 0; pw < pooled_width; ++pw) { + // calculate w and h at input feature map + int hstart = floor(static_cast(ph) * bin_size_h + roi_start_h); + int wstart = floor(static_cast(pw) * bin_size_w + roi_start_w); + int hend = ceil(static_cast(ph + 1) * bin_size_h + roi_start_h); + int wend = ceil(static_cast(pw + 1) * bin_size_w + roi_start_w); + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart, 0), height); + wstart = std::min(std::max(wstart, 0), width); + hend = std::min(std::max(hend, 0), height); + wend = std::min(std::max(wend, 0), width); + + int output_index = out_row_offset + pw; + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_plane_offset = + roi_batch_id * in_stride[0] + input_channel * in_stride[1]; + const T* offset_input_data = input_data + input_plane_offset; + T out_sum = 0.; + bool is_empty = (hend <= hstart) || (wend <= wstart); + for (int ih = hstart; ih < hend; ++ih) { + for (int iw = wstart; iw < wend; ++iw) { + int input_index = ih * in_stride[2] + iw; + out_sum += offset_input_data[input_index]; + } + } + T bin_area = (hend - hstart) * (wend - wstart); + output_data[output_index] = is_empty ? 0. : out_sum / bin_area; + } + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + psroi_pool, CPU, ALL_LAYOUT, phi::PsroiPoolKernel, float, double) { + kernel->InputAt(2).SetDataType( + paddle::experimental::CppTypeToDataType::Type()); +} diff --git a/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..6745653eba7d175447eb54c319919fd6d87fb5dd --- /dev/null +++ b/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu @@ -0,0 +1,193 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/psroi_pool_kernel.h" + +#include +#include +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +__global__ void GPUPSROIPoolBackward(const int nthreads, + const T* input_rois, + const T* dout_data, + const float spatial_scale, + const int input_channels, + const int height, + const int width, + const int output_channels, + const int pooled_height, + const int pooled_width, + const int* rois_batch_id_data, + T* dx_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + // The output is in order (n, c, ph, pw) + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % output_channels; + int n = i / pooled_width / pooled_height / output_channels; + + // set roi_batch_id + int roi_batch_id = rois_batch_id_data[n]; + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_offset = + (roi_batch_id * input_channels + input_channel) * height * width; + T* offset_dx_data = dx_data + input_offset; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = static_cast(round(offset_input_rois[0])) * spatial_scale; + T roi_start_h = static_cast(round(offset_input_rois[1])) * spatial_scale; + T roi_end_w = + static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; + T roi_end_h = + static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; + + // Force too small ROIs to be 1x1 + T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 + T roi_width = max(roi_end_w - roi_start_w, (T)0.1); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int hstart = floor(bin_size_h * static_cast(ph) + roi_start_h); + int wstart = floor(bin_size_w * static_cast(pw) + roi_start_w); + int hend = ceil(bin_size_h * static_cast(ph + 1) + roi_start_h); + int wend = ceil(bin_size_w * static_cast(pw + 1) + roi_start_w); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart, 0), height); + hend = min(max(hend, 0), height); + wstart = min(max(wstart, 0), width); + wend = min(max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Accumulate diff_val into input data + T bin_area = static_cast((hend - hstart) * (wend - wstart)); + T diff_val = is_empty ? 0. : dout_data[i] / bin_area; + for (int ih = hstart; ih < hend; ++ih) { + for (int iw = wstart; iw < wend; ++iw) { + int input_index = ih * width + iw; + paddle::platform::CudaAtomicAdd(offset_dx_data + input_index, diff_val); + } + } + } +} + +template +void PsroiPoolGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& rois, + paddle::optional rois_num, + const DenseTensor& dout, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + DenseTensor* dx) { + int rois_num_t = rois.dims()[0]; + int input_channels = x.dims()[1]; + int height = x.dims()[2]; + int width = x.dims()[3]; + + if (dx) { + // set roi batch id + DenseTensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num_t}); + int* rois_batch_id_data = ctx.template HostAlloc(&rois_batch_id_list); + int rois_batch_size; + if (rois_num.get_ptr()) { + rois_batch_size = rois_num->numel(); + std::vector rois_num_list(rois_batch_size); + paddle::memory::Copy(CPUPlace(), + rois_num_list.data(), + ctx.GetPlace(), + rois_num->data(), + sizeof(int) * rois_batch_size, + 0); + int start = 0; + for (int n = 0; n < rois_batch_size; ++n) { + for (int i = start; i < start + rois_num_list[n]; ++i) { + rois_batch_id_data[i] = n; + } + start += rois_num_list[n]; + } + } else { + auto rois_lod = rois.lod().back(); + rois_batch_size = rois_lod.size() - 1; + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + } + + DenseTensor rois_batch_id_list_gpu; + Copy(ctx, + rois_batch_id_list, + ctx.GetPlace(), + false, + &rois_batch_id_list_gpu); + + ctx.template Alloc(dx); + funcs::SetConstant set_zero; + set_zero(ctx, dx, static_cast(0)); + + int dout_size = dout.numel(); + int blocks = NumBlocks(dout_size); + int threads = kNumCUDAThreads; + + if (dout_size > 0) { + GPUPSROIPoolBackward<<>>( + dout_size, + rois.data(), + dout.data(), + spatial_scale, + input_channels, + height, + width, + output_channels, + pooled_height, + pooled_width, + rois_batch_id_list_gpu.data(), + ctx.template Alloc(dx)); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + psroi_pool_grad, GPU, ALL_LAYOUT, phi::PsroiPoolGradKernel, float, double) { + kernel->InputAt(2).SetDataType( + paddle::experimental::CppTypeToDataType::Type()); +} diff --git a/paddle/phi/kernels/gpu/psroi_pool_kernel.cu b/paddle/phi/kernels/gpu/psroi_pool_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..8f9be001ba763d323ad93fdfd4cc06e97e266188 --- /dev/null +++ b/paddle/phi/kernels/gpu/psroi_pool_kernel.cu @@ -0,0 +1,231 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/psroi_pool_kernel.h" + +#include +#include +#include "paddle/fluid/memory/memory.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +__global__ void GPUPSROIPoolForward(const int nthreads, + const T* input_data, + const T* input_rois, + const float spatial_scale, + const int input_channels, + const int height, + const int width, + const int output_channels, + const int pooled_height, + const int pooled_width, + const int* rois_batch_id_data, + T* output_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + // The output is in order (n, c, ph, pw) + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % output_channels; + int n = i / pooled_width / pooled_height / output_channels; + + // set roi_batch_id + int roi_batch_id = rois_batch_id_data[n]; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = static_cast(round(offset_input_rois[0])) * spatial_scale; + T roi_start_h = static_cast(round(offset_input_rois[1])) * spatial_scale; + T roi_end_w = + static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; + T roi_end_h = + static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; + + // Force too small ROIs to be 1x1 + T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 + T roi_width = max(roi_end_w - roi_start_w, (T)0.1); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int hstart = floor(bin_size_h * static_cast(ph) + roi_start_h); + int wstart = floor(bin_size_w * static_cast(pw) + roi_start_w); + int hend = ceil(bin_size_h * static_cast(ph + 1) + roi_start_h); + int wend = ceil(bin_size_w * static_cast(pw + 1) + roi_start_w); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart, 0), height); + hend = min(max(hend, 0), height); + wstart = min(max(wstart, 0), width); + wend = min(max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + const T* offset_input_data = + input_data + + (roi_batch_id * input_channels + input_channel) * height * width; + T outsum = 0; + + for (int ih = hstart; ih < hend; ++ih) { + for (int iw = wstart; iw < wend; ++iw) { + int input_index = ih * width + iw; + outsum += offset_input_data[input_index]; + } + } + + T bin_area = static_cast((hend - hstart) * (wend - wstart)); + output_data[i] = is_empty ? 0. : outsum / bin_area; + } +} + +template +void PsroiPoolKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& rois, + paddle::optional rois_num, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + DenseTensor* out) { + auto in_dims = x.dims(); + int batch_size = in_dims[0]; + int input_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + + PADDLE_ENFORCE_EQ( + input_channels, + output_channels * pooled_height * pooled_width, + errors::InvalidArgument( + "The channels %d of input X should equal the product of " + "output_channels %d x pooled_height %d x pooled_width %d.", + input_channels, + output_channels, + pooled_height, + pooled_width)); + + int rois_num_t = rois.dims()[0]; + if (rois_num_t == 0) return; + int rois_batch_size; + DenseTensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num_t}); + int* rois_batch_id_data = ctx.template HostAlloc(&rois_batch_id_list); + + if (rois_num.get_ptr()) { + rois_batch_size = rois_num->numel(); + auto* rois_num_data = rois_num->data(); + PADDLE_ENFORCE_EQ(rois_batch_size, + batch_size, + errors::InvalidArgument( + "The batch size of input(ROIs) and input(X) must be " + "the same but received batch size of input(ROIs) and " + "input(X) is %d and %d respectively.", + rois_batch_size, + batch_size)); + std::vector rois_num_list(rois_batch_size); + paddle::memory::Copy(CPUPlace(), + rois_num_list.data(), + ctx.GetPlace(), + rois_num_data, + sizeof(int) * rois_batch_size, + 0); + int rois_num_count = 0; + for (int i = 0; i < rois_batch_size; ++i) { + rois_num_count += rois_num_list[i]; + } + PADDLE_ENFORCE_EQ( + rois_num_count, + rois_num_t, + errors::InvalidArgument( + "the rois_num from input and RoisNum must be the same")); + int start = 0; + for (int n = 0; n < rois_batch_size; ++n) { + for (int i = start; i < start + rois_num_list[n]; ++i) { + rois_batch_id_data[i] = n; + } + start += rois_num_list[n]; + } + } else { + auto rois_lod = rois.lod().back(); + rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ(rois_batch_size, + batch_size, + errors::InvalidArgument( + "The batch size of input(ROIs) and input(X) must be " + "the same but received batch size of input(ROIs) and " + "input(X) is %d and %d respectively.", + rois_batch_size, + batch_size)); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num_t, + rois_num_with_lod, + errors::InvalidArgument( + "The number of rois from input(ROIs) and its LOD " + "must be the same. Received rois %d of input(ROIs) " + "but the number of rois %d from its LOD is %d", + rois_num, + rois_num_with_lod)); + + // set rois batch id + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + } + DenseTensor rois_batch_id_list_gpu; + Copy(ctx, rois_batch_id_list, ctx.GetPlace(), false, &rois_batch_id_list_gpu); + + int output_size = out->numel(); + int blocks = NumBlocks(output_size); + int threads = kNumCUDAThreads; + + // call cuda kernel function + GPUPSROIPoolForward<<>>( + output_size, + x.data(), + rois.data(), + spatial_scale, + input_channels, + height, + width, + output_channels, + pooled_height, + pooled_width, + rois_batch_id_list_gpu.data(), + ctx.template Alloc(out)); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + psroi_pool, GPU, ALL_LAYOUT, phi::PsroiPoolKernel, float, double) { + kernel->InputAt(2).SetDataType( + paddle::experimental::CppTypeToDataType::Type()); +} diff --git a/paddle/phi/kernels/psroi_pool_grad_kernel.h b/paddle/phi/kernels/psroi_pool_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..87163eb8e079ffd580d6f937179e24a8506376e9 --- /dev/null +++ b/paddle/phi/kernels/psroi_pool_grad_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void PsroiPoolGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& rois, + paddle::optional rois_num, + const DenseTensor& dout, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + DenseTensor* dx); + +} // namespace phi diff --git a/paddle/phi/kernels/psroi_pool_kernel.h b/paddle/phi/kernels/psroi_pool_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..341037af2caeca28e211da8862e3c8d6089b9bac --- /dev/null +++ b/paddle/phi/kernels/psroi_pool_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void PsroiPoolKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& rois, + paddle::optional rois_num, + int pooled_height, + int pooled_width, + int output_channels, + float spatial_scale, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/psroi_pool_sig.cc b/paddle/phi/ops/compat/psroi_pool_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d694d9a7759d9e3cdf0c385164a367260f2a020 --- /dev/null +++ b/paddle/phi/ops/compat/psroi_pool_sig.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature PsroiPoolOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "psroi_pool", + {"X", "ROIs", "RoisNum"}, + {"pooled_height", "pooled_width", "output_channels", "spatial_scale"}, + {"Out"}); +} + +KernelSignature PsroiPoolGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "psroi_pool_grad", + {"X", "ROIs", "RoisNum", GradVarName("Out")}, + {"pooled_height", "pooled_width", "output_channels", "spatial_scale"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(psroi_pool, phi::PsroiPoolOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(psroi_pool_grad, + phi::PsroiPoolGradOpArgumentMapping);