提交 4284b857 编写于 作者: W wanghaox

update mine_hard_examples op

上级 62dc593e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -38,7 +38,7 @@ inline bool IsEligibleMining(const MiningType mining_type, const int match_idx,
}
}
MiningType GetMiningType(std::string str) {
inline MiningType GetMiningType(std::string str) {
if (str == "max_negative") {
return MiningType::kMaxNegative;
} else if (str == "hard_example") {
......@@ -112,7 +112,7 @@ class MineHardExamplesKernel : public framework::OpKernel<T> {
neg_sel = std::min(sample_size, neg_sel);
}
std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend<int>);
std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend<size_t>);
std::set<int> sel_indices;
std::vector<int> neg_indices;
std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel,
......@@ -121,18 +121,27 @@ class MineHardExamplesKernel : public framework::OpKernel<T> {
return static_cast<int>(l.second);
});
for (int m = 0; m < prior_num; ++m) {
if (match_indices(n, m) > -1) {
if (mining_type == MiningType::kHardExample &&
sel_indices.find(m) == sel_indices.end()) {
match_indices_et(n, m) = -1;
if (mining_type == MiningType::kHardExample) {
for (int m = 0; m < prior_num; ++m) {
if (match_indices(n, m) > -1) {
if (sel_indices.find(m) == sel_indices.end()) {
match_indices_et(n, m) = -1;
}
} else {
if (sel_indices.find(m) != sel_indices.end()) {
neg_indices.push_back(m);
}
}
} else {
if (sel_indices.find(m) != sel_indices.end()) {
}
} else {
for (int m = 0; m < prior_num; ++m) {
if (match_indices(n, m) == -1 &&
sel_indices.find(m) != sel_indices.end()) {
neg_indices.push_back(m);
}
}
}
all_neg_indices.push_back(neg_indices);
batch_starts.push_back(batch_starts.back() + neg_indices.size());
}
......@@ -253,7 +262,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
"[N, Np], N is the batch size and Np is the number of prior box.");
AddInput("LocLoss",
"(Tensor, optional, default Tensor<float>), The localization loss "
"wit shape [N, Np], N is the batch size and Np is the number of "
"with shape [N, Np], N is the batch size and Np is the number of "
"prior box.")
.AsDispensable();
AddInput("MatchIndices",
......@@ -267,15 +276,15 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
"Np], N is the batch size and Np is the number of prior box.");
AddAttr<float>("neg_pos_ratio",
"(float) The ratio of the negative box to the positive "
"box. Use only when mining_type is equal to max_negative.")
"box. Use only when mining_type is max_negative.")
.SetDefault(1.0);
AddAttr<float>("neg_dist_threshold",
"(float) The negative box dis value threshold. "
"Use only when mining_type is equal to max_negative.")
"(float) The negative overlap upper bound for the unmatched "
"predictions. Use only when mining_type is max_negative.")
.SetDefault(0.5);
AddAttr<int>("sample_size",
"(float) The max sample size of negative box. Use only when "
"mining_type is equal to hard_example.")
"mining_type is hard_example.")
.SetDefault(0);
AddAttr<std::string>("mining_type",
"(float) The mining algorithm name, the value is "
......@@ -295,7 +304,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("UpdatedMatchIndices",
"(Tensor<int>) The output of updated MatchIndices, a tensor with "
"shape [N, Np]. Only update when mining_type is equal to "
"shape [N, Np]. Only update when mining_type is "
"hard_example. The input MatchIndices elements will be update to "
"-1 when it is not in the candidate high loss list of negative "
"examples.");
......@@ -303,11 +312,12 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Mine hard examples Operator.
This operator implements hard example mining to select a subset of negative box indices.
For each image, selects the box with highest losses. subject to the condition that the box cannot have
an Matcht > neg_dist_threshold when mining_type is equals max_negative. The selected number is
min(sample_size, max_negative_box_number) when mining_type is equals hard_example,
or min(neg_pos_ratio * positive_box_number, max_negative_box_number) when mining_type is
equals max_negative, where the max_negative_box_number is the count of MatchIndices elements with value -1.
For each image, selects the box with highest losses. subject to the condition that the
box cannot have an Matcht > neg_dist_threshold when mining_type is max_negative.
The selected number is min(sample_size, max_negative_box_number) when mining_type is
hard_example, or min(neg_pos_ratio * positive_box_number, max_negative_box_number)
when mining_type is max_negative, where the max_negative_box_number is the count of
MatchIndices elements with value -1.
)DOC");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册