提交 362850b9 编写于 作者: W wawltor

optimize the error meesage for detetion_map_op, test=develop

上级 5641ea2b
......@@ -25,45 +25,55 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("DetectRes"),
"Input(DetectRes) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumPosCount"),
"Output(AccumPosCount) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumTruePos"),
"Output(AccumTruePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumFalsePos"),
"Output(AccumFalsePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MAP"),
"Output(MAP) of DetectionMAPOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("DetectRes"), "Input", "DetectRes",
"DetectionMAP");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "DetectionMAP");
OP_INOUT_CHECK(ctx->HasOutput("AccumPosCount"), "Output", "AccumPosCount",
"DetectionMAP");
OP_INOUT_CHECK(ctx->HasOutput("AccumTruePos"), "Output", "AccumTruePos",
"DetectionMAP");
OP_INOUT_CHECK(ctx->HasOutput("AccumFalsePos"), "Output", "AccumFalsePos",
"DetectionMAP");
OP_INOUT_CHECK(ctx->HasOutput("MAP"), "Output", "MAP", "DetectionMAP");
auto det_dims = ctx->GetInputDim("DetectRes");
PADDLE_ENFORCE_EQ(det_dims.size(), 2UL,
"The rank of Input(DetectRes) must be 2, "
"the shape is [N, 6].");
PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
"The shape is of Input(DetectRes) [N, 6].");
PADDLE_ENFORCE_EQ(
det_dims.size(), 2UL,
platform::errors::InvalidArgument(
"Input(DetectRes) ndim must be 2, the shape is [N, 6],"
"but received the ndim is %d",
det_dims.size()));
PADDLE_ENFORCE_EQ(
det_dims[1], 6UL,
platform::errors::InvalidArgument(
"The shape is of Input(DetectRes) [N, 6], but received"
" shape is [N, %d]",
det_dims[1]));
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
"The rank of Input(Label) must be 2, "
"the shape is [N, 6].");
platform::errors::InvalidArgument(
"The ndim of Input(Label) must be 2, but received %d",
label_dims.size()));
if (ctx->IsRuntime() || label_dims[1] > 0) {
PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5,
"The shape of Input(Label) is [N, 6] or [N, 5].");
PADDLE_ENFORCE_EQ(
(label_dims[1] == 6 || label_dims[1] == 5), true,
platform::errors::InvalidArgument(
"The shape of Input(Label) is [N, 6] or [N, 5], but received "
"[N, %d]",
label_dims[1]));
}
if (ctx->HasInput("PosCount")) {
PADDLE_ENFORCE(ctx->HasInput("TruePos"),
"Input(TruePos) of DetectionMAPOp should not be null when "
"Input(TruePos) is not null.");
PADDLE_ENFORCE(
ctx->HasInput("TruePos"),
platform::errors::InvalidArgument(
"Input(TruePos) of DetectionMAPOp should not be null when "
"Input(TruePos) is not null."));
PADDLE_ENFORCE(
ctx->HasInput("FalsePos"),
"Input(FalsePos) of DetectionMAPOp should not be null when "
"Input(FalsePos) is not null.");
platform::errors::InvalidArgument(
"Input(FalsePos) of DetectionMAPOp should not be null when "
"Input(FalsePos) is not null."));
}
ctx->SetOutputDim("MAP", framework::make_ddim({1}));
......@@ -170,8 +180,10 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault("integral")
.InEnum({"integral", "11point"})
.AddCustomChecker([](const std::string& ap_type) {
PADDLE_ENFORCE_NE(GetAPType(ap_type), APType::kNone,
"The ap_type should be 'integral' or '11point.");
PADDLE_ENFORCE_NE(
GetAPType(ap_type), APType::kNone,
platform::errors::InvalidArgument(
"The ap_type should be 'integral' or '11point."));
});
AddComment(R"DOC(
Detection mAP evaluate operator.
......
......@@ -78,11 +78,16 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
auto& label_lod = in_label->lod();
auto& detect_lod = in_detect->lod();
PADDLE_ENFORCE_EQ(label_lod.size(), 1UL,
"Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
label_lod.size(), 1UL,
platform::errors::InvalidArgument("Only support LodTensor of lod_level "
"with 1 in label, but received %d.",
label_lod.size()));
PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(),
"The batch_size of input(Label) and input(Detection) "
"must be the same.");
platform::errors::InvalidArgument(
"The batch_size of input(Label) and input(Detection) "
"must be the same, but received %d:%d",
label_lod[0].size(), detect_lod[0].size()));
std::vector<std::map<int, std::vector<Box>>> gt_boxes;
std::vector<std::map<int, std::vector<std::pair<T, Box>>>> detect_boxes;
......@@ -185,7 +190,12 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
box.is_difficult = true;
boxes[label].push_back(box);
} else {
PADDLE_ENFORCE_EQ(input_label.dims()[1], 5);
PADDLE_ENFORCE_EQ(
input_label.dims()[1], 5,
platform::errors::InvalidArgument(
"The input label width"
" must be 5, but received %d, please check your input data",
input_label.dims()[1]));
Box box(labels(i, 1), labels(i, 2), labels(i, 3), labels(i, 4));
boxes[label].push_back(box);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册