提交 3159eeca 编写于 作者: M Megvii Engine Team

fix(init): fix fan_in and fan_out for group conv2d

GitOrigin-RevId: a6f41063f081c06710dd0c157ff9794bae57bab9
上级 51c03f3e
......@@ -74,7 +74,7 @@ def calculate_gain(
) -> float:
r"""Returns a recommended gain value (see the table below) for the given nonlinearity
function.
================= ====================================================
nonlinearity gain
================= ====================================================
......@@ -126,6 +126,11 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes
input tensor is stored in ``NCHW`` format.
Note:
The group conv2d kernel shape in MegEngine is ``(G, O/G, I/G, K, K)``. This
function calculates ``fan_out = O/G * K * K`` as default, but PyTorch uses
``fan_out = O * K * K``.
Args:
tensor: weight tensor in ``NCHW`` format.
"""
......@@ -141,6 +146,10 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
fan_in = shape[1]
fan_out = shape[0]
else:
if ndim >= 5:
# ignore the groups dimension of group conv2d and group conv3d
# FIXME: will be wrong for conv3d
shape = shape[1:]
num_input_fmaps = shape[1]
num_output_fmaps = shape[0]
receptive_field_size = 1
......@@ -154,7 +163,7 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
r"""Calculates fan_in / fan_out value for given weight tensor, depending on given
``mode``.
See :func:`calculate_fan_in_and_fan_out` for details.
Args:
......@@ -175,11 +184,11 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)`
where
.. math::
a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010).
......@@ -197,11 +206,11 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fills tensor with random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010).
......@@ -220,11 +229,11 @@ def msra_uniform_(
) -> None:
r"""Fills tensor wilth random values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification`
......@@ -251,11 +260,11 @@ def msra_normal_(
) -> None:
r"""Fills tensor wilth random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification`
......
......@@ -10,7 +10,7 @@ import numpy as np
import pytest
from megengine import tensor
from megengine.module import Conv2d, Linear
from megengine.module import Conv1d, Conv2d, Conv3d, Linear
from megengine.module.init import calculate_fan_in_and_fan_out, fill_
......@@ -32,7 +32,34 @@ def test_calculate_fan_in_and_fan_out():
with pytest.raises(ValueError):
calculate_fan_in_and_fan_out(l.bias)
l = Conv1d(in_channels=2, out_channels=3, kernel_size=5)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5
assert fanout == 3 * 5
# FIXME: will be wrong for group conv1d
# l = Conv1d(in_channels=2, out_channels=4, kernel_size=5, groups=2)
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 // 2 * 5
# assert fanout == 4 // 2 * 5
l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7))
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5 * 7
assert fanout == 3 * 5 * 7
l = Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 7), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7
assert fanout == 4 // 2 * 5 * 7
# FIXME: will be wrong for conv3d
# l = Conv3d(in_channels=2, out_channels=3, kernel_size=(5, 7, 9))
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 * 5 * 7 * 9
# assert fanout == 3 * 5 * 7 * 9
l = Conv3d(in_channels=2, out_channels=4, kernel_size=(5, 7, 9), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7 * 9
assert fanout == 4 // 2 * 5 * 7 * 9
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册