未验证 提交 4948f7b3 编写于 作者: Q qingqing01 提交者: GitHub

Enhance bipartite_match_op to support argmax matching after bipartite matching. (#8580)

* Enhance bipartite_match_op to support argmax matching after bipartite matching.

* Fix typo error.
上级 dce0383f
...@@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
} }
} }
void ArgMaxMatch(const Tensor& dist, int* match_indices, T* match_dist,
T overlap_threshold) const {
constexpr T kEPS = static_cast<T>(1e-6);
int64_t row = dist.dims()[0];
int64_t col = dist.dims()[1];
auto* dist_data = dist.data<T>();
for (int64_t j = 0; j < col; ++j) {
if (match_indices[j] != -1) {
// the j-th column has been matched to one entity.
continue;
}
int max_row_idx = -1;
T max_dist = -1;
for (int i = 0; i < row; ++i) {
T dist = dist_data[i * col + j];
if (dist < kEPS) {
// distance is 0 between m-th row and j-th column
continue;
}
if (dist >= overlap_threshold && dist > max_dist) {
max_row_idx = i;
max_dist = dist;
}
}
if (max_row_idx != -1) {
PADDLE_ENFORCE_EQ(match_indices[j], -1);
match_indices[j] = max_row_idx;
match_dist[j] = max_dist;
}
}
}
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dist_mat = context.Input<LoDTensor>("DistMat"); auto* dist_mat = context.Input<LoDTensor>("DistMat");
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices"); auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
...@@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
int* indices = match_indices->data<int>(); int* indices = match_indices->data<int>();
T* dist = match_dist->data<T>(); T* dist = match_dist->data<T>();
auto type = context.Attr<std::string>("match_type");
auto threshold = context.Attr<float>("dist_threshold");
if (n == 1) { if (n == 1) {
BipartiteMatch(*dist_mat, indices, dist); BipartiteMatch(*dist_mat, indices, dist);
if (type == "per_prediction") {
ArgMaxMatch(*dist_mat, indices, dist, threshold);
}
} else { } else {
auto lod = dist_mat->lod().back(); auto lod = dist_mat->lod().back();
for (size_t i = 0; i < lod.size() - 1; ++i) { for (size_t i = 0; i < lod.size() - 1; ++i) {
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]); Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
BipartiteMatch(one_ins, indices + i * col, dist + i * col); BipartiteMatch(one_ins, indices + i * col, dist + i * col);
if (type == "per_prediction") {
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold);
}
} }
} }
} }
...@@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
"This tensor can contain LoD information to represent a batch of " "This tensor can contain LoD information to represent a batch of "
"inputs. One instance of this batch can contain different numbers of " "inputs. One instance of this batch can contain different numbers of "
"entities."); "entities.");
AddAttr<std::string>(
"match_type",
"(string, defalut: per_prediction) "
"The type of matching method, should be 'bipartite' or "
"'per_prediction', 'bipartite' by defalut.")
.SetDefault("bipartite")
.InEnum({"bipartite", "per_prediction"});
AddAttr<float>(
"dist_threshold",
"(float, defalut: 0.5) "
"If `match_type` is 'per_prediction', this threshold is to determine "
"the extra matching bboxes based on the maximum distance.")
.SetDefault(0.5);
AddOutput("ColToRowMatchIndices", AddOutput("ColToRowMatchIndices",
"(Tensor) A 2-D Tensor with shape [N, M] in int type. " "(Tensor) A 2-D Tensor with shape [N, M] in int type. "
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it " "N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
...@@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can ...@@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can
find the matched column for each row, also can find the matched row for find the matched column for each row, also can find the matched row for
each column. And this operator only calculate matched indices from column each column. And this operator only calculate matched indices from column
to row. For each instance, the number of matched indices is the number of to row. For each instance, the number of matched indices is the number of
of columns of the input ditance matrix. of columns of the input distance matrix.
There are two outputs to save matched indices and distance. There are two outputs to save matched indices and distance.
A simple description, this algothrim matched the best (maximum distance) A simple description, this algorithm matched the best (maximum distance)
row entity to the column entity and the matched indices are not duplicated row entity to the column entity and the matched indices are not duplicated
in each row of ColToRowMatchIndices. If the column entity is not matched in each row of ColToRowMatchIndices. If the column entity is not matched
any row entity, set -1 in ColToRowMatchIndices. any row entity, set -1 in ColToRowMatchIndices.
......
...@@ -132,7 +132,10 @@ def detection_output(scores, ...@@ -132,7 +132,10 @@ def detection_output(scores,
return nmsed_outs return nmsed_outs
def bipartite_match(dist_matrix, name=None): def bipartite_match(dist_matrix,
match_type=None,
dist_threshold=None,
name=None):
""" """
**Bipartite matchint operator** **Bipartite matchint operator**
...@@ -164,6 +167,11 @@ def bipartite_match(dist_matrix, name=None): ...@@ -164,6 +167,11 @@ def bipartite_match(dist_matrix, name=None):
This tensor can contain LoD information to represent a batch of This tensor can contain LoD information to represent a batch of
inputs. One instance of this batch can contain different numbers of inputs. One instance of this batch can contain different numbers of
entities. entities.
match_type(string|None): The type of matching method, should be
'bipartite' or 'per_prediction', 'bipartite' by defalut.
dist_threshold(float|None): If `match_type` is 'per_prediction',
this threshold is to determine the extra matching bboxes based
on the maximum distance, 0.5 by defalut.
Returns: Returns:
match_indices(Variable): A 2-D Tensor with shape [N, M] in int type. match_indices(Variable): A 2-D Tensor with shape [N, M] in int type.
N is the batch size. If match_indices[i][j] is -1, it N is the batch size. If match_indices[i][j] is -1, it
...@@ -183,6 +191,10 @@ def bipartite_match(dist_matrix, name=None): ...@@ -183,6 +191,10 @@ def bipartite_match(dist_matrix, name=None):
helper.append_op( helper.append_op(
type='bipartite_match', type='bipartite_match',
inputs={'DistMat': dist_matrix}, inputs={'DistMat': dist_matrix},
attrs={
'match_type': match_type,
'dist_threshold': dist_threshold,
},
outputs={ outputs={
'ColToRowMatchIndices': match_indices, 'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_distance 'ColToRowMatchDist': match_distance
...@@ -333,7 +345,7 @@ def ssd_loss(location, ...@@ -333,7 +345,7 @@ def ssd_loss(location,
loc_loss_weight (float): Weight for localization loss, 1.0 by default. loc_loss_weight (float): Weight for localization loss, 1.0 by default.
conf_loss_weight (float): Weight for confidence loss, 1.0 by default. conf_loss_weight (float): Weight for confidence loss, 1.0 by default.
match_type (str): The type of matching method during training, should match_type (str): The type of matching method during training, should
be 'bipartite' or 'per_prediction'. be 'bipartite' or 'per_prediction', 'per_prediction' by defalut.
mining_type (str): The hard example mining type, should be 'hard_example' mining_type (str): The hard example mining type, should be 'hard_example'
or 'max_negative', now only support `max_negative`. or 'max_negative', now only support `max_negative`.
...@@ -381,7 +393,8 @@ def ssd_loss(location, ...@@ -381,7 +393,8 @@ def ssd_loss(location,
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
iou = iou_similarity(x=gt_box, y=prior_box) iou = iou_similarity(x=gt_box, y=prior_box)
# 1.2 Compute matched boundding box by bipartite matching algorithm. # 1.2 Compute matched boundding box by bipartite matching algorithm.
matched_indices, matched_dist = bipartite_match(iou) matched_indices, matched_dist = bipartite_match(iou, match_type,
overlap_threshold)
# 2. Compute confidence for mining hard examples # 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices # 2.1. Get the target label based on matched indices
......
...@@ -46,7 +46,20 @@ def bipartite_match(distance, match_indices, match_dist): ...@@ -46,7 +46,20 @@ def bipartite_match(distance, match_indices, match_dist):
idx += 1 idx += 1
def batch_bipartite_match(distance, lod): def argmax_match(distance, match_indices, match_dist, threshold):
r, c = distance.shape
for j in xrange(c):
if match_indices[j] != -1:
continue
col_dist = distance[:, j]
indices = np.argwhere(col_dist >= threshold).flatten()
if len(indices) < 1:
continue
match_indices[j] = indices[np.argmax(col_dist[indices])]
match_dist[j] = col_dist[match_indices[j]]
def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
"""Bipartite Matching algorithm for batch input. """Bipartite Matching algorithm for batch input.
Arg: Arg:
distance (numpy.array) : The distance of two entries with shape [M, N]. distance (numpy.array) : The distance of two entries with shape [M, N].
...@@ -59,6 +72,9 @@ def batch_bipartite_match(distance, lod): ...@@ -59,6 +72,9 @@ def batch_bipartite_match(distance, lod):
for i in range(len(lod) - 1): for i in range(len(lod) - 1):
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
match_dist[i, :]) match_dist[i, :])
if match_type == 'per_prediction':
argmax_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
match_dist[i, :], dist_threshold)
return match_indices, match_dist return match_indices, match_dist
...@@ -71,8 +87,8 @@ class TestBipartiteMatchOpWithLoD(OpTest): ...@@ -71,8 +87,8 @@ class TestBipartiteMatchOpWithLoD(OpTest):
self.inputs = {'DistMat': (dist, lod)} self.inputs = {'DistMat': (dist, lod)}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': (match_indices), 'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': (match_dist), 'ColToRowMatchDist': match_dist,
} }
def test_check_output(self): def test_check_output(self):
...@@ -96,5 +112,27 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): ...@@ -96,5 +112,27 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
self.check_output() self.check_output()
class TestBipartiteMatchOpWithPerPredictionType(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[0, 5, 11, 23]]
dist = np.random.random((23, 237)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0],
'per_prediction', 0.5)
self.inputs = {'DistMat': (dist, lod)}
self.outputs = {
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_dist,
}
self.attrs = {
'match_type': 'per_prediction',
'dist_threshold': 0.5,
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册