diff --git a/imperative/python/megengine/module/init.py b/imperative/python/megengine/module/init.py index c2cb50755a5f6fe0e0819d0fa3e87c57e6a73e80..7d8e06f86f265cb9c6ae9e7c52566e74fee03558 100644 --- a/imperative/python/megengine/module/init.py +++ b/imperative/python/megengine/module/init.py @@ -12,6 +12,8 @@ from typing import Optional, Tuple, Union import numpy as np +from ..functional import full +from ..random import gaussian, uniform from ..tensor import Tensor @@ -21,7 +23,7 @@ def fill_(tensor: Tensor, val: Union[float, int]) -> None: :param tensor: An n-dimentional tensor to be initialized :param val: The value to be filled throughout the tensor """ - tensor.set_value(np.full(tensor.shape, val, tensor.dtype)) + tensor.set_value(full(shape=tensor.shape, value=val, dtype=tensor.dtype)) def zeros_(tensor: Tensor) -> None: @@ -48,7 +50,7 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: :param a: Lower bound of the sampling interval :param b: Upper bound of the sampling interval """ - tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype)) + tensor.set_value(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype)) def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: @@ -59,7 +61,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: :param mean: The mean of the normal distribution :param std: The standard deviation of the normal distribution """ - tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32)) + tensor.set_value(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype)) def calculate_gain( diff --git a/imperative/python/test/unit/module/test_init.py b/imperative/python/test/unit/module/test_init.py index 06bc433961f29503ffc930786913e9a2066d5f26..e12acb3e0cdccf011acfd5ca1875b9d0cdf783fe 100644 --- a/imperative/python/test/unit/module/test_init.py +++ b/imperative/python/test/unit/module/test_init.py @@ -6,10 +6,21 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np import pytest +from megengine import tensor from megengine.module import Conv2d, Linear -from megengine.module.init import calculate_fan_in_and_fan_out +from megengine.module.init import calculate_fan_in_and_fan_out, fill_ + + +def test_fill_(): + x = tensor(np.zeros((2, 3, 4)), dtype=np.float32) + fill_(x, 5.0) + + np.testing.assert_array_equal( + x.numpy(), np.full(shape=(2, 3, 4), fill_value=5.0, dtype=np.float32) + ) def test_calculate_fan_in_and_fan_out():