diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 8deb3b93e9c50489dcfc6805063f23e3705cb634..16f2df79246f782ead9cc3177679674d98c3d1a9 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -341,7 +341,6 @@ void BuildDygraphPhiKernelContext( } for (size_t i = 0; i < attr_names.size(); ++i) { - VLOG(1) << "############## attr_name: " << i << " : " << attr_names[i]; if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) { if (attrs.find(attr_names[i]) != attrs.end()) { // shape is in the attribute diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index ac0cd75237baf5e8b860f197d42cd27bae65270e..bf78b6a696559cab152a6de2c4730a32dfdbb780 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -9,9 +9,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/roi_align_op.h" #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -23,79 +26,6 @@ class ROIAlignOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::NotFound("Input(X) of ROIAlignOp " - "is not found.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true, - platform::errors::NotFound("Input(ROIs) of ROIAlignOp " - "is not found.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::NotFound("Output(Out) of ROIAlignOp " - "is not found.")); - auto input_dims = ctx->GetInputDim("X"); - auto rois_dims = ctx->GetInputDim("ROIs"); - - if (ctx->HasInput("RoisNum")) { - auto rois_num_dims = ctx->GetInputDim("RoisNum"); - PADDLE_ENFORCE_EQ( - rois_num_dims.size(), 1, - platform::errors::InvalidArgument("The size of RoisNum should be 1" - ", but received size = %d", - rois_num_dims.size())); - } - PADDLE_ENFORCE_EQ( - input_dims.size(), 4, - platform::errors::InvalidArgument( - "The format of Input(X) in" - "RoIAlignOp is NCHW. And the rank of input must be 4. " - "But received rank = %d", - input_dims.size())); - PADDLE_ENFORCE_EQ(rois_dims.size(), 2, platform::errors::InvalidArgument( - "The rank of Input(ROIs) " - "in RoIAlignOp should be 2. " - "But the rank of RoIs is %d", - rois_dims.size())); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(rois_dims[1], 4, - platform::errors::InvalidArgument( - "The second dimension " - "of Input(ROIs) should be 4. But received the " - "dimension = %d", - rois_dims[1])); - } - int pooled_height = ctx->Attrs().Get("pooled_height"); - int pooled_width = ctx->Attrs().Get("pooled_width"); - float spatial_scale = ctx->Attrs().Get("spatial_scale"); - - PADDLE_ENFORCE_GT(pooled_height, 0, - platform::errors::InvalidArgument( - "The 'pooled_height' attribute in RoIAlignOp is " - "invalid. The height must be greater than 0. But " - "received 'pooled_height' = %d", - pooled_height)); - PADDLE_ENFORCE_GT(pooled_width, 0, - platform::errors::InvalidArgument( - "The 'pooled_width' attribute in RoIAlignOp is " - "invalid. The width must be greater than 0. But " - "received 'pooled_width' = %d", - pooled_width)); - PADDLE_ENFORCE_GT(spatial_scale, 0.0f, - platform::errors::InvalidArgument( - "The 'spatial_scale' attribute in RoIAlignOp is " - "invalid. The scale must be greater than 0. But " - "received 'spatial_scale' = %f", - spatial_scale)); - - auto out_dims = input_dims; - out_dims[0] = rois_dims[0]; - out_dims[1] = input_dims[1]; - out_dims[2] = pooled_height; - out_dims[3] = pooled_width; - - ctx->SetOutputDim("Out", out_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -221,17 +151,16 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(RoiAlignGradNoNeedBufVarsInferer, "X"); } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(roi_align, RoiAlignInferShapeFunctor, + PD_INFER_META(phi::RoiAlignInferMeta)); + REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker, ops::ROIAlignGradMaker, - ops::ROIAlignGradMaker); + ops::ROIAlignGradMaker, + RoiAlignInferShapeFunctor); REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp, ops::RoiAlignGradNoNeedBufVarsInferer); -REGISTER_OP_CPU_KERNEL( - roi_align_grad, - ops::CPUROIAlignGradOpKernel, - ops::CPUROIAlignGradOpKernel, - ops::CPUROIAlignGradOpKernel); REGISTER_OP_VERSION(roi_align) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/roi_align_op.cu b/paddle/fluid/operators/roi_align_op.cu deleted file mode 100644 index 1a2e64cd45ca401f5fb8ca6b6975a029ba735280..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/roi_align_op.cu +++ /dev/null @@ -1,227 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/roi_align_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaxinumNumBlocks = 4096; -static constexpr int kROISize = 4; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaxinumNumBlocks); -} - -template -__device__ void BilinearInterpolateGradient(const int height, const int width, - T y, T x, T* w1, T* w2, T* w3, - T* w4, int* x_low, int* x_high, - int* y_low, int* y_high) { - if (y < -1.0 || y > height || x < -1.0 || x > width) { - return; - } - - y = y <= 0 ? 0 : y; - x = x <= 0 ? 0 : x; - *y_low = static_cast(y); - *x_low = static_cast(x); - if (*y_low >= height - 1) { - *y_high = *y_low = height - 1; - y = static_cast(*y_low); - } else { - *y_high = *y_low + 1; - } - if (*x_low >= width - 1) { - *x_high = *x_low = width - 1; - x = static_cast(*x_low); - } else { - *x_high = *x_low + 1; - } - T ly = y - *y_low, lx = x - *x_low; - T hy = 1. - ly, hx = 1. - lx; - *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; - - return; -} - -template -__global__ void GPUROIAlignBackward( - const int nthreads, const T* input_rois, const T* out_grad, - const int num_rois, const float spatial_scale, const int channels, - const int height, const int width, const int pooled_height, - const int pooled_width, const int sampling_ratio, int* roi_batch_id_data, - T* input_grad, const bool continuous_coordinate) { - CUDA_KERNEL_LOOP(i, nthreads) { - int pw = i % pooled_width; - int ph = (i / pooled_width) % pooled_height; - int c = (i / pooled_width / pooled_height) % channels; - int n = i / pooled_width / pooled_height / channels; - const T* offset_input_rois = input_rois + n * kROISize; - int roi_batch_ind = roi_batch_id_data[n]; - - T roi_offset = continuous_coordinate ? T(0.5) : 0; - T roi_xmin = offset_input_rois[0] * spatial_scale - roi_offset; - T roi_ymin = offset_input_rois[1] * spatial_scale - roi_offset; - T roi_xmax = offset_input_rois[2] * spatial_scale - roi_offset; - T roi_ymax = offset_input_rois[3] * spatial_scale - roi_offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - if (!continuous_coordinate) { - roi_width = max(roi_width, static_cast(1.)); - roi_height = max(roi_height, static_cast(1.)); - } - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - T* offset_input_grad = - input_grad + (roi_batch_ind * channels + c) * height * width; - - const T* offset_out_grad = - out_grad + (n * channels + c) * pooled_height * pooled_width; - const T out_grad_this_bin = offset_out_grad[ph * pooled_width + pw]; - - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - const T count = roi_bin_grid_h * roi_bin_grid_w; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = roi_ymin + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_xmin + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - T w1 = 0, w2 = 0, w3 = 0, w4 = 0; - int x_low = -1, x_high = -1, y_low = -1, y_high = -1; - BilinearInterpolateGradient(height, width, y, x, &w1, &w2, &w3, &w4, - &x_low, &x_high, &y_low, &y_high); - T diff1 = out_grad_this_bin * w1 / count; - T diff2 = out_grad_this_bin * w2 / count; - T diff3 = out_grad_this_bin * w3 / count; - T diff4 = out_grad_this_bin * w4 / count; - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - platform::CudaAtomicAdd(offset_input_grad + y_low * width + x_low, - diff1); - platform::CudaAtomicAdd(offset_input_grad + y_low * width + x_high, - diff2); - platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_low, - diff3); - platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_high, - diff4); - } - } - } - } -} - -template -class GPUROIAlignGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - - auto* out_grad = ctx.Input(framework::GradVarName("Out")); - auto* in_grad = ctx.Output(framework::GradVarName("X")); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto spatial_scale = ctx.Attr("spatial_scale"); - auto sampling_ratio = ctx.Attr("sampling_ratio"); - auto aligned = ctx.Attr("aligned"); - - int rois_num = rois->dims()[0]; - int channels = in->dims()[1]; - int height = in->dims()[2]; - int width = in->dims()[3]; - - if (!in_grad) { - return; - } - Tensor roi_batch_id_list; - roi_batch_id_list.Resize({rois_num}); - auto cplace = platform::CPUPlace(); - int* roi_batch_id_data = roi_batch_id_list.mutable_data(cplace); - - auto& dev_ctx = ctx.cuda_device_context(); - auto gplace = ctx.GetPlace(); - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - int rois_batch_size = rois_num_t->numel(); - std::vector rois_num_list(rois_batch_size); - memory::Copy(cplace, rois_num_list.data(), gplace, - rois_num_t->data(), sizeof(int) * rois_batch_size, 0); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = start; i < start + rois_num_list[n]; ++i) { - roi_batch_id_data[i] = n; - } - start += rois_num_list[n]; - } - } else { - auto rois_lod = rois->lod().back(); - int rois_batch_size = rois_lod.size() - 1; - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; - } - } - } - auto roi_ptr = - memory::Alloc(dev_ctx, roi_batch_id_list.numel() * sizeof(int)); - int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); - int bytes = roi_batch_id_list.numel() * sizeof(int); - memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, - dev_ctx.stream()); - in_grad->mutable_data(ctx.GetPlace()); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, in_grad, static_cast(0)); - - int output_grad_size = out_grad->numel(); - int blocks = NumBlocks(output_grad_size); - int threads = kNumCUDAThreads; - - if (output_grad_size > 0) { - GPUROIAlignBackward<<>>( - output_grad_size, rois->data(), out_grad->data(), rois_num, - spatial_scale, channels, height, width, pooled_height, pooled_width, - sampling_ratio, roi_id_data, in_grad->mutable_data(ctx.GetPlace()), - aligned); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - roi_align_grad, - ops::GPUROIAlignGradOpKernel, - ops::GPUROIAlignGradOpKernel); diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h deleted file mode 100644 index 589e35e4ab7ae4caf5efd3fb4d93a26b2ca86b26..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/roi_align_op.h +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -template -void bilinear_interpolate_gradient(const int height, const int width, T y, T x, - const T out_grad_this_bin, const T count, - T* batch_grad_data) { - int x_low, y_low, x_high, y_high; - T w1, w2, w3, w4; - if (y < -1.0 || y > height || x < -1.0 || x > width) { - w1 = w2 = w3 = w4 = 0; - x_low = x_high = y_low = y_high = -1; - return; - } - y = y <= 0 ? 0 : y; - x = x <= 0 ? 0 : x; - y_low = static_cast(y); - x_low = static_cast(x); - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = static_cast(y_low); - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = static_cast(x_low); - } else { - x_high = x_low + 1; - } - - T ly = y - y_low, lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - T diff1 = out_grad_this_bin * w1 / count; - T diff2 = out_grad_this_bin * w2 / count; - T diff3 = out_grad_this_bin * w3 / count; - T diff4 = out_grad_this_bin * w4 / count; - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - *(batch_grad_data + y_low * width + x_low) += diff1; - *(batch_grad_data + y_low * width + x_high) += diff2; - *(batch_grad_data + y_high * width + x_low) += diff3; - *(batch_grad_data + y_high * width + x_high) += diff4; - } -} - -template -class CPUROIAlignGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - auto* in_grad = ctx.Output(framework::GradVarName("X")); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto spatial_scale = ctx.Attr("spatial_scale"); - auto sampling_ratio = ctx.Attr("sampling_ratio"); - auto in_dims = in->dims(); - auto aligned = ctx.Attr("aligned"); - - int channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - int rois_num = rois->dims()[0]; - - if (!in_grad) { - return; - } - Tensor roi_batch_id_list; - roi_batch_id_list.Resize({rois_num}); - int* roi_batch_id_data = - roi_batch_id_list.mutable_data(ctx.GetPlace()); - - int rois_batch_size; - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - rois_batch_size = rois_num_t->numel(); - auto* rois_num_data = rois_num_t->data(); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_data[n]; ++i) { - roi_batch_id_data[i] = n; - } - start += rois_num_data[n]; - } - } else { - auto rois_lod = rois->lod().back(); - rois_batch_size = rois_lod.size() - 1; - for (int n = 0; n < rois_batch_size; ++n) { - for (std::size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; - } - } - } - in_grad->mutable_data(ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, in_grad, static_cast(0)); - - int output_grad_size = out_grad->numel(); - - if ((!out_grad->IsInitialized()) || (output_grad_size <= 0)) { - return; - } - - const T* rois_data = rois->data(); - const T* out_grad_data = out_grad->data(); - T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); - - auto in_stride = phi::stride(in->dims()); - auto roi_stride = phi::stride(rois->dims()); - auto out_stride = phi::stride(out_grad->dims()); - - T roi_offset = aligned ? T(0.5) : 0; - for (int n = 0; n < rois_num; ++n) { - int roi_batch_idx = roi_batch_id_data[n]; - T roi_xmin = rois_data[0] * spatial_scale - roi_offset; - T roi_ymin = rois_data[1] * spatial_scale - roi_offset; - T roi_xmax = rois_data[2] * spatial_scale - roi_offset; - T roi_ymax = rois_data[3] * spatial_scale - roi_offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - roi_width = std::max(roi_width, static_cast(1.)); - roi_height = std::max(roi_height, static_cast(1.)); - if (!aligned) { - roi_width = std::max(roi_width, static_cast(1.)); - roi_height = std::max(roi_height, static_cast(1.)); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - for (int c = 0; c < channels; ++c) { - T* batch_grad_data = - in_grad_data + roi_batch_idx * in_stride[0] + c * in_stride[1]; - const T* batch_out_grad_data = - out_grad_data + n * out_stride[0] + c * out_stride[1]; - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int pool_index = ph * pooled_width + pw; - T out_grad_this_bin = batch_out_grad_data[pool_index]; - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_width / pooled_width); - T count = roi_bin_grid_h * roi_bin_grid_w; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = roi_ymin + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_xmin + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - bilinear_interpolate_gradient(height, width, y, x, - out_grad_this_bin, count, - batch_grad_data); - } - } - } - } - } - rois_data += roi_stride[0]; - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 06ee5a205d7b0f2f842e1b9b4b8fad8948168b64..260fbfe7197912fd3dd5b9103a0a991a45d55816 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -21,6 +21,10 @@ limitations under the License. */ namespace phi { +// Common InferMeta Functions for backward operators. +// +// NOTE: The InferMeta Functions in this file are arranged in alphabetic order. + void BilinearTensorProductGradInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 1727e85b1d533a8aaf4e044beca6b2308e441908..8cf7ce3930e941a3c5243306fa38e4466059509a 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -29,6 +29,8 @@ namespace phi { // NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. // Because functions in this file not only can infer shape, but also need // infer lod or other useful data. +// +// The InferMeta Functions in this file are arranged in alphabetic order. void AllValueCompareInferMeta(const MetaTensor& x, const MetaTensor& y, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 0bdd35d5f58e8e9d5c3dd7956897bac0adbdf550..6de95386dd998810b508db6d0469691a37cd53dd 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -18,6 +18,23 @@ limitations under the License. */ #include "paddle/phi/core/meta_tensor.h" namespace phi { +// Common InferMeta Functions for multiary operators, The format like: +// +// 1. The number of input MetaTensor is more than 3: +// void [FunctionDesc|OpName]InferMeta(const MetaTensor& x, +// const MetaTensor& y, +// const MetaTensor& z, +// const MetaTensor& w, +// ..., +// MetaTensor* out) {} +// +// 2. There are `const vector&` in params: +// void [FunctionDesc|OpName]InferMeta(const vector& x, +// ..., +// MetaTensor* out) {} +// +// NOTE: The InferMeta Functions in this file are arranged in alphabetic order. + std::vector GetMetaTensorsDim(const std::vector& tensors); void AdadeltaInferMeta(const MetaTensor& param, diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 38eaa636f8c8779c5a1f597b8cfb23ce6efc5edc..55e59b27e71cfb1d9b16a659e40d299ed3f2fc54 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -27,6 +27,8 @@ namespace phi { // NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. // Because functions in this file not only can infer shape, but also need // infer lod or other useful data. +// +// The InferMeta Functions in this file are arranged in alphabetic order. void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 235cfe368c1921eac546b670470963fb49100290..837750710c9a3dcf3c8b414c5c52a7272a0b3f58 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -322,6 +322,83 @@ void NllLossRawInferMeta(const MetaTensor& input, total_weight->set_dtype(input.dtype()); } +void RoiAlignInferMeta(const MetaTensor& x, + const MetaTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + int sampling_ratio, + bool aligned, + MetaTensor* out, + MetaConfig config) { + auto input_dims = x.dims(); + auto boxes_dims = boxes.dims(); + + if (boxes_num) { + auto boxes_num_dims = boxes_num->dims(); + PADDLE_ENFORCE_EQ( + boxes_num_dims.size(), + 1, + phi::errors::InvalidArgument("The size of RoisNum should be 1" + ", but received size = %d", + boxes_num_dims.size())); + } + PADDLE_ENFORCE_EQ(input_dims.size(), + 4, + phi::errors::InvalidArgument( + "The format of Input(X) in" + "RoIAlignOp is NCHW. And the rank of input must be 4. " + "But received rank = %d", + input_dims.size())); + PADDLE_ENFORCE_EQ(boxes_dims.size(), + 2, + phi::errors::InvalidArgument("The rank of Input(ROIs) " + "in RoIAlignOp should be 2. " + "But the rank of RoIs is %d", + boxes_dims.size())); + if (config.is_runtime) { + PADDLE_ENFORCE_EQ(boxes_dims[1], + 4, + phi::errors::InvalidArgument( + "The second dimension " + "of Input(ROIs) should be 4. But received the " + "dimension = %d", + boxes_dims[1])); + } + + PADDLE_ENFORCE_GT(pooled_height, + 0, + phi::errors::InvalidArgument( + "The 'pooled_height' attribute in RoIAlignOp is " + "invalid. The height must be greater than 0. But " + "received 'pooled_height' = %d", + pooled_height)); + PADDLE_ENFORCE_GT(pooled_width, + 0, + phi::errors::InvalidArgument( + "The 'pooled_width' attribute in RoIAlignOp is " + "invalid. The width must be greater than 0. But " + "received 'pooled_width' = %d", + pooled_width)); + PADDLE_ENFORCE_GT(spatial_scale, + 0.0f, + phi::errors::InvalidArgument( + "The 'spatial_scale' attribute in RoIAlignOp is " + "invalid. The scale must be greater than 0. But " + "received 'spatial_scale' = %f", + spatial_scale)); + + auto out_dims = input_dims; + out_dims[0] = boxes_dims[0]; + out_dims[1] = input_dims[1]; + out_dims[2] = pooled_height; + out_dims[3] = pooled_width; + + out->set_dims(out_dims); + out->set_dtype(x.dtype()); +} + void ScatterInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& updates, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 209a07db18b5c7a87ba094c5839149533757220d..0e7b9cb12a4d0b44727f488412af754e2ba8ad94 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -30,6 +30,8 @@ namespace phi { // Because functions in this file not only can infer shape, but also need // infer lod or other useful data. // +// The InferMeta Functions in this file are arranged in alphabetic order. + void AccuracyInferMeta(const MetaTensor& out, const MetaTensor& indice, const MetaTensor& label, @@ -71,6 +73,17 @@ void NllLossRawInferMeta(const MetaTensor& input, MetaTensor* total_weight, MetaConfig config = MetaConfig()); +void RoiAlignInferMeta(const MetaTensor& x, + const MetaTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + int sampling_ratio, + bool aligned, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void ScatterInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& updates, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 5447c9a573fbf3702dbb540f5052f2598899150e..3dfc9b797c089281cd9631642640a54be05ce679 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -31,6 +31,8 @@ class MetaConfig; // NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. // Because functions in this file not only can infer shape, but also need // infer lod or other useful data. +// +// The InferMeta Functions in this file are arranged in alphabetic order. void ArgMinMaxInferMeta(const MetaTensor& x, int64_t axis, diff --git a/paddle/phi/kernels/cpu/roi_align_grad_kernel.cc b/paddle/phi/kernels/cpu/roi_align_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..a91b8b6c1fcd3306521fb7cbc26d8c7adaf2d4f8 --- /dev/null +++ b/paddle/phi/kernels/cpu/roi_align_grad_kernel.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roi_align_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void bilinear_interpolate_gradient(const int height, + const int width, + T y, + T x, + const T out_grad_this_bin, + const T count, + T* batch_grad_data) { + int x_low, y_low, x_high, y_high; + T w1, w2, w3, w4; + if (y < -1.0 || y > height || x < -1.0 || x > width) { + w1 = w2 = w3 = w4 = 0; + x_low = x_high = y_low = y_high = -1; + return; + } + y = y <= 0 ? 0 : y; + x = x <= 0 ? 0 : x; + y_low = static_cast(y); + x_low = static_cast(x); + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = static_cast(y_low); + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = static_cast(x_low); + } else { + x_high = x_low + 1; + } + + T ly = y - y_low, lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + T diff1 = out_grad_this_bin * w1 / count; + T diff2 = out_grad_this_bin * w2 / count; + T diff3 = out_grad_this_bin * w3 / count; + T diff4 = out_grad_this_bin * w4 / count; + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + *(batch_grad_data + y_low * width + x_low) += diff1; + *(batch_grad_data + y_low * width + x_high) += diff2; + *(batch_grad_data + y_high * width + x_low) += diff3; + *(batch_grad_data + y_high * width + x_high) += diff4; + } +} + +template +void RoiAlignGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + const DenseTensor& out_grad, + int pooled_height, + int pooled_width, + float spatial_scale, + int sampling_ratio, + bool aligned, + DenseTensor* dx) { + auto in_dims = x.dims(); + int channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num = boxes.dims()[0]; + + if (!dx) { + return; + } + + DenseTensor roi_batch_id_list = Empty(dev_ctx, {rois_num}); + int* box_batch_id_data = roi_batch_id_list.data(); + + int boxes_batch_size; + if (boxes_num) { + boxes_batch_size = boxes_num->numel(); + auto* boxes_num_data = boxes_num->data(); + int start = 0; + for (int n = 0; n < boxes_batch_size; ++n) { + for (int i = start; i < start + boxes_num_data[n]; ++i) { + box_batch_id_data[i] = n; + } + start += boxes_num_data[n]; + } + } else { + auto boxes_lod = boxes.lod().back(); + boxes_batch_size = boxes_lod.size() - 1; + for (int n = 0; n < boxes_batch_size; ++n) { + for (std::size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) { + box_batch_id_data[i] = n; + } + } + } + dev_ctx.template Alloc(dx); + + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, dx, static_cast(0)); + + int output_grad_size = out_grad.numel(); + + if ((!out_grad.IsInitialized()) || (output_grad_size <= 0)) { + return; + } + + const T* boxes_data = boxes.data(); + const T* out_grad_data = out_grad.data(); + T* dx_data = dev_ctx.template Alloc(dx); + + auto in_stride = phi::stride(x.dims()); + auto roi_stride = phi::stride(boxes.dims()); + auto out_stride = phi::stride(out_grad.dims()); + + T roi_offset = aligned ? T(0.5) : 0; + for (int n = 0; n < rois_num; ++n) { + int box_batch_idx = box_batch_id_data[n]; + T roi_xmin = boxes_data[0] * spatial_scale - roi_offset; + T roi_ymin = boxes_data[1] * spatial_scale - roi_offset; + T roi_xmax = boxes_data[2] * spatial_scale - roi_offset; + T roi_ymax = boxes_data[3] * spatial_scale - roi_offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + roi_width = std::max(roi_width, static_cast(1.)); + roi_height = std::max(roi_height, static_cast(1.)); + if (!aligned) { + roi_width = std::max(roi_width, static_cast(1.)); + roi_height = std::max(roi_height, static_cast(1.)); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + for (int c = 0; c < channels; ++c) { + T* batch_grad_data = + dx_data + box_batch_idx * in_stride[0] + c * in_stride[1]; + const T* batch_out_grad_data = + out_grad_data + n * out_stride[0] + c * out_stride[1]; + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int pool_index = ph * pooled_width + pw; + T out_grad_this_bin = batch_out_grad_data[pool_index]; + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_width / pooled_width); + T count = roi_bin_grid_h * roi_bin_grid_w; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_ymin + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_xmin + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + bilinear_interpolate_gradient(height, + width, + y, + x, + out_grad_this_bin, + count, + batch_grad_data); + } + } + } + } + } + boxes_data += roi_stride[0]; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(roi_align_grad, + CPU, + ALL_LAYOUT, + phi::RoiAlignGradKernel, + float, + double, + int) {} diff --git a/paddle/phi/kernels/cpu/roi_align_kernel.cc b/paddle/phi/kernels/cpu/roi_align_kernel.cc index 35ab99a98eba7e59853fb311d5b2307b69ae31b2..4752a9b3a48fdcce5f3211a7aadca663fb44aa05 100644 --- a/paddle/phi/kernels/cpu/roi_align_kernel.cc +++ b/paddle/phi/kernels/cpu/roi_align_kernel.cc @@ -179,7 +179,7 @@ void AvgPool(const std::vector& interpolated_values, } template -void ROIAlignKernel(const Context& dev_ctx, +void RoiAlignKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& boxes, paddle::optional boxes_num, @@ -315,4 +315,4 @@ void ROIAlignKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - roi_align, CPU, ALL_LAYOUT, phi::ROIAlignKernel, float, double, int) {} + roi_align, CPU, ALL_LAYOUT, phi::RoiAlignKernel, float, double, int) {} diff --git a/paddle/phi/kernels/gpu/roi_align_grad_kernel.cu b/paddle/phi/kernels/gpu/roi_align_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..cf076128b69396196f59a8accd0c282322f8f49a --- /dev/null +++ b/paddle/phi/kernels/gpu/roi_align_grad_kernel.cu @@ -0,0 +1,260 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roi_align_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; +static constexpr int kROISize = 4; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__device__ void BilinearInterpolateGradient(const int height, + const int width, + T y, + T x, + T* w1, + T* w2, + T* w3, + T* w4, + int* x_low, + int* x_high, + int* y_low, + int* y_high) { + if (y < -1.0 || y > height || x < -1.0 || x > width) { + return; + } + + y = y <= 0 ? 0 : y; + x = x <= 0 ? 0 : x; + *y_low = static_cast(y); + *x_low = static_cast(x); + if (*y_low >= height - 1) { + *y_high = *y_low = height - 1; + y = static_cast(*y_low); + } else { + *y_high = *y_low + 1; + } + if (*x_low >= width - 1) { + *x_high = *x_low = width - 1; + x = static_cast(*x_low); + } else { + *x_high = *x_low + 1; + } + T ly = y - *y_low, lx = x - *x_low; + T hy = 1. - ly, hx = 1. - lx; + *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; + + return; +} + +template +__global__ void GPURoiAlignBackward(const int nthreads, + const T* input_rois, + const T* out_grad, + const int num_rois, + const float spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + int* roi_batch_id_data, + T* input_grad, + const bool continuous_coordinate) { + CUDA_KERNEL_LOOP(i, nthreads) { + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; + const T* offset_input_rois = input_rois + n * kROISize; + int roi_batch_ind = roi_batch_id_data[n]; + + T roi_offset = continuous_coordinate ? T(0.5) : 0; + T roi_xmin = offset_input_rois[0] * spatial_scale - roi_offset; + T roi_ymin = offset_input_rois[1] * spatial_scale - roi_offset; + T roi_xmax = offset_input_rois[2] * spatial_scale - roi_offset; + T roi_ymax = offset_input_rois[3] * spatial_scale - roi_offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + if (!continuous_coordinate) { + roi_width = max(roi_width, static_cast(1.)); + roi_height = max(roi_height, static_cast(1.)); + } + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_input_grad = + input_grad + (roi_batch_ind * channels + c) * height * width; + + const T* offset_out_grad = + out_grad + (n * channels + c) * pooled_height * pooled_width; + const T out_grad_this_bin = offset_out_grad[ph * pooled_width + pw]; + + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + const T count = roi_bin_grid_h * roi_bin_grid_w; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_ymin + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_xmin + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T w1 = 0, w2 = 0, w3 = 0, w4 = 0; + int x_low = -1, x_high = -1, y_low = -1, y_high = -1; + BilinearInterpolateGradient(height, + width, + y, + x, + &w1, + &w2, + &w3, + &w4, + &x_low, + &x_high, + &y_low, + &y_high); + T diff1 = out_grad_this_bin * w1 / count; + T diff2 = out_grad_this_bin * w2 / count; + T diff3 = out_grad_this_bin * w3 / count; + T diff4 = out_grad_this_bin * w4 / count; + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + paddle::platform::CudaAtomicAdd( + offset_input_grad + y_low * width + x_low, diff1); + paddle::platform::CudaAtomicAdd( + offset_input_grad + y_low * width + x_high, diff2); + paddle::platform::CudaAtomicAdd( + offset_input_grad + y_high * width + x_low, diff3); + paddle::platform::CudaAtomicAdd( + offset_input_grad + y_high * width + x_high, diff4); + } + } + } + } +} + +template +void RoiAlignGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + const DenseTensor& out_grad, + int pooled_height, + int pooled_width, + float spatial_scale, + int sampling_ratio, + bool aligned, + DenseTensor* dx) { + int rois_num = boxes.dims()[0]; + int channels = x.dims()[1]; + int height = x.dims()[2]; + int width = x.dims()[3]; + + if (!dx) { + return; + } + + DenseTensor box_batch_id_list; + box_batch_id_list.Resize({rois_num}); + int* box_batch_size = dev_ctx.template HostAlloc(&box_batch_id_list); + + auto cplace = phi::CPUPlace(); + auto gplace = dev_ctx.GetPlace(); + if (boxes_num) { + int boxes_batch_size = boxes_num->numel(); + std::vector boxes_num_list(boxes_batch_size); + paddle::memory::Copy(cplace, + boxes_num_list.data(), + gplace, + boxes_num->data(), + sizeof(int) * boxes_batch_size, + 0); + int start = 0; + for (int n = 0; n < boxes_batch_size; ++n) { + for (size_t i = start; i < start + boxes_num_list[n]; ++i) { + box_batch_size[i] = n; + } + start += boxes_num_list[n]; + } + } else { + auto boxes_lod = boxes.lod().back(); + int boxes_batch_size = boxes_lod.size() - 1; + for (int n = 0; n < boxes_batch_size; ++n) { + for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) { + box_batch_size[i] = n; + } + } + } + auto roi_ptr = + paddle::memory::Alloc(dev_ctx, box_batch_id_list.numel() * sizeof(int)); + int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); + int bytes = box_batch_id_list.numel() * sizeof(int); + paddle::memory::Copy( + gplace, roi_id_data, cplace, box_batch_size, bytes, dev_ctx.stream()); + dev_ctx.template Alloc(dx); + + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, dx, static_cast(0)); + + int output_grad_size = out_grad.numel(); + int blocks = NumBlocks(output_grad_size); + int threads = kNumCUDAThreads; + + if (output_grad_size > 0) { + GPURoiAlignBackward<<>>( + output_grad_size, + boxes.data(), + out_grad.data(), + rois_num, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + roi_id_data, + dx->data(), + aligned); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + roi_align_grad, GPU, ALL_LAYOUT, phi::RoiAlignGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/roi_align_kernel.cu b/paddle/phi/kernels/gpu/roi_align_kernel.cu index 2f906fa4f663b6da65a3e986af2214dfb49f2ec0..cd4ed29cdd1dd7b48a9135597ca79ab401a0cfba 100644 --- a/paddle/phi/kernels/gpu/roi_align_kernel.cu +++ b/paddle/phi/kernels/gpu/roi_align_kernel.cu @@ -71,7 +71,7 @@ __device__ T BilinearInterpolate( } template -__global__ void GPUROIAlignForward(const int nthreads, +__global__ void GPURoiAlignForward(const int nthreads, const T* input_data, const T* input_rois, const float spatial_scale, @@ -137,7 +137,7 @@ __global__ void GPUROIAlignForward(const int nthreads, } template -void ROIAlignKernel(const Context& dev_ctx, +void RoiAlignKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& boxes, paddle::optional boxes_num, @@ -233,7 +233,7 @@ void ROIAlignKernel(const Context& dev_ctx, int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); paddle::memory::Copy( gplace, roi_id_data, cplace, roi_batch_id_data, bytes, dev_ctx.stream()); - GPUROIAlignForward<<>>( + GPURoiAlignForward<<>>( output_size, x.data(), boxes.data(), @@ -252,4 +252,4 @@ void ROIAlignKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - roi_align, GPU, ALL_LAYOUT, phi::ROIAlignKernel, float, double) {} + roi_align, GPU, ALL_LAYOUT, phi::RoiAlignKernel, float, double) {} diff --git a/paddle/phi/kernels/roi_align_grad_kernel.h b/paddle/phi/kernels/roi_align_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..eea1fa03886a4a02dbc614052e1f280c2610f1ad --- /dev/null +++ b/paddle/phi/kernels/roi_align_grad_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void RoiAlignGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + const DenseTensor& out_grad, + int pooled_height, + int pooled_width, + float spatial_scale, + int sampling_ratio, + bool aligned, + DenseTensor* dx); + +} // namespace phi diff --git a/paddle/phi/kernels/roi_align_kernel.h b/paddle/phi/kernels/roi_align_kernel.h index 16b52c563a592f0cc23ddca94f554f5dc49e8ccf..9734da53b7f453d492cc60ee8930f54e7ca74edc 100644 --- a/paddle/phi/kernels/roi_align_kernel.h +++ b/paddle/phi/kernels/roi_align_kernel.h @@ -20,7 +20,7 @@ namespace phi { template -void ROIAlignKernel(const Context& dev_ctx, +void RoiAlignKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& boxes, paddle::optional boxes_num, diff --git a/paddle/phi/ops/compat/roi_align_sig.cc b/paddle/phi/ops/compat/roi_align_sig.cc index 0549103b6fbcb8b2367c34c8a44fb3b52f318859..1717ec8f788091fc5eae59c40a32a30c355760e8 100644 --- a/paddle/phi/ops/compat/roi_align_sig.cc +++ b/paddle/phi/ops/compat/roi_align_sig.cc @@ -16,7 +16,7 @@ namespace phi { -KernelSignature ROIAlignOpArgumentMapping(const ArgumentMappingContext& ctx) { +KernelSignature RoiAlignOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("roi_align", {"X", "ROIs", "RoisNum"}, {"pooled_height", @@ -27,6 +27,19 @@ KernelSignature ROIAlignOpArgumentMapping(const ArgumentMappingContext& ctx) { {"Out"}); } +KernelSignature RoiAlignGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("roi_align_grad", + {"X", "ROIs", "RoisNum", GradVarName("Out")}, + {"pooled_height", + "pooled_width", + "spatial_scale", + "sampling_ratio", + "aligned"}, + {GradVarName("X")}); +} + } // namespace phi -PD_REGISTER_ARG_MAPPING_FN(roi_align, phi::ROIAlignOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(roi_align, phi::RoiAlignOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(roi_align_grad, phi::RoiAlignGradOpArgumentMapping);