提交 19d80b87 编写于 作者: P peixu_ren

Fix minor errors in probabilistic programming

上级 380db207
...@@ -27,7 +27,7 @@ from .clip_ops import clip_by_value ...@@ -27,7 +27,7 @@ from .clip_ops import clip_by_value
from .multitype_ops.add_impl import hyper_add from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal from .random_ops import set_seed, normal
__all__ = [ __all__ = [
...@@ -48,5 +48,6 @@ __all__ = [ ...@@ -48,5 +48,6 @@ __all__ = [
'zeros_like', 'zeros_like',
'ones_like', 'ones_like',
'zip_operation', 'zip_operation',
'set_seed',
'normal', 'normal',
'clip_by_value',] 'clip_by_value',]
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
"""Operations for random number generatos.""" """Operations for random number generatos."""
from mindspore.ops.primitive import constexpr
from .. import operations as P from .. import operations as P
from .. import functional as F
from ..primitive import constexpr
from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
# set graph-level RNG seed # set graph-level RNG seed
_GRAPH_SEED = 0 _GRAPH_SEED = 0
...@@ -31,17 +34,17 @@ def get_seed(): ...@@ -31,17 +34,17 @@ def get_seed():
return _GRAPH_SEED return _GRAPH_SEED
def normal(shape, mean, stddev, seed): def normal(shape, mean, stddev, seed=0):
""" """
Generates random numbers according to the Normal (or Gaussian) random number distribution. Generates random numbers according to the Normal (or Gaussian) random number distribution.
It is defined as: It is defined as:
Args: Args:
- **shape** (tuple) - The shape of random tensor to be generated. shape (tuple): The shape of random tensor to be generated.
- **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak. mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
With float32 data type. With float32 data type.
- **stddev** (Tensor) - The deviation σ distribution parameter. With float32 data type. stddev (Tensor): The deviation σ distribution parameter. With float32 data type.
- **seed** (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0. Default: 0.
Returns: Returns:
...@@ -52,9 +55,13 @@ def normal(shape, mean, stddev, seed): ...@@ -52,9 +55,13 @@ def normal(shape, mean, stddev, seed):
>>> shape = (4, 16) >>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32) >>> mean = Tensor(1.0, mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32) >>> stddev = Tensor(1.0, mstype.float32)
>>> C.set_seed(10)
>>> output = C.normal(shape, mean, stddev, seed=5) >>> output = C.normal(shape, mean, stddev, seed=5)
""" """
set_seed(10) mean_dtype = F.dtype(mean)
stddev_dtype = F.dtype(stddev)
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal")
const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal")
seed1 = get_seed() seed1 = get_seed()
seed2 = seed seed2 = seed
stdnormal = P.StandardNormal(seed1, seed2) stdnormal = P.StandardNormal(seed1, seed2)
......
...@@ -29,7 +29,7 @@ class Net(nn.Cell): ...@@ -29,7 +29,7 @@ class Net(nn.Cell):
self.stdnormal = P.StandardNormal(seed, seed2) self.stdnormal = P.StandardNormal(seed, seed2)
def construct(self): def construct(self):
return self.stdnormal(self.shape, self.seed, self.seed2) return self.stdnormal(self.shape)
def test_net(): def test_net():
......
...@@ -29,7 +29,7 @@ class Net(nn.Cell): ...@@ -29,7 +29,7 @@ class Net(nn.Cell):
self.stdnormal = P.StandardNormal(seed, seed2) self.stdnormal = P.StandardNormal(seed, seed2)
def construct(self): def construct(self):
return self.stdnormal(self.shape, self.seed, self.seed2) return self.stdnormal(self.shape)
def test_net(): def test_net():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册