From 579173d8e6a8bb6039671cb6b215730b7b9614c5 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 18 Mar 2022 15:07:18 +0800 Subject: [PATCH] [Phi] Move infershape of roi_pool to phi (#40682) * move infershape of roi_pool to phi * polish code --- paddle/fluid/operators/roi_pool_op.cc | 78 +++------------------- paddle/phi/infermeta/ternary.cc | 95 ++++++++++++++++++++++++--- paddle/phi/infermeta/ternary.h | 9 +++ 3 files changed, 102 insertions(+), 80 deletions(-) diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index 9fd66590cb7..12e33d56c00 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/kernels/roi_pool_kernel.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -27,74 +29,6 @@ class ROIPoolOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "roi_pool"); - OP_INOUT_CHECK(ctx->HasInput("ROIs"), "Input", "ROIs", "roi_pool"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "roi_pool"); - OP_INOUT_CHECK(ctx->HasOutput("Argmax"), "Output", "Argmax", "roi_pool"); - - auto input_dims = ctx->GetInputDim("X"); - auto rois_dims = ctx->GetInputDim("ROIs"); - - if (ctx->HasInput("RoisNum")) { - auto rois_num_dims = ctx->GetInputDim("RoisNum"); - PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1, - platform::errors::InvalidArgument( - "The second dimension of RoisNum should " - "be 1, but received dimension is %d", - rois_num_dims.size())); - } - PADDLE_ENFORCE_EQ(input_dims.size(), 4, - platform::errors::InvalidArgument( - "The input data should be a four-dimensional " - "tensor with [N,C,H,W], but received input data with " - " %d dimension", - input_dims.size())); - PADDLE_ENFORCE_EQ( - rois_dims.size(), 2, - platform::errors::InvalidArgument( - "ROIs should be a 2-D LoDTensor with shape (num_rois, 4)" - "given as [[x1, y1, x2, y2], ...], but received ROIs is " - "%d-dimensional LoDTensor", - rois_dims.size())); - PADDLE_ENFORCE_EQ( - rois_dims[1], phi::kROISize, - platform::errors::InvalidArgument( - "ROIs should be a 2-D LoDTensor with shape (num_rois, 4)" - "given as [[x1, y1, x2, y2], ...]. But the second dimension of " - "the received data is %d", - rois_dims[1])); - - int pooled_height = ctx->Attrs().Get("pooled_height"); - int pooled_width = ctx->Attrs().Get("pooled_width"); - float spatial_scale = ctx->Attrs().Get("spatial_scale"); - - PADDLE_ENFORCE_GT(pooled_height, 0, - platform::errors::OutOfRange( - "The pooled output height must be greater than 0" - "but received height is %d", - pooled_height)); - PADDLE_ENFORCE_GT(pooled_width, 0, - platform::errors::OutOfRange( - "The pooled output width must be greater than 0" - "but received width is %d", - pooled_width)); - PADDLE_ENFORCE_GT(spatial_scale, 0.0f, - platform::errors::OutOfRange( - "The spatial scale must be greater than 0, " - "but received spatial scale is %f", - spatial_scale)); - - auto out_dims = input_dims; - out_dims[0] = rois_dims[0]; - out_dims[1] = input_dims[1]; - out_dims[2] = pooled_height; - out_dims[3] = pooled_width; - - ctx->SetOutputDim("Out", out_dims); - ctx->SetOutputDim("Argmax", out_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -213,9 +147,13 @@ class ROIPoolGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(roi_pool, RoiPoolInferShapeFunctor, + PD_INFER_META(phi::RoiPoolInferMeta)); + REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, ops::ROIPoolGradMaker, - ops::ROIPoolGradMaker); + ops::ROIPoolGradMaker, + RoiPoolInferShapeFunctor); REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp); REGISTER_OP_VERSION(roi_pool) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 837750710c9..556fb874470 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -340,29 +340,29 @@ void RoiAlignInferMeta(const MetaTensor& x, PADDLE_ENFORCE_EQ( boxes_num_dims.size(), 1, - phi::errors::InvalidArgument("The size of RoisNum should be 1" + phi::errors::InvalidArgument("The size of boxes_num should be 1" ", but received size = %d", boxes_num_dims.size())); } PADDLE_ENFORCE_EQ(input_dims.size(), 4, phi::errors::InvalidArgument( - "The format of Input(X) in" - "RoIAlignOp is NCHW. And the rank of input must be 4. " + "The format of Input(x) in" + "RoiAlignOp is NCHW. And the rank of input must be 4. " "But received rank = %d", input_dims.size())); PADDLE_ENFORCE_EQ(boxes_dims.size(), 2, - phi::errors::InvalidArgument("The rank of Input(ROIs) " - "in RoIAlignOp should be 2. " - "But the rank of RoIs is %d", + phi::errors::InvalidArgument("The rank of Input(boxes) " + "in RoiAlignOp should be 2. " + "But the rank of boxes is %d", boxes_dims.size())); if (config.is_runtime) { PADDLE_ENFORCE_EQ(boxes_dims[1], 4, phi::errors::InvalidArgument( "The second dimension " - "of Input(ROIs) should be 4. But received the " + "of Input(boxes) should be 4. But received the " "dimension = %d", boxes_dims[1])); } @@ -370,21 +370,21 @@ void RoiAlignInferMeta(const MetaTensor& x, PADDLE_ENFORCE_GT(pooled_height, 0, phi::errors::InvalidArgument( - "The 'pooled_height' attribute in RoIAlignOp is " + "The 'pooled_height' attribute in RoiAlignOp is " "invalid. The height must be greater than 0. But " "received 'pooled_height' = %d", pooled_height)); PADDLE_ENFORCE_GT(pooled_width, 0, phi::errors::InvalidArgument( - "The 'pooled_width' attribute in RoIAlignOp is " + "The 'pooled_width' attribute in RoiAlignOp is " "invalid. The width must be greater than 0. But " "received 'pooled_width' = %d", pooled_width)); PADDLE_ENFORCE_GT(spatial_scale, 0.0f, phi::errors::InvalidArgument( - "The 'spatial_scale' attribute in RoIAlignOp is " + "The 'spatial_scale' attribute in RoiAlignOp is " "invalid. The scale must be greater than 0. But " "received 'spatial_scale' = %f", spatial_scale)); @@ -399,6 +399,81 @@ void RoiAlignInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void RoiPoolInferMeta(const MetaTensor& x, + const MetaTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + MetaTensor* out, + MetaTensor* arg_max) { + auto input_dims = x.dims(); + auto boxes_dims = boxes.dims(); + + if (boxes_num) { + auto boxes_num_dims = boxes_num->dims(); + PADDLE_ENFORCE_EQ( + boxes_num_dims.size(), + 1, + phi::errors::InvalidArgument("The second dimension of boxes_num should " + "be 1, but received dimension is %d", + boxes_num_dims.size())); + } + PADDLE_ENFORCE_EQ(input_dims.size(), + 4, + phi::errors::InvalidArgument( + "The input data should be a four-dimensional " + "tensor with [N,C,H,W], but received input data with " + " %d dimension", + input_dims.size())); + PADDLE_ENFORCE_EQ( + boxes_dims.size(), + 2, + phi::errors::InvalidArgument( + "boxes should be a 2-D LoDTensor with shape (num_boxes, 4)" + "given as [[x1, y1, x2, y2], ...], but received boxes is " + "%d-dimensional LoDTensor", + boxes_dims.size())); + PADDLE_ENFORCE_EQ( + boxes_dims[1], + 4, + phi::errors::InvalidArgument( + "boxes should be a 2-D LoDTensor with shape (num_boxes, 4)" + "given as [[x1, y1, x2, y2], ...]. But the second dimension of " + "the received data is %d", + boxes_dims[1])); + + PADDLE_ENFORCE_GT( + pooled_height, + 0, + phi::errors::OutOfRange("The pooled output height must be greater than 0" + "but received height is %d", + pooled_height)); + PADDLE_ENFORCE_GT( + pooled_width, + 0, + phi::errors::OutOfRange("The pooled output width must be greater than 0" + "but received width is %d", + pooled_width)); + PADDLE_ENFORCE_GT( + spatial_scale, + 0.0f, + phi::errors::OutOfRange("The spatial scale must be greater than 0, " + "but received spatial scale is %f", + spatial_scale)); + + auto out_dims = input_dims; + out_dims[0] = boxes_dims[0]; + out_dims[1] = input_dims[1]; + out_dims[2] = pooled_height; + out_dims[3] = pooled_width; + + out->set_dims(out_dims); + out->set_dtype(x.dtype()); + arg_max->set_dims(out_dims); + arg_max->set_dtype(DataType::INT64); +} + void ScatterInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& updates, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 0e7b9cb12a4..42a0f35dc1d 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -84,6 +84,15 @@ void RoiAlignInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void RoiPoolInferMeta(const MetaTensor& x, + const MetaTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + MetaTensor* out, + MetaTensor* arg_max); + void ScatterInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& updates, -- GitLab