提交 91a21883 编写于 作者: W wanghaox

update detection_map

上级 006ef1fd
......@@ -24,25 +24,28 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Detection"),
"Input(Detection) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("DetectRes"),
"Input(DetectRes) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutPosCount"),
"Output(OutPosCount) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutTruePos"),
"Output(OutTruePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutFalsePos"),
"Output(OutFalsePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumPosCount"),
"Output(AccumPosCount) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumTruePos"),
"Output(AccumTruePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumFalsePos"),
"Output(AccumFalsePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MAP"),
"Output(MAP) of DetectionMAPOp should not be null.");
auto det_dims = ctx->GetInputDim("Detection");
auto det_dims = ctx->GetInputDim("DetectRes");
PADDLE_ENFORCE_EQ(det_dims.size(), 2UL,
"The rank of Input(Detection) must be 2, "
"The rank of Input(DetectRes) must be 2, "
"the shape is [N, 6].");
PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
"The shape is of Input(Detection) [N, 6].");
"The shape is of Input(DetectRes) [N, 6].");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
"The rank of Input(Label) must be 2, "
......@@ -50,8 +53,17 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(label_dims[1], 6UL,
"The shape is of Input(Label) [N, 6].");
auto map_dim = framework::make_ddim({1});
ctx->SetOutputDim("MAP", map_dim);
if (ctx->HasInput("PosCount")) {
PADDLE_ENFORCE(ctx->HasInput("TruePos"),
"Input(TruePos) of DetectionMAPOp should not be null when "
"Input(TruePos) is not null.");
PADDLE_ENFORCE(
ctx->HasInput("FalsePos"),
"Input(FalsePos) of DetectionMAPOp should not be null when "
"Input(FalsePos) is not null.");
}
ctx->SetOutputDim("MAP", framework::make_ddim({1}));
}
protected:
......@@ -59,7 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::Tensor>("Detection")->type()),
ctx.Input<framework::Tensor>("DetectRes")->type()),
ctx.device_context());
}
};
......@@ -68,6 +80,14 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax], M is the total "
"number of detect results in this mini-batch. For each instance, "
"the offsets in first dimension are called LoD, the number of "
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected data.");
AddInput("Label",
"(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the"
"Labeled ground-truth data. Each row has 6 values: "
......@@ -76,38 +96,43 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
"instance, the offsets in first dimension are called LoD, "
"the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, "
"means there is no ground-truth data.");
AddInput("Detection",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax], M is the total "
"number of detections in this mini-batch. For each instance, "
"the offsets in first dimension are called LoD, the number of "
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected data.");
AddInput("PosCount",
"(Tensor) A tensor with shape [Ncls, 1], store the "
"input positive example count of each class.")
"input positive example count of each class, Ncls is the count of "
"input classification. "
"This input is used to pass the AccumPosCount generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. "
"When the input(PosCount) is empty, the cumulative "
"calculation is not carried out, and only the results of the "
"current mini-batch are calculated.")
.AsDispensable();
AddInput("TruePos",
"(LodTensor) A 2-D LodTensor with shape [Ntp, 2], store the "
"input true positive example of each class.")
"(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the "
"input true positive example of each class."
"This input is used to pass the AccumTruePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. ")
.AsDispensable();
AddInput("FalsePos",
"(LodTensor) A 2-D LodTensor with shape [Nfp, 2], store the "
"input false positive example of each class.")
"(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the "
"input false positive example of each class."
"This input is used to pass the AccumFalsePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. ")
.AsDispensable();
AddOutput("OutPosCount",
AddOutput("AccumPosCount",
"(Tensor) A tensor with shape [Ncls, 1], store the "
"positive example count of each class. It combines the input "
"input(PosCount) and the positive example count computed from "
"input(Detection) and input(Label).");
AddOutput("OutTruePos",
"(LodTensor) A LodTensor with shape [Ntp', 2], store the "
AddOutput("AccumTruePos",
"(LoDTensor) A LoDTensor with shape [Ntp', 2], store the "
"true positive example of each class. It combines the "
"input(TruePos) and the true positive examples computed from "
"input(Detection) and input(Label).");
AddOutput("OutFalsePos",
"(LodTensor) A LodTensor with shape [Nfp', 2], store the "
AddOutput("AccumFalsePos",
"(LoDTensor) A LoDTensor with shape [Nfp', 2], store the "
"false positive example of each class. It combines the "
"input(FalsePos) and the false positive examples computed from "
"input(Detection) and input(Label).");
......@@ -115,9 +140,10 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) A tensor with shape [1], store the mAP evaluate "
"result of the detection.");
AddAttr<float>("overlap_threshold",
AddAttr<float>(
"overlap_threshold",
"(float) "
"The jaccard overlap threshold of detection output and "
"The lower bound jaccard overlap threshold of detection output and "
"ground-truth data.")
.SetDefault(.3f);
AddAttr<bool>("evaluate_difficult",
......
......@@ -54,7 +54,7 @@ template <typename Place, typename T>
class DetectionMAPOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_detect = ctx.Input<framework::LoDTensor>("Detection");
auto* in_detect = ctx.Input<framework::LoDTensor>("DetectRes");
auto* in_label = ctx.Input<framework::LoDTensor>("Label");
auto* out_map = ctx.Output<framework::Tensor>("MAP");
......@@ -62,9 +62,9 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
auto* in_true_pos = ctx.Input<framework::LoDTensor>("TruePos");
auto* in_false_pos = ctx.Input<framework::LoDTensor>("FalsePos");
auto* out_pos_count = ctx.Output<framework::Tensor>("OutPosCount");
auto* out_true_pos = ctx.Output<framework::LoDTensor>("OutTruePos");
auto* out_false_pos = ctx.Output<framework::LoDTensor>("OutFalsePos");
auto* out_pos_count = ctx.Output<framework::Tensor>("AccumPosCount");
auto* out_true_pos = ctx.Output<framework::LoDTensor>("AccumTruePos");
auto* out_false_pos = ctx.Output<framework::LoDTensor>("AccumFalsePos");
float overlap_threshold = ctx.Attr<float>("overlap_threshold");
float evaluate_difficult = ctx.Attr<bool>("evaluate_difficult");
......@@ -265,28 +265,22 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
label_pos_count[i] = pos_count_data[i];
}
const T* true_pos_data = input_true_pos.data<T>();
auto true_pos_data_lod = input_true_pos.lod();
for (int i = 0; i < true_pos_data_lod.size(); ++i) {
for (int j = true_pos_data_lod[0][i]; j < true_pos_data_lod[0][i + 1];
++j) {
T score = true_pos_data[j * 2];
auto SetData = [](const framework::LoDTensor& pos_tensor,
std::map<int, std::vector<std::pair<T, int>>>& pos) {
const T* pos_data = pos_tensor.data<T>();
auto pos_data_lod = pos_tensor.lod();
for (int i = 0; i < pos_data_lod.size(); ++i) {
for (int j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) {
T score = pos_data[j * 2];
int flag = 1;
if (true_pos_data[j * 2 + 1] < kEPS) flag = 0;
true_pos[i].push_back(std::make_pair(score, flag));
}
}
const T* false_pos_data = input_false_pos.data<T>();
auto false_pos_data_lod = input_false_pos.lod();
for (int i = 0; i < false_pos_data_lod.size(); ++i) {
for (int j = false_pos_data_lod[0][i]; j < false_pos_data_lod[0][i + 1];
++j) {
T score = false_pos_data[j * 2];
int flag = 1;
if (false_pos_data[j * 2 + 1] < kEPS) flag = 0;
false_pos[i].push_back(std::make_pair(score, flag));
if (pos_data[j * 2 + 1] < kEPS) flag = 0;
pos[i].push_back(std::make_pair(score, flag));
}
}
};
SetData(input_true_pos, true_pos);
SetData(input_false_pos, false_pos);
return;
}
......
......@@ -37,7 +37,7 @@ class TestDetectionMAPOp(OpTest):
self.inputs = {
'Label': (self.label, self.label_lod),
'Detection': (self.detect, self.detect_lod),
'DetectRes': (self.detect, self.detect_lod),
'PosCount': self.class_pos_count,
'TruePos': (self.true_pos, self.true_pos_lod),
'FalsePos': (self.false_pos, self.false_pos_lod)
......@@ -45,7 +45,7 @@ class TestDetectionMAPOp(OpTest):
else:
self.inputs = {
'Label': (self.label, self.label_lod),
'Detection': (self.detect, self.detect_lod),
'DetectRes': (self.detect, self.detect_lod),
}
self.attrs = {
......@@ -61,9 +61,9 @@ class TestDetectionMAPOp(OpTest):
self.outputs = {
'MAP': self.mAP,
'OutPosCount': self.out_class_pos_count,
'OutTruePos': (self.out_true_pos, self.out_true_pos_lod),
'OutFalsePos': (self.out_false_pos, self.out_false_pos_lod)
'AccumPosCount': self.out_class_pos_count,
'AccumTruePos': (self.out_true_pos, self.out_true_pos_lod),
'AccumFalsePos': (self.out_false_pos, self.out_false_pos_lod)
}
def init_test_case(self):
......@@ -175,9 +175,7 @@ class TestDetectionMAPOp(OpTest):
false_pos[label].append([score, fp])
for (label, label_pos_num) in label_count.items():
if label_pos_num == 0 or label not in true_pos:
continue
if label_pos_num == 0 or label not in true_pos: continue
label_true_pos = true_pos[label]
label_false_pos = false_pos[label]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册