From 3159eecadd93f4803dea5c36fbd823ec344e63a6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 26 Sep 2021 17:23:32 +0800 Subject: [PATCH] fix(init): fix fan_in and fan_out for group conv2d GitOrigin-RevId: a6f41063f081c06710dd0c157ff9794bae57bab9 --- imperative/python/megengine/module/init.py | 29 ++++++++++++------- .../python/test/unit/module/test_init.py | 29 ++++++++++++++++++- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/imperative/python/megengine/module/init.py b/imperative/python/megengine/module/init.py index 848347552..2bf73fde5 100644 --- a/imperative/python/megengine/module/init.py +++ b/imperative/python/megengine/module/init.py @@ -74,7 +74,7 @@ def calculate_gain( ) -> float: r"""Returns a recommended gain value (see the table below) for the given nonlinearity function. - + ================= ==================================================== nonlinearity gain ================= ==================================================== @@ -126,6 +126,11 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes input tensor is stored in ``NCHW`` format. + Note: + The group conv2d kernel shape in MegEngine is ``(G, O/G, I/G, K, K)``. This + function calculates ``fan_out = O/G * K * K`` as default, but PyTorch uses + ``fan_out = O * K * K``. + Args: tensor: weight tensor in ``NCHW`` format. """ @@ -141,6 +146,10 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: fan_in = shape[1] fan_out = shape[0] else: + if ndim >= 5: + # ignore the groups dimension of group conv2d and group conv3d + # FIXME: will be wrong for conv3d + shape = shape[1:] num_input_fmaps = shape[1] num_output_fmaps = shape[0] receptive_field_size = 1 @@ -154,7 +163,7 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: def calculate_correct_fan(tensor: Tensor, mode: str) -> float: r"""Calculates fan_in / fan_out value for given weight tensor, depending on given ``mode``. - + See :func:`calculate_fan_in_and_fan_out` for details. Args: @@ -175,11 +184,11 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float: def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)` where - + .. math:: a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} - + Also known as Glorot initialization. Detailed information can be retrieved from `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). @@ -197,11 +206,11 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: r"""Fills tensor with random values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where - + .. math:: \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} - + Also known as Glorot initialization. Detailed information can be retrieved from `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). @@ -220,11 +229,11 @@ def msra_uniform_( ) -> None: r"""Fills tensor wilth random values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where - + .. math:: \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} - + Detailed information can be retrieved from `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` @@ -251,11 +260,11 @@ def msra_normal_( ) -> None: r"""Fills tensor wilth random values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where - + .. math:: \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} - + Detailed information can be retrieved from `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` diff --git a/imperative/python/test/unit/module/test_init.py b/imperative/python/test/unit/module/test_init.py index 9f3a019ea..b28f60e17 100644 --- a/imperative/python/test/unit/module/test_init.py +++ b/imperative/python/test/unit/module/test_init.py @@ -10,7 +10,7 @@ import numpy as np import pytest from megengine import tensor -from megengine.module import Conv2d, Linear +from megengine.module import Conv1d, Conv2d, Conv3d, Linear from megengine.module.init import calculate_fan_in_and_fan_out, fill_ @@ -32,7 +32,34 @@ def test_calculate_fan_in_and_fan_out(): with pytest.raises(ValueError): calculate_fan_in_and_fan_out(l.bias) + l = Conv1d(in_channels=2, out_channels=3, kernel_size=5) + fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + assert fanin == 2 * 5 + assert fanout == 3 * 5 + + # FIXME: will be wrong for group conv1d + # l = Conv1d(in_channels=2, out_channels=4, kernel_size=5, groups=2) + # fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + # assert fanin == 2 // 2 * 5 + # assert fanout == 4 // 2 * 5 + l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7)) fanin, fanout = calculate_fan_in_and_fan_out(l.weight) assert fanin == 2 * 5 * 7 assert fanout == 3 * 5 * 7 + + l = Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 7), groups=2) + fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + assert fanin == 2 // 2 * 5 * 7 + assert fanout == 4 // 2 * 5 * 7 + + # FIXME: will be wrong for conv3d + # l = Conv3d(in_channels=2, out_channels=3, kernel_size=(5, 7, 9)) + # fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + # assert fanin == 2 * 5 * 7 * 9 + # assert fanout == 3 * 5 * 7 * 9 + + l = Conv3d(in_channels=2, out_channels=4, kernel_size=(5, 7, 9), groups=2) + fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + assert fanin == 2 // 2 * 5 * 7 * 9 + assert fanout == 4 // 2 * 5 * 7 * 9 -- GitLab