Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
778b71fc
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
778b71fc
编写于
6月 27, 2018
作者:
B
baiyf
提交者:
GitHub
6月 27, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize bipartite_match_op in large scale input (#11730)
* optimize bipartite_match_op in large scale input
上级
c2289777
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
84 addition
and
31 deletion
+84
-31
paddle/fluid/operators/detection/bipartite_match_op.cc
paddle/fluid/operators/detection/bipartite_match_op.cc
+67
-31
python/paddle/fluid/tests/unittests/test_bipartite_match_op.py
...n/paddle/fluid/tests/unittests/test_bipartite_match_op.py
+17
-0
未找到文件。
paddle/fluid/operators/detection/bipartite_match_op.cc
浏览文件 @
778b71fc
...
...
@@ -51,6 +51,12 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
}
};
template
<
class
T
>
bool
DistPairDescend
(
std
::
tuple
<
int
,
int
,
T
>
pair1
,
std
::
tuple
<
int
,
int
,
T
>
pair2
)
{
return
std
::
get
<
2
>
(
pair1
)
>
std
::
get
<
2
>
(
pair2
);
}
template
<
typename
T
>
class
BipartiteMatchKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -58,46 +64,76 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
// The match_dist must be initialized to 0 at first.
void
BipartiteMatch
(
const
Tensor
&
dist
,
int
*
match_indices
,
T
*
match_dist
)
const
{
constexpr
T
kEPS
=
static_cast
<
T
>
(
1e-6
);
PADDLE_ENFORCE_EQ
(
dist
.
dims
().
size
(),
2
,
"The rank of dist must be 2."
);
int64_t
row
=
dist
.
dims
()[
0
];
int64_t
col
=
dist
.
dims
()[
1
];
auto
*
dist_data
=
dist
.
data
<
T
>
();
std
::
vector
<
int
>
row_pool
;
for
(
int
i
=
0
;
i
<
row
;
++
i
)
{
row_pool
.
push_back
(
i
);
}
while
(
row_pool
.
size
()
>
0
)
{
int
max_idx
=
-
1
;
int
max_row_idx
=
-
1
;
T
max_dist
=
-
1
;
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
if
(
match_indices
[
j
]
!=
-
1
)
{
continue
;
// Test result: When row==130 the speed of these two methods almost the same
if
(
row
>=
130
)
{
std
::
vector
<
std
::
tuple
<
int
,
int
,
T
>>
match_pair
;
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
match_pair
.
push_back
(
std
::
make_tuple
(
i
,
j
,
dist_data
[
i
*
col
+
j
]));
}
for
(
size_t
k
=
0
;
k
<
row_pool
.
size
();
++
k
)
{
int
m
=
row_pool
[
k
];
// distance is 0 between m-th row and j-th column
if
(
dist_data
[
m
*
col
+
j
]
<
kEPS
)
{
}
std
::
sort
(
match_pair
.
begin
(),
match_pair
.
end
(),
DistPairDescend
<
T
>
);
std
::
vector
<
int
>
row_indices
(
row
,
-
1
);
int64_t
idx
=
0
;
for
(
int64_t
k
=
0
;
k
<
row
*
col
;
++
k
)
{
int64_t
i
=
std
::
get
<
0
>
(
match_pair
[
k
]);
int64_t
j
=
std
::
get
<
1
>
(
match_pair
[
k
]);
T
dist
=
std
::
get
<
2
>
(
match_pair
[
k
]);
if
(
idx
>=
row
)
{
break
;
}
if
(
match_indices
[
j
]
==
-
1
&&
row_indices
[
i
]
==
-
1
&&
dist
>
0
)
{
match_indices
[
j
]
=
i
;
row_indices
[
i
]
=
j
;
match_dist
[
j
]
=
dist
;
idx
+=
1
;
}
}
}
else
{
constexpr
T
kEPS
=
static_cast
<
T
>
(
1e-6
);
std
::
vector
<
int
>
row_pool
;
for
(
int
i
=
0
;
i
<
row
;
++
i
)
{
row_pool
.
push_back
(
i
);
}
while
(
row_pool
.
size
()
>
0
)
{
int
max_idx
=
-
1
;
int
max_row_idx
=
-
1
;
T
max_dist
=
-
1
;
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
if
(
match_indices
[
j
]
!=
-
1
)
{
continue
;
}
if
(
dist_data
[
m
*
col
+
j
]
>
max_dist
)
{
max_idx
=
j
;
max_row_idx
=
m
;
max_dist
=
dist_data
[
m
*
col
+
j
];
for
(
size_t
k
=
0
;
k
<
row_pool
.
size
();
++
k
)
{
int
m
=
row_pool
[
k
];
// distance is 0 between m-th row and j-th column
if
(
dist_data
[
m
*
col
+
j
]
<
kEPS
)
{
continue
;
}
if
(
dist_data
[
m
*
col
+
j
]
>
max_dist
)
{
max_idx
=
j
;
max_row_idx
=
m
;
max_dist
=
dist_data
[
m
*
col
+
j
];
}
}
}
}
if
(
max_idx
==
-
1
)
{
// Cannot find good match.
break
;
}
else
{
PADDLE_ENFORCE_EQ
(
match_indices
[
max_idx
],
-
1
)
;
match_indices
[
max_idx
]
=
max_row_idx
;
match_dist
[
max_idx
]
=
max_dist
;
// Erase the row index.
row_pool
.
erase
(
std
::
find
(
row_pool
.
begin
(),
row_pool
.
end
(),
max_row_idx
));
if
(
max_idx
==
-
1
)
{
// Cannot find good match.
break
;
}
else
{
PADDLE_ENFORCE_EQ
(
match_indices
[
max_idx
],
-
1
);
match_indices
[
max_idx
]
=
max_row_idx
;
match_dist
[
max_idx
]
=
max_dist
;
// Erase the row index.
row_pool
.
erase
(
std
::
find
(
row_pool
.
begin
(),
row_pool
.
end
(),
max_row_idx
));
}
}
}
}
...
...
python/paddle/fluid/tests/unittests/test_bipartite_match_op.py
浏览文件 @
778b71fc
...
...
@@ -114,6 +114,23 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
self
.
check_output
()
class
TestBipartiteMatchOpWithoutLoDLargeScaleInput
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
'bipartite_match'
lod
=
[[
300
]]
dist
=
np
.
random
.
random
((
300
,
17
)).
astype
(
'float32'
)
match_indices
,
match_dist
=
batch_bipartite_match
(
dist
,
lod
[
0
])
self
.
inputs
=
{
'DistMat'
:
dist
}
self
.
outputs
=
{
'ColToRowMatchIndices'
:
match_indices
,
'ColToRowMatchDist'
:
match_dist
,
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestBipartiteMatchOpWithPerPredictionType
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
'bipartite_match'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录