未验证 提交 29f49229 编写于 作者: W wawltor 提交者: GitHub

optimize the error meesage for detetion_map_op

 optimize the error meesage for detetion_map_op
上级 daf5aa9b
...@@ -25,45 +25,55 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -25,45 +25,55 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("DetectRes"), OP_INOUT_CHECK(ctx->HasInput("DetectRes"), "Input", "DetectRes",
"Input(DetectRes) of DetectionMAPOp should not be null."); "DetectionMAP");
PADDLE_ENFORCE(ctx->HasInput("Label"), OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "DetectionMAP");
"Input(Label) of DetectionMAPOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("AccumPosCount"), "Output", "AccumPosCount",
PADDLE_ENFORCE( "DetectionMAP");
ctx->HasOutput("AccumPosCount"), OP_INOUT_CHECK(ctx->HasOutput("AccumTruePos"), "Output", "AccumTruePos",
"Output(AccumPosCount) of DetectionMAPOp should not be null."); "DetectionMAP");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasOutput("AccumFalsePos"), "Output", "AccumFalsePos",
ctx->HasOutput("AccumTruePos"), "DetectionMAP");
"Output(AccumTruePos) of DetectionMAPOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("MAP"), "Output", "MAP", "DetectionMAP");
PADDLE_ENFORCE(
ctx->HasOutput("AccumFalsePos"),
"Output(AccumFalsePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MAP"),
"Output(MAP) of DetectionMAPOp should not be null.");
auto det_dims = ctx->GetInputDim("DetectRes"); auto det_dims = ctx->GetInputDim("DetectRes");
PADDLE_ENFORCE_EQ(det_dims.size(), 2UL, PADDLE_ENFORCE_EQ(
"The rank of Input(DetectRes) must be 2, " det_dims.size(), 2UL,
"the shape is [N, 6]."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(det_dims[1], 6UL, "Input(DetectRes) ndim must be 2, the shape is [N, 6],"
"The shape is of Input(DetectRes) [N, 6]."); "but received the ndim is %d",
det_dims.size()));
PADDLE_ENFORCE_EQ(
det_dims[1], 6UL,
platform::errors::InvalidArgument(
"The shape is of Input(DetectRes) [N, 6], but received"
" shape is [N, %d]",
det_dims[1]));
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2, PADDLE_ENFORCE_EQ(label_dims.size(), 2,
"The rank of Input(Label) must be 2, " platform::errors::InvalidArgument(
"the shape is [N, 6]."); "The ndim of Input(Label) must be 2, but received %d",
label_dims.size()));
if (ctx->IsRuntime() || label_dims[1] > 0) { if (ctx->IsRuntime() || label_dims[1] > 0) {
PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5, PADDLE_ENFORCE_EQ(
"The shape of Input(Label) is [N, 6] or [N, 5]."); (label_dims[1] == 6 || label_dims[1] == 5), true,
platform::errors::InvalidArgument(
"The shape of Input(Label) is [N, 6] or [N, 5], but received "
"[N, %d]",
label_dims[1]));
} }
if (ctx->HasInput("PosCount")) { if (ctx->HasInput("PosCount")) {
PADDLE_ENFORCE(ctx->HasInput("TruePos"), PADDLE_ENFORCE(
"Input(TruePos) of DetectionMAPOp should not be null when " ctx->HasInput("TruePos"),
"Input(TruePos) is not null."); platform::errors::InvalidArgument(
"Input(TruePos) of DetectionMAPOp should not be null when "
"Input(PosCount) is not null."));
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasInput("FalsePos"), ctx->HasInput("FalsePos"),
"Input(FalsePos) of DetectionMAPOp should not be null when " platform::errors::InvalidArgument(
"Input(FalsePos) is not null."); "Input(FalsePos) of DetectionMAPOp should not be null when "
"Input(PosCount) is not null."));
} }
ctx->SetOutputDim("MAP", framework::make_ddim({1})); ctx->SetOutputDim("MAP", framework::make_ddim({1}));
...@@ -170,8 +180,10 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -170,8 +180,10 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault("integral") .SetDefault("integral")
.InEnum({"integral", "11point"}) .InEnum({"integral", "11point"})
.AddCustomChecker([](const std::string& ap_type) { .AddCustomChecker([](const std::string& ap_type) {
PADDLE_ENFORCE_NE(GetAPType(ap_type), APType::kNone, PADDLE_ENFORCE_NE(
"The ap_type should be 'integral' or '11point."); GetAPType(ap_type), APType::kNone,
platform::errors::InvalidArgument(
"The ap_type should be 'integral' or '11point."));
}); });
AddComment(R"DOC( AddComment(R"DOC(
Detection mAP evaluate operator. Detection mAP evaluate operator.
......
...@@ -78,11 +78,16 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -78,11 +78,16 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
auto& label_lod = in_label->lod(); auto& label_lod = in_label->lod();
auto& detect_lod = in_detect->lod(); auto& detect_lod = in_detect->lod();
PADDLE_ENFORCE_EQ(label_lod.size(), 1UL, PADDLE_ENFORCE_EQ(
"Only support one level sequence now."); label_lod.size(), 1UL,
platform::errors::InvalidArgument("Only support LodTensor of lod_level "
"with 1 in label, but received %d.",
label_lod.size()));
PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(), PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(),
"The batch_size of input(Label) and input(Detection) " platform::errors::InvalidArgument(
"must be the same."); "The batch_size of input(Label) and input(Detection) "
"must be the same, but received %d:%d",
label_lod[0].size(), detect_lod[0].size()));
std::vector<std::map<int, std::vector<Box>>> gt_boxes; std::vector<std::map<int, std::vector<Box>>> gt_boxes;
std::vector<std::map<int, std::vector<std::pair<T, Box>>>> detect_boxes; std::vector<std::map<int, std::vector<std::pair<T, Box>>>> detect_boxes;
...@@ -185,7 +190,12 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -185,7 +190,12 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
box.is_difficult = true; box.is_difficult = true;
boxes[label].push_back(box); boxes[label].push_back(box);
} else { } else {
PADDLE_ENFORCE_EQ(input_label.dims()[1], 5); PADDLE_ENFORCE_EQ(
input_label.dims()[1], 5,
platform::errors::InvalidArgument(
"The input label width"
" must be 5, but received %d, please check your input data",
input_label.dims()[1]));
Box box(labels(i, 1), labels(i, 2), labels(i, 3), labels(i, 4)); Box box(labels(i, 1), labels(i, 2), labels(i, 3), labels(i, 4));
boxes[label].push_back(box); boxes[label].push_back(box);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册