未验证 提交 579173d8 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Move infershape of roi_pool to phi (#40682)

* move infershape of roi_pool to phi

* polish code
上级 7f93e2b0
......@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/roi_pool_kernel.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
......@@ -27,74 +29,6 @@ class ROIPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "roi_pool");
OP_INOUT_CHECK(ctx->HasInput("ROIs"), "Input", "ROIs", "roi_pool");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "roi_pool");
OP_INOUT_CHECK(ctx->HasOutput("Argmax"), "Output", "Argmax", "roi_pool");
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");
if (ctx->HasInput("RoisNum")) {
auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
platform::errors::InvalidArgument(
"The second dimension of RoisNum should "
"be 1, but received dimension is %d",
rois_num_dims.size()));
}
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
"The input data should be a four-dimensional "
"tensor with [N,C,H,W], but received input data with "
" %d dimension",
input_dims.size()));
PADDLE_ENFORCE_EQ(
rois_dims.size(), 2,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ...], but received ROIs is "
"%d-dimensional LoDTensor",
rois_dims.size()));
PADDLE_ENFORCE_EQ(
rois_dims[1], phi::kROISize,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ...]. But the second dimension of "
"the received data is %d",
rois_dims[1]));
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
PADDLE_ENFORCE_GT(pooled_height, 0,
platform::errors::OutOfRange(
"The pooled output height must be greater than 0"
"but received height is %d",
pooled_height));
PADDLE_ENFORCE_GT(pooled_width, 0,
platform::errors::OutOfRange(
"The pooled output width must be greater than 0"
"but received width is %d",
pooled_width));
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
platform::errors::OutOfRange(
"The spatial scale must be greater than 0, "
"but received spatial scale is %f",
spatial_scale));
auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] = input_dims[1];
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("Argmax", out_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -213,9 +147,13 @@ class ROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(roi_pool, RoiPoolInferShapeFunctor,
PD_INFER_META(phi::RoiPoolInferMeta));
REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
ops::ROIPoolGradMaker<paddle::framework::OpDesc>,
ops::ROIPoolGradMaker<paddle::imperative::OpBase>);
ops::ROIPoolGradMaker<paddle::imperative::OpBase>,
RoiPoolInferShapeFunctor);
REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp);
REGISTER_OP_VERSION(roi_pool)
......
......@@ -340,29 +340,29 @@ void RoiAlignInferMeta(const MetaTensor& x,
PADDLE_ENFORCE_EQ(
boxes_num_dims.size(),
1,
phi::errors::InvalidArgument("The size of RoisNum should be 1"
phi::errors::InvalidArgument("The size of boxes_num should be 1"
", but received size = %d",
boxes_num_dims.size()));
}
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"The format of Input(X) in"
"RoIAlignOp is NCHW. And the rank of input must be 4. "
"The format of Input(x) in"
"RoiAlignOp is NCHW. And the rank of input must be 4. "
"But received rank = %d",
input_dims.size()));
PADDLE_ENFORCE_EQ(boxes_dims.size(),
2,
phi::errors::InvalidArgument("The rank of Input(ROIs) "
"in RoIAlignOp should be 2. "
"But the rank of RoIs is %d",
phi::errors::InvalidArgument("The rank of Input(boxes) "
"in RoiAlignOp should be 2. "
"But the rank of boxes is %d",
boxes_dims.size()));
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(boxes_dims[1],
4,
phi::errors::InvalidArgument(
"The second dimension "
"of Input(ROIs) should be 4. But received the "
"of Input(boxes) should be 4. But received the "
"dimension = %d",
boxes_dims[1]));
}
......@@ -370,21 +370,21 @@ void RoiAlignInferMeta(const MetaTensor& x,
PADDLE_ENFORCE_GT(pooled_height,
0,
phi::errors::InvalidArgument(
"The 'pooled_height' attribute in RoIAlignOp is "
"The 'pooled_height' attribute in RoiAlignOp is "
"invalid. The height must be greater than 0. But "
"received 'pooled_height' = %d",
pooled_height));
PADDLE_ENFORCE_GT(pooled_width,
0,
phi::errors::InvalidArgument(
"The 'pooled_width' attribute in RoIAlignOp is "
"The 'pooled_width' attribute in RoiAlignOp is "
"invalid. The width must be greater than 0. But "
"received 'pooled_width' = %d",
pooled_width));
PADDLE_ENFORCE_GT(spatial_scale,
0.0f,
phi::errors::InvalidArgument(
"The 'spatial_scale' attribute in RoIAlignOp is "
"The 'spatial_scale' attribute in RoiAlignOp is "
"invalid. The scale must be greater than 0. But "
"received 'spatial_scale' = %f",
spatial_scale));
......@@ -399,6 +399,81 @@ void RoiAlignInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void RoiPoolInferMeta(const MetaTensor& x,
const MetaTensor& boxes,
paddle::optional<const MetaTensor&> boxes_num,
int pooled_height,
int pooled_width,
float spatial_scale,
MetaTensor* out,
MetaTensor* arg_max) {
auto input_dims = x.dims();
auto boxes_dims = boxes.dims();
if (boxes_num) {
auto boxes_num_dims = boxes_num->dims();
PADDLE_ENFORCE_EQ(
boxes_num_dims.size(),
1,
phi::errors::InvalidArgument("The second dimension of boxes_num should "
"be 1, but received dimension is %d",
boxes_num_dims.size()));
}
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"The input data should be a four-dimensional "
"tensor with [N,C,H,W], but received input data with "
" %d dimension",
input_dims.size()));
PADDLE_ENFORCE_EQ(
boxes_dims.size(),
2,
phi::errors::InvalidArgument(
"boxes should be a 2-D LoDTensor with shape (num_boxes, 4)"
"given as [[x1, y1, x2, y2], ...], but received boxes is "
"%d-dimensional LoDTensor",
boxes_dims.size()));
PADDLE_ENFORCE_EQ(
boxes_dims[1],
4,
phi::errors::InvalidArgument(
"boxes should be a 2-D LoDTensor with shape (num_boxes, 4)"
"given as [[x1, y1, x2, y2], ...]. But the second dimension of "
"the received data is %d",
boxes_dims[1]));
PADDLE_ENFORCE_GT(
pooled_height,
0,
phi::errors::OutOfRange("The pooled output height must be greater than 0"
"but received height is %d",
pooled_height));
PADDLE_ENFORCE_GT(
pooled_width,
0,
phi::errors::OutOfRange("The pooled output width must be greater than 0"
"but received width is %d",
pooled_width));
PADDLE_ENFORCE_GT(
spatial_scale,
0.0f,
phi::errors::OutOfRange("The spatial scale must be greater than 0, "
"but received spatial scale is %f",
spatial_scale));
auto out_dims = input_dims;
out_dims[0] = boxes_dims[0];
out_dims[1] = input_dims[1];
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
arg_max->set_dims(out_dims);
arg_max->set_dtype(DataType::INT64);
}
void ScatterInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& updates,
......
......@@ -84,6 +84,15 @@ void RoiAlignInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void RoiPoolInferMeta(const MetaTensor& x,
const MetaTensor& boxes,
paddle::optional<const MetaTensor&> boxes_num,
int pooled_height,
int pooled_width,
float spatial_scale,
MetaTensor* out,
MetaTensor* arg_max);
void ScatterInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& updates,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册