Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a640eeff
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a640eeff
编写于
10月 09, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge): remove batched_nms
GitOrigin-RevId: 01f9ee137ccdb3dd5bc8f3662174531a18f6222b
上级
205a39b0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
15 addition
and
98 deletion
+15
-98
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+15
-79
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+0
-19
未找到文件。
imperative/python/megengine/functional/nn.py
浏览文件 @
a640eeff
...
...
@@ -30,7 +30,6 @@ __all__ = [
"adaptive_avg_pool2d"
,
"adaptive_max_pool2d"
,
"avg_pool2d"
,
"batched_nms"
,
"batch_norm2d"
,
"conv2d"
,
"conv_transpose2d"
,
...
...
@@ -391,14 +390,14 @@ def softplus(inp: Tensor) -> Tensor:
.. math::
\text{softplus}(x) = \log(1 + \exp(x))
softplus is a smooth approximation to the ReLU function and can be used
to constrain the output to be always positive.
For numerical stability the implementation follows this transformation:
.. math::
\text{softplus}(x) = \log(1 + \exp(x))
= \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
\text{softplus}(x) = \log(1 + \exp(x))
= \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
= \log1p(\exp(-\text{abs}(x))) + \text{relu}(x)
:param inp: input tensor.
...
...
@@ -414,9 +413,9 @@ def softplus(inp: Tensor) -> Tensor:
x = tensor(np.arange(-3, 3, dtype=np.float32))
y = F.softplus(x)
print(y.numpy())
Outputs:
.. testoutput::
[0.0486 0.1269 0.3133 0.6931 1.3133 2.1269]
...
...
@@ -435,11 +434,11 @@ def log_softmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
For numerical stability the implementation follows this transformation:
.. math::
\operatorname{logsoftmax}(x)
\operatorname{logsoftmax}(x)
= \log (\frac{\exp (x)}{\sum_{i}(\exp (x_{i}))})
= x - \log (\sum_{i}(\exp (x_{i})))
= x - logsumexp(x)
:param inp: input tensor.
:param axis: axis along which log_softmax will be applied.
...
...
@@ -456,7 +455,7 @@ def log_softmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
print(y.numpy())
Outputs:
.. testoutput::
[[-4.4519 -3.4519 -2.4519 -1.4519 -0.4519]
...
...
@@ -505,9 +504,9 @@ def logsumexp(
)
->
Tensor
:
r
"""
Calculates the logarithm of the inputs' exponential sum along the given :attr:`axis`.
.. math::
\operatorname{logsumexp}(\boldsymbol{x})= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
For numerical stability, the implementation follows this transformation:
...
...
@@ -516,7 +515,7 @@ def logsumexp(
\operatorname{logsumexp}(\boldsymbol{x})= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
= \operatorname{logsumexp}(\boldsymbol{x})=b+\log \sum_{j=1}^{n} \exp \left(x_{j}-b\right)
where
.. math::
...
...
@@ -527,7 +526,7 @@ def logsumexp(
:param keepdims: whether to retain :attr:`axis` or not for the output tensor.
Examples:
.. testcode::
import numpy as np
...
...
@@ -1080,7 +1079,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
Outputs:
.. testoutput::
[7.3485 1. ]
"""
...
...
@@ -1471,7 +1470,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
:param iou_thresh: IoU threshold for overlapping.
:param scores: tensor of shape `(N,)`, the score of boxes.
:return: indices of the elements that have been kept by NMS.
Examples:
.. testcode::
...
...
@@ -1492,7 +1491,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
Outputs:
.. testoutput::
[75 69]
"""
...
...
@@ -1518,69 +1517,6 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
return
keep_inds
def
batched_nms
(
boxes
:
Tensor
,
scores
:
Tensor
,
idxs
:
Tensor
,
iou_thresh
:
float
,
)
->
Tensor
:
r
"""
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU).
:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
:param iou_thresh: ``IoU`` threshold for overlapping.
:param idxs: tensor of shape `(N,)`, the class indexs of boxes in the batch.
:param scores: tensor of shape `(N,)`, the score of boxes.
:return: indices of the elements that have been kept by NMS.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = np.zeros((100,4))
np.random.seed(42)
x[:,:2] = np.random.rand(100,2)*20
x[:,2:] = np.random.rand(100,2)*20 + 100
scores = tensor(np.random.rand(100))
idxs = tensor(np.random.randint(0, 10, 100))
inp = tensor(x)
result = F.batched_nms(inp, scores, idxs, iou_thresh=0.6)
print(result.numpy())
Outputs:
.. testoutput::
[75 41 99 98 69 64 11 27 35 18]
"""
assert
(
boxes
.
ndim
==
2
and
boxes
.
shape
[
1
]
==
4
),
"the expected shape of boxes is (N, 4)"
assert
scores
.
ndim
==
1
,
"the expected shape of scores is (N,)"
assert
idxs
.
ndim
==
1
,
"the expected shape of idxs is (N,)"
assert
boxes
.
shape
[
0
]
==
scores
.
shape
[
0
]
==
idxs
.
shape
[
0
]
boxes
=
boxes
.
detach
()
scores
=
scores
.
detach
()
idxs
=
idxs
.
detach
()
max_coordinate
=
boxes
.
max
()
offsets
=
idxs
.
astype
(
"float32"
)
*
(
max_coordinate
+
1
)
boxes
=
boxes
+
offsets
.
reshape
(
-
1
,
1
).
broadcast
(
boxes
.
shape
[
0
],
4
)
sorted_idx
=
argsort
(
scores
,
descending
=
True
)
boxes
=
boxes
[
sorted_idx
]
max_output
=
boxes
.
shape
[
0
]
op
=
builtin
.
NMSKeep
(
iou_thresh
,
max_output
)
inp
=
utils
.
convert_inputs
(
boxes
.
reshape
(
1
,
-
1
,
4
))
indices
,
count
=
apply
(
op
,
*
inp
)
indices
=
indices
[
0
][:
count
.
item
()]
keep_inds
=
sorted_idx
[
indices
]
return
keep_inds
from
.loss
import
*
# isort:skip
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
a640eeff
...
...
@@ -361,25 +361,6 @@ def test_nms():
np
.
testing
.
assert_equal
(
result
.
numpy
(),
np
.
array
([
2
,
1
,
3
],
dtype
=
np
.
int32
))
def
test_batched_nms
():
x
=
np
.
array
(
[
[
0
,
0
,
100
,
100
],
[
0.5
,
0.5
,
1.5
,
1.5
],
[
20
,
20
,
100
,
100
],
[
0.5
,
0.5
,
1.0
,
1.0
],
[
10
,
10
,
100
,
100
],
[
0.5
,
0.5
,
1.0
,
1.0
],
],
dtype
=
np
.
float32
,
)
inp
=
tensor
(
x
)
scores
=
tensor
([
0.6
,
0.9
,
0.5
,
0.6
,
0.8
,
0.7
],
dtype
=
np
.
float32
)
idxs
=
tensor
([
0
,
1
,
0
,
1
,
0
,
1
],
dtype
=
np
.
int32
)
results
=
F
.
batched_nms
(
inp
,
scores
=
scores
,
idxs
=
idxs
,
iou_thresh
=
0.5
)
np
.
testing
.
assert_equal
(
results
.
numpy
(),
np
.
array
([
1
,
4
,
5
],
dtype
=
np
.
int32
))
@
pytest
.
mark
.
skip
(
reason
=
"cuda does not support nchw int8"
)
def
test_conv_bias
():
inp_scale
=
1.5
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录