提交 91a21883 编写于 作者: W wanghaox

update detection_map

上级 006ef1fd
...@@ -24,25 +24,28 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -24,25 +24,28 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Detection"), PADDLE_ENFORCE(ctx->HasInput("DetectRes"),
"Input(Detection) of DetectionMAPOp should not be null."); "Input(DetectRes) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of DetectionMAPOp should not be null."); "Input(Label) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutPosCount"), PADDLE_ENFORCE(
"Output(OutPosCount) of DetectionMAPOp should not be null."); ctx->HasOutput("AccumPosCount"),
PADDLE_ENFORCE(ctx->HasOutput("OutTruePos"), "Output(AccumPosCount) of DetectionMAPOp should not be null.");
"Output(OutTruePos) of DetectionMAPOp should not be null."); PADDLE_ENFORCE(
PADDLE_ENFORCE(ctx->HasOutput("OutFalsePos"), ctx->HasOutput("AccumTruePos"),
"Output(OutFalsePos) of DetectionMAPOp should not be null."); "Output(AccumTruePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumFalsePos"),
"Output(AccumFalsePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MAP"), PADDLE_ENFORCE(ctx->HasOutput("MAP"),
"Output(MAP) of DetectionMAPOp should not be null."); "Output(MAP) of DetectionMAPOp should not be null.");
auto det_dims = ctx->GetInputDim("Detection"); auto det_dims = ctx->GetInputDim("DetectRes");
PADDLE_ENFORCE_EQ(det_dims.size(), 2UL, PADDLE_ENFORCE_EQ(det_dims.size(), 2UL,
"The rank of Input(Detection) must be 2, " "The rank of Input(DetectRes) must be 2, "
"the shape is [N, 6]."); "the shape is [N, 6].");
PADDLE_ENFORCE_EQ(det_dims[1], 6UL, PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
"The shape is of Input(Detection) [N, 6]."); "The shape is of Input(DetectRes) [N, 6].");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
"The rank of Input(Label) must be 2, " "The rank of Input(Label) must be 2, "
...@@ -50,8 +53,17 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -50,8 +53,17 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(label_dims[1], 6UL, PADDLE_ENFORCE_EQ(label_dims[1], 6UL,
"The shape is of Input(Label) [N, 6]."); "The shape is of Input(Label) [N, 6].");
auto map_dim = framework::make_ddim({1}); if (ctx->HasInput("PosCount")) {
ctx->SetOutputDim("MAP", map_dim); PADDLE_ENFORCE(ctx->HasInput("TruePos"),
"Input(TruePos) of DetectionMAPOp should not be null when "
"Input(TruePos) is not null.");
PADDLE_ENFORCE(
ctx->HasInput("FalsePos"),
"Input(FalsePos) of DetectionMAPOp should not be null when "
"Input(FalsePos) is not null.");
}
ctx->SetOutputDim("MAP", framework::make_ddim({1}));
} }
protected: protected:
...@@ -59,7 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -59,7 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::ToDataType(
ctx.Input<framework::Tensor>("Detection")->type()), ctx.Input<framework::Tensor>("DetectRes")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -68,6 +80,14 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -68,6 +80,14 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker) DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax], M is the total "
"number of detect results in this mini-batch. For each instance, "
"the offsets in first dimension are called LoD, the number of "
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected data.");
AddInput("Label", AddInput("Label",
"(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the" "(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the"
"Labeled ground-truth data. Each row has 6 values: " "Labeled ground-truth data. Each row has 6 values: "
...@@ -76,38 +96,43 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -76,38 +96,43 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
"instance, the offsets in first dimension are called LoD, " "instance, the offsets in first dimension are called LoD, "
"the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, " "the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, "
"means there is no ground-truth data."); "means there is no ground-truth data.");
AddInput("Detection",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax], M is the total "
"number of detections in this mini-batch. For each instance, "
"the offsets in first dimension are called LoD, the number of "
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected data.");
AddInput("PosCount", AddInput("PosCount",
"(Tensor) A tensor with shape [Ncls, 1], store the " "(Tensor) A tensor with shape [Ncls, 1], store the "
"input positive example count of each class.") "input positive example count of each class, Ncls is the count of "
"input classification. "
"This input is used to pass the AccumPosCount generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. "
"When the input(PosCount) is empty, the cumulative "
"calculation is not carried out, and only the results of the "
"current mini-batch are calculated.")
.AsDispensable(); .AsDispensable();
AddInput("TruePos", AddInput("TruePos",
"(LodTensor) A 2-D LodTensor with shape [Ntp, 2], store the " "(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the "
"input true positive example of each class.") "input true positive example of each class."
"This input is used to pass the AccumTruePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. ")
.AsDispensable(); .AsDispensable();
AddInput("FalsePos", AddInput("FalsePos",
"(LodTensor) A 2-D LodTensor with shape [Nfp, 2], store the " "(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the "
"input false positive example of each class.") "input false positive example of each class."
"This input is used to pass the AccumFalsePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. ")
.AsDispensable(); .AsDispensable();
AddOutput("OutPosCount", AddOutput("AccumPosCount",
"(Tensor) A tensor with shape [Ncls, 1], store the " "(Tensor) A tensor with shape [Ncls, 1], store the "
"positive example count of each class. It combines the input " "positive example count of each class. It combines the input "
"input(PosCount) and the positive example count computed from " "input(PosCount) and the positive example count computed from "
"input(Detection) and input(Label)."); "input(Detection) and input(Label).");
AddOutput("OutTruePos", AddOutput("AccumTruePos",
"(LodTensor) A LodTensor with shape [Ntp', 2], store the " "(LoDTensor) A LoDTensor with shape [Ntp', 2], store the "
"true positive example of each class. It combines the " "true positive example of each class. It combines the "
"input(TruePos) and the true positive examples computed from " "input(TruePos) and the true positive examples computed from "
"input(Detection) and input(Label)."); "input(Detection) and input(Label).");
AddOutput("OutFalsePos", AddOutput("AccumFalsePos",
"(LodTensor) A LodTensor with shape [Nfp', 2], store the " "(LoDTensor) A LoDTensor with shape [Nfp', 2], store the "
"false positive example of each class. It combines the " "false positive example of each class. It combines the "
"input(FalsePos) and the false positive examples computed from " "input(FalsePos) and the false positive examples computed from "
"input(Detection) and input(Label)."); "input(Detection) and input(Label).");
...@@ -115,10 +140,11 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -115,10 +140,11 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) A tensor with shape [1], store the mAP evaluate " "(Tensor) A tensor with shape [1], store the mAP evaluate "
"result of the detection."); "result of the detection.");
AddAttr<float>("overlap_threshold", AddAttr<float>(
"(float) " "overlap_threshold",
"The jaccard overlap threshold of detection output and " "(float) "
"ground-truth data.") "The lower bound jaccard overlap threshold of detection output and "
"ground-truth data.")
.SetDefault(.3f); .SetDefault(.3f);
AddAttr<bool>("evaluate_difficult", AddAttr<bool>("evaluate_difficult",
"(bool, default true) " "(bool, default true) "
......
...@@ -54,7 +54,7 @@ template <typename Place, typename T> ...@@ -54,7 +54,7 @@ template <typename Place, typename T>
class DetectionMAPOpKernel : public framework::OpKernel<T> { class DetectionMAPOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_detect = ctx.Input<framework::LoDTensor>("Detection"); auto* in_detect = ctx.Input<framework::LoDTensor>("DetectRes");
auto* in_label = ctx.Input<framework::LoDTensor>("Label"); auto* in_label = ctx.Input<framework::LoDTensor>("Label");
auto* out_map = ctx.Output<framework::Tensor>("MAP"); auto* out_map = ctx.Output<framework::Tensor>("MAP");
...@@ -62,9 +62,9 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -62,9 +62,9 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
auto* in_true_pos = ctx.Input<framework::LoDTensor>("TruePos"); auto* in_true_pos = ctx.Input<framework::LoDTensor>("TruePos");
auto* in_false_pos = ctx.Input<framework::LoDTensor>("FalsePos"); auto* in_false_pos = ctx.Input<framework::LoDTensor>("FalsePos");
auto* out_pos_count = ctx.Output<framework::Tensor>("OutPosCount"); auto* out_pos_count = ctx.Output<framework::Tensor>("AccumPosCount");
auto* out_true_pos = ctx.Output<framework::LoDTensor>("OutTruePos"); auto* out_true_pos = ctx.Output<framework::LoDTensor>("AccumTruePos");
auto* out_false_pos = ctx.Output<framework::LoDTensor>("OutFalsePos"); auto* out_false_pos = ctx.Output<framework::LoDTensor>("AccumFalsePos");
float overlap_threshold = ctx.Attr<float>("overlap_threshold"); float overlap_threshold = ctx.Attr<float>("overlap_threshold");
float evaluate_difficult = ctx.Attr<bool>("evaluate_difficult"); float evaluate_difficult = ctx.Attr<bool>("evaluate_difficult");
...@@ -265,28 +265,22 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -265,28 +265,22 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
label_pos_count[i] = pos_count_data[i]; label_pos_count[i] = pos_count_data[i];
} }
const T* true_pos_data = input_true_pos.data<T>(); auto SetData = [](const framework::LoDTensor& pos_tensor,
auto true_pos_data_lod = input_true_pos.lod(); std::map<int, std::vector<std::pair<T, int>>>& pos) {
for (int i = 0; i < true_pos_data_lod.size(); ++i) { const T* pos_data = pos_tensor.data<T>();
for (int j = true_pos_data_lod[0][i]; j < true_pos_data_lod[0][i + 1]; auto pos_data_lod = pos_tensor.lod();
++j) { for (int i = 0; i < pos_data_lod.size(); ++i) {
T score = true_pos_data[j * 2]; for (int j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) {
int flag = 1; T score = pos_data[j * 2];
if (true_pos_data[j * 2 + 1] < kEPS) flag = 0; int flag = 1;
true_pos[i].push_back(std::make_pair(score, flag)); if (pos_data[j * 2 + 1] < kEPS) flag = 0;
} pos[i].push_back(std::make_pair(score, flag));
} }
const T* false_pos_data = input_false_pos.data<T>();
auto false_pos_data_lod = input_false_pos.lod();
for (int i = 0; i < false_pos_data_lod.size(); ++i) {
for (int j = false_pos_data_lod[0][i]; j < false_pos_data_lod[0][i + 1];
++j) {
T score = false_pos_data[j * 2];
int flag = 1;
if (false_pos_data[j * 2 + 1] < kEPS) flag = 0;
false_pos[i].push_back(std::make_pair(score, flag));
} }
} };
SetData(input_true_pos, true_pos);
SetData(input_false_pos, false_pos);
return; return;
} }
......
...@@ -37,7 +37,7 @@ class TestDetectionMAPOp(OpTest): ...@@ -37,7 +37,7 @@ class TestDetectionMAPOp(OpTest):
self.inputs = { self.inputs = {
'Label': (self.label, self.label_lod), 'Label': (self.label, self.label_lod),
'Detection': (self.detect, self.detect_lod), 'DetectRes': (self.detect, self.detect_lod),
'PosCount': self.class_pos_count, 'PosCount': self.class_pos_count,
'TruePos': (self.true_pos, self.true_pos_lod), 'TruePos': (self.true_pos, self.true_pos_lod),
'FalsePos': (self.false_pos, self.false_pos_lod) 'FalsePos': (self.false_pos, self.false_pos_lod)
...@@ -45,7 +45,7 @@ class TestDetectionMAPOp(OpTest): ...@@ -45,7 +45,7 @@ class TestDetectionMAPOp(OpTest):
else: else:
self.inputs = { self.inputs = {
'Label': (self.label, self.label_lod), 'Label': (self.label, self.label_lod),
'Detection': (self.detect, self.detect_lod), 'DetectRes': (self.detect, self.detect_lod),
} }
self.attrs = { self.attrs = {
...@@ -61,9 +61,9 @@ class TestDetectionMAPOp(OpTest): ...@@ -61,9 +61,9 @@ class TestDetectionMAPOp(OpTest):
self.outputs = { self.outputs = {
'MAP': self.mAP, 'MAP': self.mAP,
'OutPosCount': self.out_class_pos_count, 'AccumPosCount': self.out_class_pos_count,
'OutTruePos': (self.out_true_pos, self.out_true_pos_lod), 'AccumTruePos': (self.out_true_pos, self.out_true_pos_lod),
'OutFalsePos': (self.out_false_pos, self.out_false_pos_lod) 'AccumFalsePos': (self.out_false_pos, self.out_false_pos_lod)
} }
def init_test_case(self): def init_test_case(self):
...@@ -175,9 +175,7 @@ class TestDetectionMAPOp(OpTest): ...@@ -175,9 +175,7 @@ class TestDetectionMAPOp(OpTest):
false_pos[label].append([score, fp]) false_pos[label].append([score, fp])
for (label, label_pos_num) in label_count.items(): for (label, label_pos_num) in label_count.items():
if label_pos_num == 0 or label not in true_pos: if label_pos_num == 0 or label not in true_pos: continue
continue
label_true_pos = true_pos[label] label_true_pos = true_pos[label]
label_false_pos = false_pos[label] label_false_pos = false_pos[label]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册