提交 3a11d123 编写于 作者: P peixu_ren

Update random uniform op invocation

上级 64b0feb7
......@@ -15,6 +15,7 @@
"""Bernoulli Distribution"""
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type
......@@ -116,7 +117,7 @@ class Bernoulli(Distribution):
self.select = P.Select()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)
self.uniform = C.uniform
def extend_repr(self):
if self.is_scalar_batch:
......@@ -256,7 +257,6 @@ class Bernoulli(Distribution):
probs1 = self.probs if probs is None else probs
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed)
sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self.dtype)
return sample
return self.cast(sample, self.dtype)
......@@ -15,6 +15,7 @@
"""Exponential Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type
......@@ -107,7 +108,8 @@ class Exponential(Distribution):
self.minval = np.finfo(np.float).tiny
# ops needed for the class
# ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
......@@ -118,7 +120,7 @@ class Exponential(Distribution):
self.shape = P.Shape()
self.sqrt = P.Sqrt()
self.sq = P.Square()
self.uniform = P.UniformReal(seed=seed)
self.uniform = C.uniform
def extend_repr(self):
if self.is_scalar_batch:
......@@ -251,5 +253,6 @@ class Exponential(Distribution):
rate = self.rate if rate is None else rate
minval = self.const(self.minval)
maxval = self.const(1.0)
sample = self.uniform(shape + self.shape(rate), minval, maxval)
return -self.log(sample) / rate
sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed)
sample = -self.log(sample_uniform) / rate
return self.cast(sample, self.dtype)
......@@ -15,6 +15,7 @@
"""Geometric Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type
......@@ -109,6 +110,7 @@ class Geometric(Distribution):
self.minval = np.finfo(np.float).tiny
# ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.fill = P.Fill()
......@@ -121,7 +123,7 @@ class Geometric(Distribution):
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)
self.uniform = C.uniform
def extend_repr(self):
if self.is_scalar_batch:
......@@ -269,5 +271,6 @@ class Geometric(Distribution):
probs = self.probs if probs is None else probs
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval)
return self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval, self.seed)
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
return self.cast(sample, self.dtype)
......@@ -14,6 +14,7 @@
# ============================================================================
"""Uniform Distribution"""
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type
......@@ -108,7 +109,8 @@ class Uniform(Distribution):
self._low = low
self._high = high
# ops needed for the class
# ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
......@@ -121,8 +123,8 @@ class Uniform(Distribution):
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)
self.zeroslike = P.ZerosLike()
self.uniform = C.uniform
def extend_repr(self):
if self.is_scalar_batch:
......@@ -284,6 +286,6 @@ class Uniform(Distribution):
broadcast_shape = self.shape(low + high)
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed)
sample = (high - low) * sample_uniform + low
return sample
return self.cast(sample, self.dtype)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册