Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2bd84d67
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2bd84d67
编写于
9月 28, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge): add adaptive pooling python wrapper
GitOrigin-RevId: 789f1511ec76e41bfb7cd8e6430da527af288570
上级
edb32495
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
220 addition
and
1 deletion
+220
-1
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+45
-1
imperative/python/megengine/module/__init__.py
imperative/python/megengine/module/__init__.py
+1
-0
imperative/python/megengine/module/adaptive_pooling.py
imperative/python/megengine/module/adaptive_pooling.py
+114
-0
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+60
-0
未找到文件。
imperative/python/megengine/functional/nn.py
浏览文件 @
2bd84d67
...
...
@@ -13,7 +13,7 @@ from ..core._imperative_rt import CompNode
from
..core.ops
import
builtin
from
..core.ops._internal
import
param_defs
as
P
from
..core.ops.special
import
Const
from
..core.tensor
import
utils
from
..core.tensor
import
megbrain_graph
,
utils
from
..core.tensor.core
import
TensorBase
,
TensorWrapperBase
,
apply
from
..core.tensor.utils
import
astensor1d
from
..distributed
import
WORLD
,
is_distributed
...
...
@@ -27,6 +27,8 @@ from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshap
from
.types
import
_pair
,
_pair_nonzero
__all__
=
[
"adaptive_avg_pool2d"
,
"adaptive_max_pool2d"
,
"avg_pool2d"
,
"batched_nms"
,
"batch_norm2d"
,
...
...
@@ -324,6 +326,48 @@ def avg_pool2d(
return
output
def
adaptive_max_pool2d
(
inp
:
Tensor
,
oshp
:
Union
[
Tuple
[
int
,
int
],
int
,
Tensor
],
)
->
Tensor
:
"""Applies a 2D max adaptive pooling over an input.
Refer to :class:`~.MaxAdaptivePool2d` for more information.
:param inp: The input tensor.
:param oshp: (OH, OW) size of the output shape.
:return: output tensor.
"""
assert
isinstance
(
inp
,
(
Tensor
,
megbrain_graph
.
VarNode
)),
"inp must be Tensor type"
if
isinstance
(
oshp
,
int
):
oshp
=
(
oshp
,
oshp
)
op
=
builtin
.
AdaptivePooling
(
mode
=
"MAX"
,
format
=
"NCHW"
,)
oshp
=
astensor1d
(
oshp
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
output
,)
=
apply
(
op
,
inp
,
oshp
)
return
output
def
adaptive_avg_pool2d
(
inp
:
Tensor
,
oshp
:
Union
[
Tuple
[
int
,
int
],
int
,
Tensor
],
)
->
Tensor
:
"""Applies a 2D average adaptive pooling over an input.
Refer to :class:`~.AvgAdaptivePool2d` for more information.
:param inp: The input tensor.
:param oshp: (OH, OW) size of the output shape.
:return: output tensor.
"""
assert
isinstance
(
inp
,
(
Tensor
,
megbrain_graph
.
VarNode
)),
"inp must be Tensor type"
if
isinstance
(
oshp
,
int
):
oshp
=
(
oshp
,
oshp
)
op
=
builtin
.
AdaptivePooling
(
mode
=
"AVERAGE"
,
format
=
"NCHW"
,)
oshp
=
astensor1d
(
oshp
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
output
,)
=
apply
(
op
,
inp
,
oshp
)
return
output
def
prelu
(
inp
:
Tensor
,
weight
:
Tensor
)
->
Tensor
:
r
"""
Applies the element-wise PReLU function.
...
...
imperative/python/megengine/module/__init__.py
浏览文件 @
2bd84d67
...
...
@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.activation
import
LeakyReLU
,
PReLU
,
ReLU
,
Sigmoid
,
Softmax
from
.adaptive_pooling
import
AdaptiveAvgPool2d
,
AdaptiveMaxPool2d
from
.batchnorm
import
BatchNorm1d
,
BatchNorm2d
,
SyncBatchNorm
from
.concat
import
Concat
from
.conv
import
Conv2d
,
ConvRelu2d
,
ConvTranspose2d
,
LocalConv2d
...
...
imperative/python/megengine/module/adaptive_pooling.py
0 → 100644
浏览文件 @
2bd84d67
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
abc
import
abstractmethod
from
typing
import
Tuple
,
Union
from
..functional
import
adaptive_avg_pool2d
,
adaptive_max_pool2d
from
..tensor
import
Parameter
,
Tensor
from
.module
import
Module
class
_AdaptivePoolNd
(
Module
):
def
__init__
(
self
,
oshp
:
Union
[
Tuple
[
int
,
int
],
int
,
Tensor
],
):
super
(
_AdaptivePoolNd
,
self
).
__init__
()
self
.
oshp
=
oshp
@
abstractmethod
def
forward
(
self
,
inp
):
pass
class
AdaptiveMaxPool2d
(
_AdaptivePoolNd
):
r
"""Applies a 2D max adaptive pooling over an input.
For instance, given an input of the size :math:`(N, C, H, W)` and
an output shape :math:`(OH, OW)`, this layer generates the output of
the size :math:`(N, C, OH, OW)` through a process described as:
.. math::
\begin{aligned}
out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1}
\text{input}(N_i, C_j, \text{stride[0]} \times h + m,
\text{stride[1]} \times w + n)
\end{aligned}
Kernel_size and stride can be inferred from input shape and out shape:
padding: (0, 0)
stride: (floor(IH / OH), floor(IW / OW))
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w)
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
m = M.AdaptiveMaxPool2d((2, 2))
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4))
oup = m(inp)
print(oup.numpy())
Outputs:
.. testoutput::
[[[[5. 7.]
[13. 15.]]]]
"""
def
forward
(
self
,
inp
):
return
adaptive_max_pool2d
(
inp
,
self
.
oshp
)
class
AdaptiveAvgPool2d
(
_AdaptivePoolNd
):
r
"""Applies a 2D average pooling over an input.
For instance, given an input of the size :math:`(N, C, H, W)` and
an output shape :math:`(OH, OW)`, this layer generates the output of
the size :math:`(N, C, OH, OW)` through a process described as:
.. math::
out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n)
Kernel_size and stride can be inferred from input shape and out shape:
padding: (0, 0)
stride: (floor(IH / OH), floor(IW / OW))
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w)
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
m = M.AdaptiveAvgPool2d((2, 2))
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4))
oup = m(inp)
print(oup.numpy())
Outputs:
.. testoutput::
[[[[2.5 4.5]
[10.5 12.5]]]]
"""
def
forward
(
self
,
inp
):
return
adaptive_avg_pool2d
(
inp
,
self
.
oshp
)
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
2bd84d67
...
...
@@ -206,6 +206,66 @@ def test_roi_pooling():
assert
make_shape_tuple
(
inp_feat
.
grad
.
shape
)
==
make_shape_tuple
(
inp_feat
.
shape
)
def
test_adaptive_avg_pool2d
():
inp
=
tensor
(
np
.
arange
(
0
,
16
,
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
4
,
4
))
oshp
=
(
2
,
2
)
grad
=
Grad
().
wrt
(
inp
,
callback
=
_save_to
(
inp
))
outp
=
F
.
adaptive_avg_pool2d
(
inp
,
oshp
,)
assert
make_shape_tuple
(
outp
.
shape
)
==
(
inp
.
shape
[
0
],
inp
.
shape
[
1
],
*
oshp
,)
np
.
testing
.
assert_equal
(
outp
.
numpy
(),
np
.
array
([[[[
2.5
,
4.5
],
[
10.5
,
12.5
]]]],
dtype
=
np
.
float32
)
)
grad
(
outp
,
tensor
(
F
.
ones_like
(
outp
)))
assert
make_shape_tuple
(
inp
.
grad
.
shape
)
==
make_shape_tuple
(
inp
.
shape
)
np
.
testing
.
assert_equal
(
inp
.
grad
.
numpy
(),
np
.
array
(
[
[
[
[
0.25
,
0.25
,
0.25
,
0.25
],
[
0.25
,
0.25
,
0.25
,
0.25
],
[
0.25
,
0.25
,
0.25
,
0.25
],
[
0.25
,
0.25
,
0.25
,
0.25
],
]
]
],
dtype
=
np
.
float32
,
),
)
def
test_adaptive_max_pool2d
():
inp
=
tensor
(
np
.
arange
(
0
,
16
,
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
4
,
4
))
oshp
=
(
2
,
2
)
grad
=
Grad
().
wrt
(
inp
,
callback
=
_save_to
(
inp
))
outp
=
F
.
adaptive_max_pool2d
(
inp
,
oshp
,)
assert
make_shape_tuple
(
outp
.
shape
)
==
(
inp
.
shape
[
0
],
inp
.
shape
[
1
],
*
oshp
,)
np
.
testing
.
assert_equal
(
outp
.
numpy
(),
np
.
array
([[[[
5
,
7
],
[
13
,
15
]]]],
dtype
=
np
.
float32
)
)
grad
(
outp
,
tensor
(
F
.
ones_like
(
outp
)))
assert
make_shape_tuple
(
inp
.
grad
.
shape
)
==
make_shape_tuple
(
inp
.
shape
)
np
.
testing
.
assert_equal
(
inp
.
grad
.
numpy
(),
np
.
array
(
[
[
[
[
0.0
,
0.0
,
0.0
,
0.0
],
[
0.0
,
1.0
,
0.0
,
1.0
],
[
0.0
,
0.0
,
0.0
,
0.0
],
[
0.0
,
1.0
,
0.0
,
1.0
],
]
]
],
dtype
=
np
.
float32
,
),
)
def
test_one_hot
():
def
onehot_low_dimension
():
inp
=
tensor
(
np
.
arange
(
1
,
4
,
dtype
=
np
.
int32
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录