提交 2beb65b1 编写于 作者: M Megvii Engine Team

fix(mge/module): tensor shape will not work when constructing numpy array

GitOrigin-RevId: 5a0d705970e70aa057863066257e0091ad0733d0
上级 651920c7
......@@ -12,6 +12,8 @@ from typing import Optional, Tuple, Union
import numpy as np
from ..functional import full
from ..random import gaussian, uniform
from ..tensor import Tensor
......@@ -21,7 +23,7 @@ def fill_(tensor: Tensor, val: Union[float, int]) -> None:
:param tensor: An n-dimentional tensor to be initialized
:param val: The value to be filled throughout the tensor
"""
tensor.set_value(np.full(tensor.shape, val, tensor.dtype))
tensor.set_value(full(shape=tensor.shape, value=val, dtype=tensor.dtype))
def zeros_(tensor: Tensor) -> None:
......@@ -48,7 +50,7 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None:
:param a: Lower bound of the sampling interval
:param b: Upper bound of the sampling interval
"""
tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype))
tensor.set_value(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype))
def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
......@@ -59,7 +61,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
:param mean: The mean of the normal distribution
:param std: The standard deviation of the normal distribution
"""
tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32))
tensor.set_value(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype))
def calculate_gain(
......
......@@ -6,10 +6,21 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
from megengine import tensor
from megengine.module import Conv2d, Linear
from megengine.module.init import calculate_fan_in_and_fan_out
from megengine.module.init import calculate_fan_in_and_fan_out, fill_
def test_fill_():
x = tensor(np.zeros((2, 3, 4)), dtype=np.float32)
fill_(x, 5.0)
np.testing.assert_array_equal(
x.numpy(), np.full(shape=(2, 3, 4), fill_value=5.0, dtype=np.float32)
)
def test_calculate_fan_in_and_fan_out():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册