Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
af349d61
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
af349d61
编写于
9月 29, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/functional): fix op mismatch when tracing NMSKeep
GitOrigin-RevId: e8f2cbb7557b7482df936faca80f4fcc15eef22b
上级
d502e79f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
44 addition
and
3 deletion
+44
-3
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+15
-3
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+7
-0
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+22
-0
未找到文件。
imperative/python/megengine/functional/nn.py
浏览文件 @
af349d61
...
@@ -17,6 +17,7 @@ from ..core.tensor import megbrain_graph, utils
...
@@ -17,6 +17,7 @@ from ..core.tensor import megbrain_graph, utils
from
..core.tensor.core
import
TensorBase
,
TensorWrapperBase
,
apply
from
..core.tensor.core
import
TensorBase
,
TensorWrapperBase
,
apply
from
..core.tensor.utils
import
astensor1d
from
..core.tensor.utils
import
astensor1d
from
..distributed
import
WORLD
,
is_distributed
from
..distributed
import
WORLD
,
is_distributed
from
..jit.tracing
import
is_tracing
from
..random
import
uniform
from
..random
import
uniform
from
..tensor
import
Tensor
from
..tensor
import
Tensor
from
.debug_param
import
get_conv_execution_strategy
from
.debug_param
import
get_conv_execution_strategy
...
@@ -1470,13 +1471,17 @@ def indexing_one_hot(
...
@@ -1470,13 +1471,17 @@ def indexing_one_hot(
return
result
return
result
def
nms
(
boxes
:
Tensor
,
scores
:
Tensor
,
iou_thresh
:
float
)
->
Tensor
:
def
nms
(
boxes
:
Tensor
,
scores
:
Tensor
,
iou_thresh
:
float
,
max_output
:
Optional
[
int
]
=
None
)
->
Tensor
:
r
"""
r
"""
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).
:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
:param iou_thresh: IoU threshold for overlapping.
:param iou_thresh: IoU threshold for overlapping.
:param scores: tensor of shape `(N,)`, the score of boxes.
:param scores: tensor of shape `(N,)`, the score of boxes.
:param max_output: the maximum number of boxes to keep; it is optional if this operator is not traced
otherwise it required to be specified; if it is not specified, all boxes are kept.
:return: indices of the elements that have been kept by NMS.
:return: indices of the elements that have been kept by NMS.
Examples:
Examples:
...
@@ -1515,12 +1520,19 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
...
@@ -1515,12 +1520,19 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
scores
=
scores
.
detach
()
scores
=
scores
.
detach
()
sorted_idx
=
argsort
(
scores
,
descending
=
True
)
sorted_idx
=
argsort
(
scores
,
descending
=
True
)
boxes
=
boxes
[
sorted_idx
]
boxes
=
boxes
[
sorted_idx
]
if
is_tracing
():
assert
(
max_output
is
not
None
and
max_output
>
0
),
"max_output should be specified under tracing"
if
max_output
is
None
:
max_output
=
boxes
.
shape
[
0
]
max_output
=
boxes
.
shape
[
0
]
op
=
builtin
.
NMSKeep
(
iou_thresh
,
max_output
)
op
=
builtin
.
NMSKeep
(
iou_thresh
,
max_output
)
inp
=
utils
.
convert_inputs
(
boxes
.
reshape
(
1
,
-
1
,
4
))
inp
=
utils
.
convert_inputs
(
boxes
.
reshape
(
1
,
-
1
,
4
))
indices
,
count
=
apply
(
op
,
*
inp
)
indices
,
count
=
apply
(
op
,
*
inp
)
indices
=
indices
[
0
][:
count
.
item
()
]
indices
=
indices
[
0
][:
count
[
0
]
]
keep_inds
=
sorted_idx
[
indices
]
keep_inds
=
sorted_idx
[
indices
]
return
keep_inds
return
keep_inds
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
af349d61
...
@@ -36,6 +36,13 @@ active_trace = None
...
@@ -36,6 +36,13 @@ active_trace = None
skip_tracing
=
False
skip_tracing
=
False
def
is_tracing
():
if
active_trace
is
None
:
return
False
else
:
return
not
skip_tracing
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
exclude_from_trace
():
def
exclude_from_trace
():
global
skip_tracing
global
skip_tracing
...
...
imperative/python/test/unit/test_tracing.py
浏览文件 @
af349d61
...
@@ -357,3 +357,25 @@ def test_trace_broadcast():
...
@@ -357,3 +357,25 @@ def test_trace_broadcast():
f
(
x1
)
f
(
x1
)
f
(
x2
)
f
(
x2
)
f
(
x3
)
f
(
x3
)
def
test_trace_nms
():
def
make_inputs
(
n
):
boxes
=
np
.
zeros
((
n
,
4
))
boxes
[:,
:
2
]
=
np
.
random
.
rand
(
n
,
2
)
*
100
boxes
[:,
2
:]
=
np
.
random
.
rand
(
n
,
2
)
*
100
+
100
scores
=
np
.
random
.
rand
(
n
)
return
tensor
(
boxes
),
tensor
(
scores
)
@
trace
(
symbolic
=
False
)
def
f
(
boxes
,
scores
):
results
=
F
.
nn
.
nms
(
boxes
,
scores
=
scores
,
iou_thresh
=
0.5
,
max_output
=
20
)
with
exclude_from_trace
():
_
=
F
.
nn
.
nms
(
boxes
,
scores
=
scores
,
iou_thresh
=
0.5
)
return
results
f
(
*
make_inputs
(
10
))
f
(
*
make_inputs
(
20
))
f
(
*
make_inputs
(
30
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录