未验证 提交 778b71fc 编写于 作者: B baiyf 提交者: GitHub

Optimize bipartite_match_op in large scale input (#11730)

* optimize bipartite_match_op in large scale input
上级 c2289777
......@@ -51,6 +51,12 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
}
};
template <class T>
bool DistPairDescend(std::tuple<int, int, T> pair1,
std::tuple<int, int, T> pair2) {
return std::get<2>(pair1) > std::get<2>(pair2);
}
template <typename T>
class BipartiteMatchKernel : public framework::OpKernel<T> {
public:
......@@ -58,11 +64,40 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
// The match_dist must be initialized to 0 at first.
void BipartiteMatch(const Tensor& dist, int* match_indices,
T* match_dist) const {
constexpr T kEPS = static_cast<T>(1e-6);
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
int64_t row = dist.dims()[0];
int64_t col = dist.dims()[1];
auto* dist_data = dist.data<T>();
// Test result: When row==130 the speed of these two methods almost the same
if (row >= 130) {
std::vector<std::tuple<int, int, T>> match_pair;
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
match_pair.push_back(std::make_tuple(i, j, dist_data[i * col + j]));
}
}
std::sort(match_pair.begin(), match_pair.end(), DistPairDescend<T>);
std::vector<int> row_indices(row, -1);
int64_t idx = 0;
for (int64_t k = 0; k < row * col; ++k) {
int64_t i = std::get<0>(match_pair[k]);
int64_t j = std::get<1>(match_pair[k]);
T dist = std::get<2>(match_pair[k]);
if (idx >= row) {
break;
}
if (match_indices[j] == -1 && row_indices[i] == -1 && dist > 0) {
match_indices[j] = i;
row_indices[i] = j;
match_dist[j] = dist;
idx += 1;
}
}
} else {
constexpr T kEPS = static_cast<T>(1e-6);
std::vector<int> row_pool;
for (int i = 0; i < row; ++i) {
row_pool.push_back(i);
......@@ -101,6 +136,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
}
}
}
}
void ArgMaxMatch(const Tensor& dist, int* match_indices, T* match_dist,
T overlap_threshold) const {
......
......@@ -114,6 +114,23 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
self.check_output()
class TestBipartiteMatchOpWithoutLoDLargeScaleInput(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[300]]
dist = np.random.random((300, 17)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': dist}
self.outputs = {
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_dist,
}
def test_check_output(self):
self.check_output()
class TestBipartiteMatchOpWithPerPredictionType(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册