提交 ea00b57d 编写于 作者: M Megvii Engine Team

refactor(mge/module): use mge rng to sample from uniform/normal distribution

GitOrigin-RevId: 6ec8f99af5cc3f82fe70fe357e274f67eb1d4c36
上级 a1597cfc
......@@ -12,7 +12,8 @@ from typing import Optional, Tuple, Union
import numpy as np
from ..core import Tensor
from ..core import Tensor, Graph
from ..random import gaussian, uniform
def fill_(tensor: Tensor, val: Union[float, int]) -> None:
......@@ -48,7 +49,8 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None:
:param a: Lower bound of the sampling interval
:param b: Upper bound of the sampling interval
"""
tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype))
with Graph(eager_evaluation=True):
tensor.set_value((b - a) * uniform(tensor.shape) + a)
def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
......@@ -59,7 +61,8 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
:param mean: The mean of the normal distribution
:param std: The standard deviation of the normal distribution
"""
tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32))
with Graph(eager_evaluation=True):
tensor.set_value(gaussian(tensor.shape, mean=mean, std=std))
def calculate_gain(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册