Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3159eeca
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3159eeca
编写于
9月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(init): fix fan_in and fan_out for group conv2d
GitOrigin-RevId: a6f41063f081c06710dd0c157ff9794bae57bab9
上级
51c03f3e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
47 addition
and
11 deletion
+47
-11
imperative/python/megengine/module/init.py
imperative/python/megengine/module/init.py
+19
-10
imperative/python/test/unit/module/test_init.py
imperative/python/test/unit/module/test_init.py
+28
-1
未找到文件。
imperative/python/megengine/module/init.py
浏览文件 @
3159eeca
...
...
@@ -74,7 +74,7 @@ def calculate_gain(
)
->
float
:
r
"""Returns a recommended gain value (see the table below) for the given nonlinearity
function.
================= ====================================================
nonlinearity gain
================= ====================================================
...
...
@@ -126,6 +126,11 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
r
"""Calculates fan_in / fan_out value for given weight tensor. This function assumes
input tensor is stored in ``NCHW`` format.
Note:
The group conv2d kernel shape in MegEngine is ``(G, O/G, I/G, K, K)``. This
function calculates ``fan_out = O/G * K * K`` as default, but PyTorch uses
``fan_out = O * K * K``.
Args:
tensor: weight tensor in ``NCHW`` format.
"""
...
...
@@ -141,6 +146,10 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
fan_in
=
shape
[
1
]
fan_out
=
shape
[
0
]
else
:
if
ndim
>=
5
:
# ignore the groups dimension of group conv2d and group conv3d
# FIXME: will be wrong for conv3d
shape
=
shape
[
1
:]
num_input_fmaps
=
shape
[
1
]
num_output_fmaps
=
shape
[
0
]
receptive_field_size
=
1
...
...
@@ -154,7 +163,7 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
def
calculate_correct_fan
(
tensor
:
Tensor
,
mode
:
str
)
->
float
:
r
"""Calculates fan_in / fan_out value for given weight tensor, depending on given
``mode``.
See :func:`calculate_fan_in_and_fan_out` for details.
Args:
...
...
@@ -175,11 +184,11 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
def
xavier_uniform_
(
tensor
:
Tensor
,
gain
:
float
=
1.0
)
->
None
:
r
"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)`
where
.. math::
a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010).
...
...
@@ -197,11 +206,11 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
def
xavier_normal_
(
tensor
:
Tensor
,
gain
:
float
=
1.0
)
->
None
:
r
"""Fills tensor with random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010).
...
...
@@ -220,11 +229,11 @@ def msra_uniform_(
)
->
None
:
r
"""Fills tensor wilth random values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification`
...
...
@@ -251,11 +260,11 @@ def msra_normal_(
)
->
None
:
r
"""Fills tensor wilth random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification`
...
...
imperative/python/test/unit/module/test_init.py
浏览文件 @
3159eeca
...
...
@@ -10,7 +10,7 @@ import numpy as np
import
pytest
from
megengine
import
tensor
from
megengine.module
import
Conv
2
d
,
Linear
from
megengine.module
import
Conv
1d
,
Conv2d
,
Conv3
d
,
Linear
from
megengine.module.init
import
calculate_fan_in_and_fan_out
,
fill_
...
...
@@ -32,7 +32,34 @@ def test_calculate_fan_in_and_fan_out():
with
pytest
.
raises
(
ValueError
):
calculate_fan_in_and_fan_out
(
l
.
bias
)
l
=
Conv1d
(
in_channels
=
2
,
out_channels
=
3
,
kernel_size
=
5
)
fanin
,
fanout
=
calculate_fan_in_and_fan_out
(
l
.
weight
)
assert
fanin
==
2
*
5
assert
fanout
==
3
*
5
# FIXME: will be wrong for group conv1d
# l = Conv1d(in_channels=2, out_channels=4, kernel_size=5, groups=2)
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 // 2 * 5
# assert fanout == 4 // 2 * 5
l
=
Conv2d
(
in_channels
=
2
,
out_channels
=
3
,
kernel_size
=
(
5
,
7
))
fanin
,
fanout
=
calculate_fan_in_and_fan_out
(
l
.
weight
)
assert
fanin
==
2
*
5
*
7
assert
fanout
==
3
*
5
*
7
l
=
Conv2d
(
in_channels
=
2
,
out_channels
=
4
,
kernel_size
=
(
5
,
7
),
groups
=
2
)
fanin
,
fanout
=
calculate_fan_in_and_fan_out
(
l
.
weight
)
assert
fanin
==
2
//
2
*
5
*
7
assert
fanout
==
4
//
2
*
5
*
7
# FIXME: will be wrong for conv3d
# l = Conv3d(in_channels=2, out_channels=3, kernel_size=(5, 7, 9))
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 * 5 * 7 * 9
# assert fanout == 3 * 5 * 7 * 9
l
=
Conv3d
(
in_channels
=
2
,
out_channels
=
4
,
kernel_size
=
(
5
,
7
,
9
),
groups
=
2
)
fanin
,
fanout
=
calculate_fan_in_and_fan_out
(
l
.
weight
)
assert
fanin
==
2
//
2
*
5
*
7
*
9
assert
fanout
==
4
//
2
*
5
*
7
*
9
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录