From 2beb65b19d1193640d5fbaa6415e88c889c2af03 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Sep 2020 18:11:43 +0800 Subject: [PATCH] fix(mge/module): tensor shape will not work when constructing numpy array GitOrigin-RevId: 5a0d705970e70aa057863066257e0091ad0733d0 --- imperative/python/megengine/module/init.py | 8 +++++--- imperative/python/test/unit/module/test_init.py | 13 ++++++++++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/module/init.py b/imperative/python/megengine/module/init.py index c2cb5075..7d8e06f8 100644 --- a/imperative/python/megengine/module/init.py +++ b/imperative/python/megengine/module/init.py @@ -12,6 +12,8 @@ from typing import Optional, Tuple, Union import numpy as np +from ..functional import full +from ..random import gaussian, uniform from ..tensor import Tensor @@ -21,7 +23,7 @@ def fill_(tensor: Tensor, val: Union[float, int]) -> None: :param tensor: An n-dimentional tensor to be initialized :param val: The value to be filled throughout the tensor """ - tensor.set_value(np.full(tensor.shape, val, tensor.dtype)) + tensor.set_value(full(shape=tensor.shape, value=val, dtype=tensor.dtype)) def zeros_(tensor: Tensor) -> None: @@ -48,7 +50,7 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: :param a: Lower bound of the sampling interval :param b: Upper bound of the sampling interval """ - tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype)) + tensor.set_value(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype)) def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: @@ -59,7 +61,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: :param mean: The mean of the normal distribution :param std: The standard deviation of the normal distribution """ - tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32)) + tensor.set_value(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype)) def calculate_gain( diff --git a/imperative/python/test/unit/module/test_init.py b/imperative/python/test/unit/module/test_init.py index 06bc4339..e12acb3e 100644 --- a/imperative/python/test/unit/module/test_init.py +++ b/imperative/python/test/unit/module/test_init.py @@ -6,10 +6,21 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np import pytest +from megengine import tensor from megengine.module import Conv2d, Linear -from megengine.module.init import calculate_fan_in_and_fan_out +from megengine.module.init import calculate_fan_in_and_fan_out, fill_ + + +def test_fill_(): + x = tensor(np.zeros((2, 3, 4)), dtype=np.float32) + fill_(x, 5.0) + + np.testing.assert_array_equal( + x.numpy(), np.full(shape=(2, 3, 4), fill_value=5.0, dtype=np.float32) + ) def test_calculate_fan_in_and_fan_out(): -- GitLab