From 19d80b87a91b7d5e7157db657a27512c78695c77 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Wed, 22 Jul 2020 14:50:45 -0400 Subject: [PATCH] Fix minor errors in probabilistic programming --- mindspore/ops/composite/__init__.py | 3 ++- mindspore/ops/composite/random_ops.py | 21 ++++++++++++------- .../test_aicpu_ops/test_standard_normal.py | 2 +- tests/st/ops/gpu/test_standard_normal.py | 2 +- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index bb5e2960f..530bf9e1b 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -27,7 +27,7 @@ from .clip_ops import clip_by_value from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like -from .random_ops import normal +from .random_ops import set_seed, normal __all__ = [ @@ -48,5 +48,6 @@ __all__ = [ 'zeros_like', 'ones_like', 'zip_operation', + 'set_seed', 'normal', 'clip_by_value',] diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index db338f567..53fa58c4d 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -15,8 +15,11 @@ """Operations for random number generatos.""" -from mindspore.ops.primitive import constexpr from .. import operations as P +from .. import functional as F +from ..primitive import constexpr +from .multitype_ops import _constexpr_utils as const_utils +from ...common import dtype as mstype # set graph-level RNG seed _GRAPH_SEED = 0 @@ -31,17 +34,17 @@ def get_seed(): return _GRAPH_SEED -def normal(shape, mean, stddev, seed): +def normal(shape, mean, stddev, seed=0): """ Generates random numbers according to the Normal (or Gaussian) random number distribution. It is defined as: Args: - - **shape** (tuple) - The shape of random tensor to be generated. - - **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak. + shape (tuple): The shape of random tensor to be generated. + mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak. With float32 data type. - - **stddev** (Tensor) - The deviation σ distribution parameter. With float32 data type. - - **seed** (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. + stddev (Tensor): The deviation σ distribution parameter. With float32 data type. + seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. Default: 0. Returns: @@ -52,9 +55,13 @@ def normal(shape, mean, stddev, seed): >>> shape = (4, 16) >>> mean = Tensor(1.0, mstype.float32) >>> stddev = Tensor(1.0, mstype.float32) + >>> C.set_seed(10) >>> output = C.normal(shape, mean, stddev, seed=5) """ - set_seed(10) + mean_dtype = F.dtype(mean) + stddev_dtype = F.dtype(stddev) + const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal") + const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal") seed1 = get_seed() seed2 = seed stdnormal = P.StandardNormal(seed1, seed2) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py b/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py index 847e3e623..818ae092b 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py @@ -29,7 +29,7 @@ class Net(nn.Cell): self.stdnormal = P.StandardNormal(seed, seed2) def construct(self): - return self.stdnormal(self.shape, self.seed, self.seed2) + return self.stdnormal(self.shape) def test_net(): diff --git a/tests/st/ops/gpu/test_standard_normal.py b/tests/st/ops/gpu/test_standard_normal.py index dd89848c9..efa4a99d7 100644 --- a/tests/st/ops/gpu/test_standard_normal.py +++ b/tests/st/ops/gpu/test_standard_normal.py @@ -29,7 +29,7 @@ class Net(nn.Cell): self.stdnormal = P.StandardNormal(seed, seed2) def construct(self): - return self.stdnormal(self.shape, self.seed, self.seed2) + return self.stdnormal(self.shape) def test_net(): -- GitLab