diff --git a/paddle/operators/mine_hard_examples_op.cc b/paddle/operators/mine_hard_examples_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..051cc24706d69ec4f38524af1dd510bf079c74c7 --- /dev/null +++ b/paddle/operators/mine_hard_examples_op.cc @@ -0,0 +1,330 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +enum MiningType { kNone = 0, kMaxNegative, kHardExample }; + +template +bool SortScoreDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +inline bool IsEligibleMining(const MiningType mining_type, const int match_idx, + const float match_dist, + const float neg_dist_threshold) { + if (mining_type == MiningType::kMaxNegative) { + return match_idx == -1 && match_dist < neg_dist_threshold; + } else if (mining_type == MiningType::kHardExample) { + return true; + } else { + return false; + } +} + +inline MiningType GetMiningType(std::string str) { + if (str == "max_negative") { + return MiningType::kMaxNegative; + } else if (str == "hard_example") { + return MiningType::kHardExample; + } else { + return MiningType::kNone; + } +} + +template +class MineHardExamplesKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_cls_loss = ctx.Input("ClsLoss"); + auto* in_loc_loss = ctx.Input("LocLoss"); + auto* in_matched_indices = ctx.Input("MatchIndices"); + auto* in_match_dist = ctx.Input("MatchDist"); + float neg_pos_ratio = ctx.Attr("neg_pos_ratio"); + T neg_dist_threshold = + static_cast(ctx.Attr("neg_dist_threshold")); + int sample_size = ctx.Attr("sample_size"); + MiningType mining_type = + GetMiningType(ctx.Attr("mining_type")); + + auto out_neg_indices = ctx.Output("NegIndices"); + auto out_match_indices = + ctx.Output("UpdatedMatchIndices"); + + framework::Copy(*in_matched_indices, ctx.GetPlace(), out_match_indices); + + int batch_size = in_matched_indices->dims()[0]; + int prior_num = in_matched_indices->dims()[1]; + + auto match_indices = framework::EigenMatrix::From(*in_matched_indices); + + auto match_indices_et = + framework::EigenMatrix::From(*out_match_indices); + + auto match_dist = framework::EigenMatrix::From(*in_match_dist); + + const T* cls_loss = in_cls_loss->data(); + const T* loc_loss = nullptr; + if (in_loc_loss) { + loc_loss = in_loc_loss->data(); + } + + std::vector> all_neg_indices; + std::vector batch_starts = {0}; + for (int n = 0; n < batch_size; ++n) { + std::vector> loss_idx; + int neg_sel = 0; + for (int m = 0; m < prior_num; ++m) { + if (IsEligibleMining(mining_type, match_indices(n, m), match_dist(n, m), + neg_dist_threshold)) { + T loss = cls_loss[n * prior_num + m]; + if (mining_type == MiningType::kHardExample && loc_loss != nullptr) { + loss = cls_loss[n * prior_num + m] + loc_loss[n * prior_num + m]; + } + loss_idx.push_back(std::make_pair(loss, m)); + ++neg_sel; + } + } + + if (mining_type == MiningType::kMaxNegative) { + int num_pos = 0; + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) != -1) ++num_pos; + } + neg_sel = std::min(static_cast(num_pos * neg_pos_ratio), neg_sel); + } else if (mining_type == MiningType::kHardExample) { + neg_sel = std::min(sample_size, neg_sel); + } + + std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend); + std::set sel_indices; + std::vector neg_indices; + std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel, + std::inserter(sel_indices, sel_indices.begin()), + [](std::pair& l) -> int { + return static_cast(l.second); + }); + + if (mining_type == MiningType::kHardExample) { + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) > -1) { + if (sel_indices.find(m) == sel_indices.end()) { + match_indices_et(n, m) = -1; + } + } else { + if (sel_indices.find(m) != sel_indices.end()) { + neg_indices.push_back(m); + } + } + } + } else { + neg_indices.resize(sel_indices.size()); + std::copy(sel_indices.begin(), sel_indices.end(), neg_indices.begin()); + } + + all_neg_indices.push_back(neg_indices); + batch_starts.push_back(batch_starts.back() + neg_indices.size()); + } + + framework::LoD out_neg_indices_lod; + out_neg_indices_lod.emplace_back(batch_starts); + int neg_offset = 0; + auto neg_data = out_neg_indices->mutable_data( + framework::make_ddim({static_cast(batch_starts.back()), 1}), + ctx.GetPlace()); + + for (auto neg_indices : all_neg_indices) { + std::copy(neg_indices.begin(), neg_indices.end(), neg_data + neg_offset); + neg_offset += neg_indices.size(); + } + out_neg_indices->set_lod(out_neg_indices_lod); + return; + } +}; + +class MineHardExamplesOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("ClsLoss"), + "Input(ClsLoss) of MineHardExamplesOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("MatchIndices"), + "Input(MatchIndices) of MineHardExamplesOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("MatchDist"), + "Input(MatchDist) of MineHardExamplesOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("NegIndices"), + "Output(NegIndices) of MineHardExamplesOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("UpdatedMatchIndices"), + "Output(UpdatedMatchIndices) of MineHardExamplesOp should " + "not be null."); + + auto cls_loss_dims = ctx->GetInputDim("ClsLoss"); + auto idx_dims = ctx->GetInputDim("MatchIndices"); + auto dis_dims = ctx->GetInputDim("MatchDist"); + + PADDLE_ENFORCE_EQ(cls_loss_dims.size(), 2UL, + "The shape of ClsLoss is [N, Np]."); + PADDLE_ENFORCE_EQ(idx_dims.size(), 2UL, + "The shape of MatchIndices is [N, Np]."); + PADDLE_ENFORCE_EQ(dis_dims.size(), 2UL, + "The shape of MatchDist is [N, Np]."); + + if (ctx->HasInput("LocLoss")) { + auto loc_loss_dims = ctx->GetInputDim("LocLoss"); + PADDLE_ENFORCE_EQ(loc_loss_dims.size(), 2UL, + "The shape of LocLoss is [N, Np]."); + PADDLE_ENFORCE_EQ(cls_loss_dims[0], loc_loss_dims[0], + "Batch size of ClsLoss and LocLoss must be the same."); + PADDLE_ENFORCE_EQ( + cls_loss_dims[1], loc_loss_dims[1], + "Prior box number of ClsLoss and LocLoss must be the same."); + } + + PADDLE_ENFORCE_EQ( + cls_loss_dims[0], idx_dims[0], + "Batch size of ClsLoss and MatchIndices must be the same."); + PADDLE_ENFORCE_EQ( + cls_loss_dims[1], idx_dims[1], + "Prior box number of ClsLoss and MatchIndices must be the same."); + + PADDLE_ENFORCE_EQ(cls_loss_dims[0], dis_dims[0], + "Batch size of ClsLoss and MatchDist must be the same."); + PADDLE_ENFORCE_EQ( + cls_loss_dims[1], idx_dims[1], + "Prior box number of ClsLoss and MatchDist must be the same."); + + auto mining_type = + GetMiningType(ctx->Attrs().Get("mining_type")); + + PADDLE_ENFORCE_NE(mining_type, MiningType::kNone, + "mining_type must be hard_example or max_negative"); + + if (mining_type == MiningType::kMaxNegative) { + auto neg_pos_ratio = ctx->Attrs().Get("neg_pos_ratio"); + auto neg_dist_threshold = ctx->Attrs().Get("neg_dist_threshold"); + PADDLE_ENFORCE_GT( + neg_pos_ratio, 0.0f, + "neg_pos_ratio must greater than zero in max_negative mode"); + PADDLE_ENFORCE_GT( + neg_dist_threshold, 0.0f, + "neg_dist_threshold must greater than zero in max_negative mode"); + } else if (mining_type == MiningType::kHardExample) { + auto sample_size = ctx->Attrs().Get("sample_size"); + PADDLE_ENFORCE_GT( + sample_size, 0, + "sample_size must greater than zero in hard_example mode"); + } + + ctx->SetOutputDim("UpdatedMatchIndices", idx_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("ClsLoss")->type()), + ctx.device_context()); + } +}; + +class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MineHardExamplesOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "ClsLoss", + "(Tensor, default Tensor), The classification loss with shape " + "[N, Np], N is the batch size and Np is the number of prior box."); + AddInput("LocLoss", + "(Tensor, optional, default Tensor), The localization loss " + "with shape [N, Np], N is the batch size and Np is the number of " + "prior box.") + .AsDispensable(); + AddInput("MatchIndices", + "(Tensor, Tensor), Matched indices with shape [N, Np], N is " + "the batch size and Np is the number of prior box. " + "MatchIndices[i][j] equal -1 means the j-th prior box in i-th " + "instance does not match any entity, otherwise means it is " + "matched to row."); + AddInput("MatchDist", + "(Tensor, default Tensor) Matched indices with shape [N, " + "Np], N is the batch size and Np is the number of prior box."); + AddAttr("neg_pos_ratio", + "(float) The ratio of the negative box to the positive " + "box. Use only when mining_type is max_negative.") + .SetDefault(1.0); + AddAttr("neg_dist_threshold", + "(float) The negative overlap upper bound for the unmatched " + "predictions. Use only when mining_type is max_negative.") + .SetDefault(0.5); + AddAttr("sample_size", + "(float) The max sample size of negative box. Use only when " + "mining_type is hard_example.") + .SetDefault(0); + AddAttr("mining_type", + "(float) The mining algorithm name, the value is " + "hard_example or max_negative.") + .SetDefault("max_negative") + .InEnum({"hard_example", "max_negative"}); + + AddOutput( + "NegIndices", + "(LoDTensor) The output of negative example indices. a LoDTensor " + "with shape [Neg, 1]. The size of lod[0] minus 1 is batch size, " + "and each element is the prior box index. " + "For example, the batch size is 2, the lod is [[0, 1, 2]], " + "the sample 0's box 1(MatchIndices[0][1]) is selected, " + "and sample 1's box 0 is selected. The output NegIndices is " + "[[1], [0]]."); + + AddOutput("UpdatedMatchIndices", + "(Tensor) The output of updated MatchIndices, a tensor with " + "shape [N, Np]. Only update when mining_type is " + "hard_example. The input MatchIndices elements will be update to " + "-1 when it is not in the candidate high loss list of negative " + "examples."); + + AddComment(R"DOC( +Mine hard examples Operator. +This operator implements hard example mining to select a subset of negative box indices. +For each image, selects the box with highest losses. subject to the condition that the +box cannot have an Matcht > neg_dist_threshold when mining_type is max_negative. +The selected number is min(sample_size, max_negative_box_number) when mining_type is +hard_example, or min(neg_pos_ratio * positive_box_number, max_negative_box_number) +when mining_type is max_negative, where the max_negative_box_number is the count of +MatchIndices elements with value -1. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(mine_hard_examples, ops::MineHardExamplesOp, + ops::MineHardExamplesOpMaker); + +REGISTER_OP_CPU_KERNEL( + mine_hard_examples, + ops::MineHardExamplesKernel, + ops::MineHardExamplesKernel); diff --git a/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py b/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py new file mode 100755 index 0000000000000000000000000000000000000000..c27573c3d69037bc48e0b6a90636b3f027f15a41 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py @@ -0,0 +1,100 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +class TestMineHardExamplesOp(OpTest): + def set_data(self): + self.init_test_data() + self.inputs = { + 'ClsLoss': self.cls_loss, + 'LocLoss': self.loc_loss, + 'MatchIndices': self.match_indices, + 'MatchDist': self.match_dis + } + + self.attrs = { + 'neg_pos_ratio': self.neg_pos_ratio, + 'neg_overlap': self.neg_overlap, + 'sample_size': self.sample_size, + 'mining_type': self.mining_type + } + + self.outputs = { + 'NegIndices': (self.neg_indices, self.neg_indices_lod), + 'UpdatedMatchIndices': self.updated_match_indices + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + return + + def setUp(self): + self.op_type = "mine_hard_examples" + self.set_data() + + def init_test_data(self): + self.neg_pos_ratio = 1.0 + self.neg_overlap = 0.5 + self.sample_size = 0 + self.mining_type = "max_negative" + self.cls_loss = np.array([[0.1, 0.1, 0.3], + [0.3, 0.1, 0.1]]).astype('float32') + + self.loc_loss = np.array([[0.1, 0.2, 0.3], + [0.3, 0.4, 0.1]]).astype('float32') + + self.match_dis = np.array([[0.2, 0.4, 0.8], + [0.1, 0.9, 0.3]]).astype('float32') + + self.match_indices = np.array([[0, -1, -1], + [-1, 0, -1]]).astype('int32') + + self.updated_match_indices = self.match_indices + + self.neg_indices_lod = [[0, 1, 2]] + self.neg_indices = np.array([[1], [0]]).astype('int32') + + +class TestMineHardExamplesOpHardExample(TestMineHardExamplesOp): + def init_test_data(self): + super(TestMineHardExamplesOpHardExample, self).init_test_data() + self.mining_type = "hard_example" + self.sample_size = 2 + + self.cls_loss = np.array([[0.5, 0.1, 0.3], + [0.3, 0.1, 0.1]]).astype('float32') + + self.loc_loss = np.array([[0.2, 0.2, 0.3], + [0.3, 0.1, 0.2]]).astype('float32') + + self.match_indices = np.array([[0, -1, -1], + [-1, 0, -1]]).astype('int32') + + self.updated_match_indices = np.array([[0, -1, -1], + [-1, -1, -1]]).astype('int32') + + self.neg_indices_lod = [[0, 1, 3]] + self.neg_indices = np.array([[2], [0], [2]]).astype('int32') + + +if __name__ == '__main__': + unittest.main()