提交 cf4353f7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3220 Add random normal op at MindSpore front-end

Merge pull request !3220 from peixu_ren/custom_gpu
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Bernoulli Distribution""" """Bernoulli Distribution"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob from ._utils.utils import cast_to_tensor, check_prob
from ...common import dtype as mstype from ...common import dtype as mstype
...@@ -53,6 +54,7 @@ class Bernoulli(Distribution): ...@@ -53,6 +54,7 @@ class Bernoulli(Distribution):
check_prob(self._probs) check_prob(self._probs)
else: else:
self._probs = probs self._probs = probs
self.seed = seed
# ops needed for the class # ops needed for the class
self.log = P.Log() self.log = P.Log()
...@@ -64,7 +66,6 @@ class Bernoulli(Distribution): ...@@ -64,7 +66,6 @@ class Bernoulli(Distribution):
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.less = P.Less() self.less = P.Less()
self.cast = P.Cast() self.cast = P.Cast()
self.normal = P.Normal(seed=seed)
self.erf = P.Erf() self.erf = P.Erf()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
...@@ -159,7 +160,7 @@ class Bernoulli(Distribution): ...@@ -159,7 +160,7 @@ class Bernoulli(Distribution):
mean_zero = self.const(0.0) mean_zero = self.const(0.0)
sd_one = self.const(1.0) sd_one = self.const(1.0)
sqrt_two = self.sqrt(self.const(2.0)) sqrt_two = self.sqrt(self.const(2.0))
sample_norm = self.normal(sample_shape, mean_zero, sd_one) sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two))) sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two)))
sample = self.less(sample_uniform, probs1) sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self._dtype) sample = self.cast(sample, self._dtype)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Normal Distribution""" """Normal Distribution"""
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_equal_zero from ._utils.utils import convert_to_batch, check_greater_equal_zero
from ...common import dtype as mstype from ...common import dtype as mstype
...@@ -60,6 +61,7 @@ class Normal(Distribution): ...@@ -60,6 +61,7 @@ class Normal(Distribution):
else: else:
self._mean_value = mean self._mean_value = mean
self._sd_value = sd self._sd_value = sd
self.seed = seed
#ops needed for the class #ops needed for the class
self.exp = P.Exp() self.exp = P.Exp()
...@@ -70,7 +72,6 @@ class Normal(Distribution): ...@@ -70,7 +72,6 @@ class Normal(Distribution):
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv() self.realdiv = P.RealDiv()
self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step
self.normal = P.Normal(seed=seed)
self.shape = P.Shape() self.shape = P.Shape()
self.zeroslike = P.ZerosLike() self.zeroslike = P.ZerosLike()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
...@@ -163,7 +164,7 @@ class Normal(Distribution): ...@@ -163,7 +164,7 @@ class Normal(Distribution):
sample_shape = shape + batch_shape sample_shape = shape + batch_shape
mean_zero = self.const(0.0) mean_zero = self.const(0.0)
sd_one = self.const(1.0) sd_one = self.const(1.0)
sample_norm = self.normal(sample_shape, mean_zero, sd_one) sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample = self.add(mean, self.mul(sample_norm, sd)) sample = self.add(mean, self.mul(sample_norm, sd))
return sample return sample
return None return None
...@@ -27,6 +27,7 @@ from .clip_ops import clip_by_value ...@@ -27,6 +27,7 @@ from .clip_ops import clip_by_value
from .multitype_ops.add_impl import hyper_add from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal
__all__ = [ __all__ = [
...@@ -47,4 +48,5 @@ __all__ = [ ...@@ -47,4 +48,5 @@ __all__ = [
'zeros_like', 'zeros_like',
'ones_like', 'ones_like',
'zip_operation', 'zip_operation',
'normal',
'clip_by_value',] 'clip_by_value',]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operations for random number generatos."""
from mindspore.ops.primitive import constexpr
from .. import operations as P
# set graph-level RNG seed
_GRAPH_SEED = 0
@constexpr
def set_seed(seed):
global _GRAPH_SEED
_GRAPH_SEED = seed
@constexpr
def get_seed():
return _GRAPH_SEED
def normal(shape, mean, stddev, seed):
"""
Generates random numbers according to the Normal (or Gaussian) random number distribution.
It is defined as:
Args:
- **shape** (tuple) - The shape of random tensor to be generated.
- **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak.
With float32 data type.
- **stddev** (Tensor) - The deviation σ distribution parameter. With float32 data type.
- **seed** (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev.
The dtype is float32.
Examples:
>>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32)
>>> output = C.normal(shape, mean, stddev, seed=5)
"""
set_seed(10)
seed1 = get_seed()
seed2 = seed
stdnormal = P.StandardNormal(seed1, seed2)
rnd = stdnormal(shape)
value = rnd * stddev + mean
return value
...@@ -55,7 +55,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A ...@@ -55,7 +55,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
from .random_ops import (RandomChoiceWithMask, Normal, Gamma, Poisson, UniformInt, UniformReal, from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, Laplace) RandomCategorical, Laplace)
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm, from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D, BiasAdd, Conv2D,
...@@ -175,7 +175,7 @@ __all__ = [ ...@@ -175,7 +175,7 @@ __all__ = [
'HSigmoid', 'HSigmoid',
'Tanh', 'Tanh',
'RandomChoiceWithMask', 'RandomChoiceWithMask',
'Normal', 'StandardNormal',
'Gamma', 'Gamma',
'Poisson', 'Poisson',
'UniformInt', 'UniformInt',
......
...@@ -22,6 +22,48 @@ from ..primitive import PrimitiveWithInfer, prim_attr_register ...@@ -22,6 +22,48 @@ from ..primitive import PrimitiveWithInfer, prim_attr_register
from .._utils import get_broadcast_shape from .._utils import get_broadcast_shape
class StandardNormal(PrimitiveWithInfer):
r"""
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
Args:
seed (int): Random seed. Default: 0.
seed2 (int): Random seed2. Default: 0.
Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
Outputs:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev.
The dtype is float32.
Examples:
>>> shape = (4, 16)
>>> stdnormal = P.StandardNormal(seed=2)
>>> output = stdnormal(shape)
"""
@prim_attr_register
def __init__(self, seed=0, seed2=0):
"""Init StandardNormal"""
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)
def __infer__(self, shape):
shape_v = shape["value"]
if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
out = {
'shape': shape_v,
'dtype': mstype.float32,
'value': None}
return out
class Laplace(PrimitiveWithInfer): class Laplace(PrimitiveWithInfer):
r""" r"""
Generates random numbers according to the Laplace random number distribution. Generates random numbers according to the Laplace random number distribution.
...@@ -393,46 +435,3 @@ class RandomCategorical(PrimitiveWithInfer): ...@@ -393,46 +435,3 @@ class RandomCategorical(PrimitiveWithInfer):
return {'shape': (x_shape), return {'shape': (x_shape),
'dtype': (self.dtype), 'dtype': (self.dtype),
'value': None} 'value': None}
class Normal(PrimitiveWithInfer):
"""
Generates random samples from a normal(Gaussian) distribution.
Args:
seed (int): Random seed. Default: 0.
Inputs:
- **shape** (tuple[int]) - The shape of output tensor. Only constant value is allowed.
- **mean** (Tensor) - The mean of the distribution, with float32 data type.
- **stddev** (Tensor) - The standard deviation of the distribution, with float32 data type.
Outputs:
Tensor, with the given shape from the specific distribution and float32 data type.
Examples:
>>> normal = P.Normal()
>>> mean = Tensor(0., mstype.float32)
>>> stddev = Tensor(1., mstype.float32)
>>> out = normal((32, 3, 3), mean, stddev)
"""
@prim_attr_register
def __init__(self, seed=0):
"""Init Normal"""
validator.check_value_type("seed", seed, [int], self.name)
def __infer__(self, shape, mean, stddev):
shape_value = shape["value"]
if shape_value is None:
raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_value, [tuple], self.name)
for i, shape_i in enumerate(shape_value):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GE, self.name)
validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
validator.check_tensor_type_same({"stddev": stddev["dtype"]}, [mstype.float32], self.name)
out = {"shape": shape_value,
"dtype": mstype.float32,
"value": None}
return out
...@@ -43,7 +43,6 @@ def test_net_1D(): ...@@ -43,7 +43,6 @@ def test_net_1D():
net = Net(shape, seed) net = Net(shape, seed)
tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32)
output = net(tmean, tstddev) output = net(tmean, tstddev)
print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)
...@@ -55,5 +54,4 @@ def test_net_ND(): ...@@ -55,5 +54,4 @@ def test_net_ND():
net = Net(shape, seed) net = Net(shape, seed)
tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32)
output = net(tmean, tstddev) output = net(tmean, tstddev)
print(output.asnumpy())
assert output.shape == (3, 2, 2) assert output.shape == (3, 2, 2)
...@@ -13,13 +13,8 @@ ...@@ -13,13 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np
import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
...@@ -43,5 +38,4 @@ def test_net(): ...@@ -43,5 +38,4 @@ def test_net():
shape = (3, 2, 4) shape = (3, 2, 4)
net = Net(shape, seed, seed2) net = Net(shape, seed, seed2)
output = net() output = net()
print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self.shape = shape
self.seed = seed
def construct(self, mean, stddev):
return C.normal(self.shape, mean, stddev, self.seed)
def test_net_1D():
seed = 10
shape = (3, 2, 4)
mean = 1.0
stddev = 1.0
net = Net(shape, seed)
tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32)
output = net(tmean, tstddev)
assert output.shape == (3, 2, 4)
def test_net_ND():
seed = 10
shape = (3, 1, 2)
mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
stddev = np.array([1.0]).astype(np.float32)
net = Net(shape, seed)
tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32)
output = net(tmean, tstddev)
assert output.shape == (3, 2, 2)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.context as context
import mindspore.nn as nn
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self, shape, seed=0, seed2=0):
super(Net, self).__init__()
self.shape = shape
self.seed = seed
self.seed2 = seed2
self.stdnormal = P.StandardNormal(seed, seed2)
def construct(self):
return self.stdnormal(self.shape, self.seed, self.seed2)
def test_net():
seed = 10
seed2 = 10
shape = (3, 2, 4)
net = Net(shape, seed, seed2)
output = net()
assert output.shape == (3, 2, 4)
...@@ -571,10 +571,10 @@ class NormalNet(nn.Cell): ...@@ -571,10 +571,10 @@ class NormalNet(nn.Cell):
def __init__(self, shape=None, seed=0): def __init__(self, shape=None, seed=0):
super(NormalNet, self).__init__() super(NormalNet, self).__init__()
self.shape = shape self.shape = shape
self.normal = P.Normal(seed=seed) self.seed = seed
def construct(self, mean, stddev): def construct(self, mean, stddev):
out = self.normal(self.shape, mean, stddev) out = C.normal(self.shape, mean, stddev, self.seed)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册