From 4284b857cb61f9ad090044834f3c0f62c339c0b2 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Fri, 2 Feb 2018 15:45:13 +0800 Subject: [PATCH] update mine_hard_examples op --- paddle/operators/mine_hard_examples_op.cc | 52 ++++++++++++++--------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/paddle/operators/mine_hard_examples_op.cc b/paddle/operators/mine_hard_examples_op.cc index 603368f93ca..2a3bd139ed2 100644 --- a/paddle/operators/mine_hard_examples_op.cc +++ b/paddle/operators/mine_hard_examples_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ inline bool IsEligibleMining(const MiningType mining_type, const int match_idx, } } -MiningType GetMiningType(std::string str) { +inline MiningType GetMiningType(std::string str) { if (str == "max_negative") { return MiningType::kMaxNegative; } else if (str == "hard_example") { @@ -112,7 +112,7 @@ class MineHardExamplesKernel : public framework::OpKernel { neg_sel = std::min(sample_size, neg_sel); } - std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend); + std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend); std::set sel_indices; std::vector neg_indices; std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel, @@ -121,18 +121,27 @@ class MineHardExamplesKernel : public framework::OpKernel { return static_cast(l.second); }); - for (int m = 0; m < prior_num; ++m) { - if (match_indices(n, m) > -1) { - if (mining_type == MiningType::kHardExample && - sel_indices.find(m) == sel_indices.end()) { - match_indices_et(n, m) = -1; + if (mining_type == MiningType::kHardExample) { + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) > -1) { + if (sel_indices.find(m) == sel_indices.end()) { + match_indices_et(n, m) = -1; + } + } else { + if (sel_indices.find(m) != sel_indices.end()) { + neg_indices.push_back(m); + } } - } else { - if (sel_indices.find(m) != sel_indices.end()) { + } + } else { + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) == -1 && + sel_indices.find(m) != sel_indices.end()) { neg_indices.push_back(m); } } } + all_neg_indices.push_back(neg_indices); batch_starts.push_back(batch_starts.back() + neg_indices.size()); } @@ -253,7 +262,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { "[N, Np], N is the batch size and Np is the number of prior box."); AddInput("LocLoss", "(Tensor, optional, default Tensor), The localization loss " - "wit shape [N, Np], N is the batch size and Np is the number of " + "with shape [N, Np], N is the batch size and Np is the number of " "prior box.") .AsDispensable(); AddInput("MatchIndices", @@ -267,15 +276,15 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { "Np], N is the batch size and Np is the number of prior box."); AddAttr("neg_pos_ratio", "(float) The ratio of the negative box to the positive " - "box. Use only when mining_type is equal to max_negative.") + "box. Use only when mining_type is max_negative.") .SetDefault(1.0); AddAttr("neg_dist_threshold", - "(float) The negative box dis value threshold. " - "Use only when mining_type is equal to max_negative.") + "(float) The negative overlap upper bound for the unmatched " + "predictions. Use only when mining_type is max_negative.") .SetDefault(0.5); AddAttr("sample_size", "(float) The max sample size of negative box. Use only when " - "mining_type is equal to hard_example.") + "mining_type is hard_example.") .SetDefault(0); AddAttr("mining_type", "(float) The mining algorithm name, the value is " @@ -295,7 +304,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("UpdatedMatchIndices", "(Tensor) The output of updated MatchIndices, a tensor with " - "shape [N, Np]. Only update when mining_type is equal to " + "shape [N, Np]. Only update when mining_type is " "hard_example. The input MatchIndices elements will be update to " "-1 when it is not in the candidate high loss list of negative " "examples."); @@ -303,11 +312,12 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Mine hard examples Operator. This operator implements hard example mining to select a subset of negative box indices. -For each image, selects the box with highest losses. subject to the condition that the box cannot have -an Matcht > neg_dist_threshold when mining_type is equals max_negative. The selected number is -min(sample_size, max_negative_box_number) when mining_type is equals hard_example, -or min(neg_pos_ratio * positive_box_number, max_negative_box_number) when mining_type is -equals max_negative, where the max_negative_box_number is the count of MatchIndices elements with value -1. +For each image, selects the box with highest losses. subject to the condition that the +box cannot have an Matcht > neg_dist_threshold when mining_type is max_negative. +The selected number is min(sample_size, max_negative_box_number) when mining_type is +hard_example, or min(neg_pos_ratio * positive_box_number, max_negative_box_number) +when mining_type is max_negative, where the max_negative_box_number is the count of +MatchIndices elements with value -1. )DOC"); } }; -- GitLab