提交 4284b857 编写于 作者: W wanghaox

update mine_hard_examples op

上级 62dc593e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -38,7 +38,7 @@ inline bool IsEligibleMining(const MiningType mining_type, const int match_idx, ...@@ -38,7 +38,7 @@ inline bool IsEligibleMining(const MiningType mining_type, const int match_idx,
} }
} }
MiningType GetMiningType(std::string str) { inline MiningType GetMiningType(std::string str) {
if (str == "max_negative") { if (str == "max_negative") {
return MiningType::kMaxNegative; return MiningType::kMaxNegative;
} else if (str == "hard_example") { } else if (str == "hard_example") {
...@@ -112,7 +112,7 @@ class MineHardExamplesKernel : public framework::OpKernel<T> { ...@@ -112,7 +112,7 @@ class MineHardExamplesKernel : public framework::OpKernel<T> {
neg_sel = std::min(sample_size, neg_sel); neg_sel = std::min(sample_size, neg_sel);
} }
std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend<int>); std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend<size_t>);
std::set<int> sel_indices; std::set<int> sel_indices;
std::vector<int> neg_indices; std::vector<int> neg_indices;
std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel, std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel,
...@@ -121,18 +121,27 @@ class MineHardExamplesKernel : public framework::OpKernel<T> { ...@@ -121,18 +121,27 @@ class MineHardExamplesKernel : public framework::OpKernel<T> {
return static_cast<int>(l.second); return static_cast<int>(l.second);
}); });
for (int m = 0; m < prior_num; ++m) { if (mining_type == MiningType::kHardExample) {
if (match_indices(n, m) > -1) { for (int m = 0; m < prior_num; ++m) {
if (mining_type == MiningType::kHardExample && if (match_indices(n, m) > -1) {
sel_indices.find(m) == sel_indices.end()) { if (sel_indices.find(m) == sel_indices.end()) {
match_indices_et(n, m) = -1; match_indices_et(n, m) = -1;
}
} else {
if (sel_indices.find(m) != sel_indices.end()) {
neg_indices.push_back(m);
}
} }
} else { }
if (sel_indices.find(m) != sel_indices.end()) { } else {
for (int m = 0; m < prior_num; ++m) {
if (match_indices(n, m) == -1 &&
sel_indices.find(m) != sel_indices.end()) {
neg_indices.push_back(m); neg_indices.push_back(m);
} }
} }
} }
all_neg_indices.push_back(neg_indices); all_neg_indices.push_back(neg_indices);
batch_starts.push_back(batch_starts.back() + neg_indices.size()); batch_starts.push_back(batch_starts.back() + neg_indices.size());
} }
...@@ -253,7 +262,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -253,7 +262,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
"[N, Np], N is the batch size and Np is the number of prior box."); "[N, Np], N is the batch size and Np is the number of prior box.");
AddInput("LocLoss", AddInput("LocLoss",
"(Tensor, optional, default Tensor<float>), The localization loss " "(Tensor, optional, default Tensor<float>), The localization loss "
"wit shape [N, Np], N is the batch size and Np is the number of " "with shape [N, Np], N is the batch size and Np is the number of "
"prior box.") "prior box.")
.AsDispensable(); .AsDispensable();
AddInput("MatchIndices", AddInput("MatchIndices",
...@@ -267,15 +276,15 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -267,15 +276,15 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
"Np], N is the batch size and Np is the number of prior box."); "Np], N is the batch size and Np is the number of prior box.");
AddAttr<float>("neg_pos_ratio", AddAttr<float>("neg_pos_ratio",
"(float) The ratio of the negative box to the positive " "(float) The ratio of the negative box to the positive "
"box. Use only when mining_type is equal to max_negative.") "box. Use only when mining_type is max_negative.")
.SetDefault(1.0); .SetDefault(1.0);
AddAttr<float>("neg_dist_threshold", AddAttr<float>("neg_dist_threshold",
"(float) The negative box dis value threshold. " "(float) The negative overlap upper bound for the unmatched "
"Use only when mining_type is equal to max_negative.") "predictions. Use only when mining_type is max_negative.")
.SetDefault(0.5); .SetDefault(0.5);
AddAttr<int>("sample_size", AddAttr<int>("sample_size",
"(float) The max sample size of negative box. Use only when " "(float) The max sample size of negative box. Use only when "
"mining_type is equal to hard_example.") "mining_type is hard_example.")
.SetDefault(0); .SetDefault(0);
AddAttr<std::string>("mining_type", AddAttr<std::string>("mining_type",
"(float) The mining algorithm name, the value is " "(float) The mining algorithm name, the value is "
...@@ -295,7 +304,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -295,7 +304,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("UpdatedMatchIndices", AddOutput("UpdatedMatchIndices",
"(Tensor<int>) The output of updated MatchIndices, a tensor with " "(Tensor<int>) The output of updated MatchIndices, a tensor with "
"shape [N, Np]. Only update when mining_type is equal to " "shape [N, Np]. Only update when mining_type is "
"hard_example. The input MatchIndices elements will be update to " "hard_example. The input MatchIndices elements will be update to "
"-1 when it is not in the candidate high loss list of negative " "-1 when it is not in the candidate high loss list of negative "
"examples."); "examples.");
...@@ -303,11 +312,12 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -303,11 +312,12 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Mine hard examples Operator. Mine hard examples Operator.
This operator implements hard example mining to select a subset of negative box indices. This operator implements hard example mining to select a subset of negative box indices.
For each image, selects the box with highest losses. subject to the condition that the box cannot have For each image, selects the box with highest losses. subject to the condition that the
an Matcht > neg_dist_threshold when mining_type is equals max_negative. The selected number is box cannot have an Matcht > neg_dist_threshold when mining_type is max_negative.
min(sample_size, max_negative_box_number) when mining_type is equals hard_example, The selected number is min(sample_size, max_negative_box_number) when mining_type is
or min(neg_pos_ratio * positive_box_number, max_negative_box_number) when mining_type is hard_example, or min(neg_pos_ratio * positive_box_number, max_negative_box_number)
equals max_negative, where the max_negative_box_number is the count of MatchIndices elements with value -1. when mining_type is max_negative, where the max_negative_box_number is the count of
MatchIndices elements with value -1.
)DOC"); )DOC");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册