From 91a2188301b82151560c59501cca45785d34cfcb Mon Sep 17 00:00:00 2001 From: wanghaox Date: Mon, 12 Feb 2018 10:39:59 +0800 Subject: [PATCH] update detection_map --- paddle/fluid/operators/detection_map_op.cc | 98 ++++++++++++------- paddle/fluid/operators/detection_map_op.h | 44 ++++----- .../v2/fluid/tests/test_detection_map_op.py | 14 ++- 3 files changed, 87 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index cc4b6202c0..48308a11b4 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -24,25 +24,28 @@ class DetectionMAPOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Detection"), - "Input(Detection) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("DetectRes"), + "Input(DetectRes) of DetectionMAPOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) of DetectionMAPOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutPosCount"), - "Output(OutPosCount) of DetectionMAPOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutTruePos"), - "Output(OutTruePos) of DetectionMAPOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutFalsePos"), - "Output(OutFalsePos) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumPosCount"), + "Output(AccumPosCount) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumTruePos"), + "Output(AccumTruePos) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumFalsePos"), + "Output(AccumFalsePos) of DetectionMAPOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("MAP"), "Output(MAP) of DetectionMAPOp should not be null."); - auto det_dims = ctx->GetInputDim("Detection"); + auto det_dims = ctx->GetInputDim("DetectRes"); PADDLE_ENFORCE_EQ(det_dims.size(), 2UL, - "The rank of Input(Detection) must be 2, " + "The rank of Input(DetectRes) must be 2, " "the shape is [N, 6]."); PADDLE_ENFORCE_EQ(det_dims[1], 6UL, - "The shape is of Input(Detection) [N, 6]."); + "The shape is of Input(DetectRes) [N, 6]."); auto label_dims = ctx->GetInputDim("Label"); PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, "The rank of Input(Label) must be 2, " @@ -50,8 +53,17 @@ class DetectionMAPOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(label_dims[1], 6UL, "The shape is of Input(Label) [N, 6]."); - auto map_dim = framework::make_ddim({1}); - ctx->SetOutputDim("MAP", map_dim); + if (ctx->HasInput("PosCount")) { + PADDLE_ENFORCE(ctx->HasInput("TruePos"), + "Input(TruePos) of DetectionMAPOp should not be null when " + "Input(TruePos) is not null."); + PADDLE_ENFORCE( + ctx->HasInput("FalsePos"), + "Input(FalsePos) of DetectionMAPOp should not be null when " + "Input(FalsePos) is not null."); + } + + ctx->SetOutputDim("MAP", framework::make_ddim({1})); } protected: @@ -59,7 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType( - ctx.Input("Detection")->type()), + ctx.Input("DetectRes")->type()), ctx.device_context()); } }; @@ -68,6 +80,14 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { public: DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("DetectRes", + "(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the " + "detections. Each row has 6 values: " + "[label, confidence, xmin, ymin, xmax, ymax], M is the total " + "number of detect results in this mini-batch. For each instance, " + "the offsets in first dimension are called LoD, the number of " + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " + "no detected data."); AddInput("Label", "(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the" "Labeled ground-truth data. Each row has 6 values: " @@ -76,38 +96,43 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { "instance, the offsets in first dimension are called LoD, " "the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, " "means there is no ground-truth data."); - AddInput("Detection", - "(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the " - "detections. Each row has 6 values: " - "[label, confidence, xmin, ymin, xmax, ymax], M is the total " - "number of detections in this mini-batch. For each instance, " - "the offsets in first dimension are called LoD, the number of " - "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " - "no detected data."); AddInput("PosCount", "(Tensor) A tensor with shape [Ncls, 1], store the " - "input positive example count of each class.") + "input positive example count of each class, Ncls is the count of " + "input classification. " + "This input is used to pass the AccumPosCount generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. " + "When the input(PosCount) is empty, the cumulative " + "calculation is not carried out, and only the results of the " + "current mini-batch are calculated.") .AsDispensable(); AddInput("TruePos", - "(LodTensor) A 2-D LodTensor with shape [Ntp, 2], store the " - "input true positive example of each class.") + "(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the " + "input true positive example of each class." + "This input is used to pass the AccumTruePos generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. ") .AsDispensable(); AddInput("FalsePos", - "(LodTensor) A 2-D LodTensor with shape [Nfp, 2], store the " - "input false positive example of each class.") + "(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the " + "input false positive example of each class." + "This input is used to pass the AccumFalsePos generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. ") .AsDispensable(); - AddOutput("OutPosCount", + AddOutput("AccumPosCount", "(Tensor) A tensor with shape [Ncls, 1], store the " "positive example count of each class. It combines the input " "input(PosCount) and the positive example count computed from " "input(Detection) and input(Label)."); - AddOutput("OutTruePos", - "(LodTensor) A LodTensor with shape [Ntp', 2], store the " + AddOutput("AccumTruePos", + "(LoDTensor) A LoDTensor with shape [Ntp', 2], store the " "true positive example of each class. It combines the " "input(TruePos) and the true positive examples computed from " "input(Detection) and input(Label)."); - AddOutput("OutFalsePos", - "(LodTensor) A LodTensor with shape [Nfp', 2], store the " + AddOutput("AccumFalsePos", + "(LoDTensor) A LoDTensor with shape [Nfp', 2], store the " "false positive example of each class. It combines the " "input(FalsePos) and the false positive examples computed from " "input(Detection) and input(Label)."); @@ -115,10 +140,11 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) A tensor with shape [1], store the mAP evaluate " "result of the detection."); - AddAttr("overlap_threshold", - "(float) " - "The jaccard overlap threshold of detection output and " - "ground-truth data.") + AddAttr( + "overlap_threshold", + "(float) " + "The lower bound jaccard overlap threshold of detection output and " + "ground-truth data.") .SetDefault(.3f); AddAttr("evaluate_difficult", "(bool, default true) " diff --git a/paddle/fluid/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h index 0379a3328a..0f5f588e9c 100644 --- a/paddle/fluid/operators/detection_map_op.h +++ b/paddle/fluid/operators/detection_map_op.h @@ -54,7 +54,7 @@ template class DetectionMAPOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_detect = ctx.Input("Detection"); + auto* in_detect = ctx.Input("DetectRes"); auto* in_label = ctx.Input("Label"); auto* out_map = ctx.Output("MAP"); @@ -62,9 +62,9 @@ class DetectionMAPOpKernel : public framework::OpKernel { auto* in_true_pos = ctx.Input("TruePos"); auto* in_false_pos = ctx.Input("FalsePos"); - auto* out_pos_count = ctx.Output("OutPosCount"); - auto* out_true_pos = ctx.Output("OutTruePos"); - auto* out_false_pos = ctx.Output("OutFalsePos"); + auto* out_pos_count = ctx.Output("AccumPosCount"); + auto* out_true_pos = ctx.Output("AccumTruePos"); + auto* out_false_pos = ctx.Output("AccumFalsePos"); float overlap_threshold = ctx.Attr("overlap_threshold"); float evaluate_difficult = ctx.Attr("evaluate_difficult"); @@ -265,28 +265,22 @@ class DetectionMAPOpKernel : public framework::OpKernel { label_pos_count[i] = pos_count_data[i]; } - const T* true_pos_data = input_true_pos.data(); - auto true_pos_data_lod = input_true_pos.lod(); - for (int i = 0; i < true_pos_data_lod.size(); ++i) { - for (int j = true_pos_data_lod[0][i]; j < true_pos_data_lod[0][i + 1]; - ++j) { - T score = true_pos_data[j * 2]; - int flag = 1; - if (true_pos_data[j * 2 + 1] < kEPS) flag = 0; - true_pos[i].push_back(std::make_pair(score, flag)); - } - } - const T* false_pos_data = input_false_pos.data(); - auto false_pos_data_lod = input_false_pos.lod(); - for (int i = 0; i < false_pos_data_lod.size(); ++i) { - for (int j = false_pos_data_lod[0][i]; j < false_pos_data_lod[0][i + 1]; - ++j) { - T score = false_pos_data[j * 2]; - int flag = 1; - if (false_pos_data[j * 2 + 1] < kEPS) flag = 0; - false_pos[i].push_back(std::make_pair(score, flag)); + auto SetData = [](const framework::LoDTensor& pos_tensor, + std::map>>& pos) { + const T* pos_data = pos_tensor.data(); + auto pos_data_lod = pos_tensor.lod(); + for (int i = 0; i < pos_data_lod.size(); ++i) { + for (int j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) { + T score = pos_data[j * 2]; + int flag = 1; + if (pos_data[j * 2 + 1] < kEPS) flag = 0; + pos[i].push_back(std::make_pair(score, flag)); + } } - } + }; + + SetData(input_true_pos, true_pos); + SetData(input_false_pos, false_pos); return; } diff --git a/python/paddle/v2/fluid/tests/test_detection_map_op.py b/python/paddle/v2/fluid/tests/test_detection_map_op.py index ec57ca4ad5..70ccd885d8 100644 --- a/python/paddle/v2/fluid/tests/test_detection_map_op.py +++ b/python/paddle/v2/fluid/tests/test_detection_map_op.py @@ -37,7 +37,7 @@ class TestDetectionMAPOp(OpTest): self.inputs = { 'Label': (self.label, self.label_lod), - 'Detection': (self.detect, self.detect_lod), + 'DetectRes': (self.detect, self.detect_lod), 'PosCount': self.class_pos_count, 'TruePos': (self.true_pos, self.true_pos_lod), 'FalsePos': (self.false_pos, self.false_pos_lod) @@ -45,7 +45,7 @@ class TestDetectionMAPOp(OpTest): else: self.inputs = { 'Label': (self.label, self.label_lod), - 'Detection': (self.detect, self.detect_lod), + 'DetectRes': (self.detect, self.detect_lod), } self.attrs = { @@ -61,9 +61,9 @@ class TestDetectionMAPOp(OpTest): self.outputs = { 'MAP': self.mAP, - 'OutPosCount': self.out_class_pos_count, - 'OutTruePos': (self.out_true_pos, self.out_true_pos_lod), - 'OutFalsePos': (self.out_false_pos, self.out_false_pos_lod) + 'AccumPosCount': self.out_class_pos_count, + 'AccumTruePos': (self.out_true_pos, self.out_true_pos_lod), + 'AccumFalsePos': (self.out_false_pos, self.out_false_pos_lod) } def init_test_case(self): @@ -175,9 +175,7 @@ class TestDetectionMAPOp(OpTest): false_pos[label].append([score, fp]) for (label, label_pos_num) in label_count.items(): - if label_pos_num == 0 or label not in true_pos: - continue - + if label_pos_num == 0 or label not in true_pos: continue label_true_pos = true_pos[label] label_false_pos = false_pos[label] -- GitLab