未验证 提交 c2f5a3ad 编写于 作者: W wangguanzhong 提交者: GitHub

enhance the error message of roi_align, test=develop (#23649)

上级 cec234b1
...@@ -23,35 +23,59 @@ class ROIAlignOp : public framework::OperatorWithKernel { ...@@ -23,35 +23,59 @@ class ROIAlignOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of ROIAlignOp should not be null."); platform::errors::NotFound("Input(X) of ROIAlignOp "
PADDLE_ENFORCE(ctx->HasInput("ROIs"), "is not found."));
"Input(ROIs) of ROIAlignOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true,
PADDLE_ENFORCE(ctx->HasOutput("Out"), platform::errors::NotFound("Input(ROIs) of ROIAlignOp "
"Output(Out) of ROIAlignOp should not be null."); "is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of ROIAlignOp "
"is not found."));
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs"); auto rois_dims = ctx->GetInputDim("ROIs");
PADDLE_ENFORCE(input_dims.size() == 4, PADDLE_ENFORCE_EQ(
"The format of input tensor is NCHW."); input_dims.size(), 4,
PADDLE_ENFORCE(rois_dims.size() == 2, platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" "The format of Input(X) in"
"given as [[x1, y1, x2, y2], ...]."); "RoIAlignOp is NCHW. And the rank of input must be 4. "
"But received rank = %d",
input_dims.size()));
PADDLE_ENFORCE_EQ(rois_dims.size(), 2, platform::errors::InvalidArgument(
"The rank of Input(ROIs) "
"in RoIAlignOp should be 2. "
"But the rank of RoIs is %d",
rois_dims.size()));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE(rois_dims[1] == 4, PADDLE_ENFORCE_EQ(rois_dims[1], 4,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" platform::errors::InvalidArgument(
"given as [[x1, y1, x2, y2], ...]."); "The second dimension "
"of Input(ROIs) should be 4. But received the "
"dimension = %d",
rois_dims[1]));
} }
int pooled_height = ctx->Attrs().Get<int>("pooled_height"); int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width"); int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale"); float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
PADDLE_ENFORCE_GT(pooled_height, 0, PADDLE_ENFORCE_GT(pooled_height, 0,
"The pooled output height must greater than 0"); platform::errors::InvalidArgument(
"The pooled output "
"height must greater than 0. But received "
"pooled_height = %d",
pooled_height));
PADDLE_ENFORCE_GT(pooled_width, 0, PADDLE_ENFORCE_GT(pooled_width, 0,
"The pooled output width must greater than 0"); platform::errors::InvalidArgument(
"The pooled output "
"width must greater than 0. But received "
"pooled_width = %d",
pooled_width));
PADDLE_ENFORCE_GT(spatial_scale, 0.0f, PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
"The spatial scale must greater than 0"); platform::errors::InvalidArgument(
"The spatial scale "
"must greater than 0 But received spatial_scale = %f",
spatial_scale));
auto out_dims = input_dims; auto out_dims = input_dims;
out_dims[0] = rois_dims[0]; out_dims[0] = rois_dims[0];
...@@ -76,10 +100,13 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { ...@@ -76,10 +100,13 @@ class ROIAlignGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE_EQ(
"The GRAD@Out of ROIAlignGradOp should not be null."); ctx->HasInput(framework::GradVarName("Out")), true,
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), platform::errors::NotFound("The GRAD@Out of ROIAlignGradOp "
"The GRAD@X of ROIAlignGradOp should not be null."); "is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutputs(framework::GradVarName("X")), true,
platform::errors::NotFound("The GRAD@X of ROIAlignGradOp "
"is not found."));
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
} }
......
...@@ -266,7 +266,11 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -266,7 +266,11 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size, rois_batch_size, batch_size,
"The rois_batch_size and imgs batch_size must be the same."); platform::errors::InvalidArgument(
"The rois_batch_size and imgs "
"batch_size must be the same. But received rois_batch_size = %d, "
"batch_size = %d",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size]; int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
"The rois_num from input and lod must be the same."); "The rois_num from input and lod must be the same.");
......
...@@ -172,7 +172,11 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -172,7 +172,11 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size, rois_batch_size, batch_size,
"The rois_batch_size and imgs batch_size must be the same."); platform::errors::InvalidArgument(
"The rois_batch_size and imgs "
"batch_size must be the same. But received rois_batch_size = %d, "
"batch_size = %d",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size]; int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
"The rois_num from input and lod must be the same."); "The rois_num from input and lod must be the same.");
......
...@@ -6723,6 +6723,9 @@ def roi_align(input, ...@@ -6723,6 +6723,9 @@ def roi_align(input,
spatial_scale=0.5, spatial_scale=0.5,
sampling_ratio=-1) sampling_ratio=-1)
""" """
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'roi_align')
check_variable_and_dtype(rois, 'rois', ['float32', 'float64'], 'roi_align')
helper = LayerHelper('roi_align', **locals()) helper = LayerHelper('roi_align', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
align_out = helper.create_variable_for_type_inference(dtype) align_out = helper.create_variable_for_type_inference(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册