提交 3a11d123 编写于 作者: P peixu_ren

Update random uniform op invocation

上级 64b0feb7
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Bernoulli Distribution""" """Bernoulli Distribution"""
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type from ._utils.utils import cast_to_tensor, check_prob, check_type
...@@ -116,7 +117,7 @@ class Bernoulli(Distribution): ...@@ -116,7 +117,7 @@ class Bernoulli(Distribution):
self.select = P.Select() self.select = P.Select()
self.sq = P.Square() self.sq = P.Square()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed) self.uniform = C.uniform
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
...@@ -256,7 +257,6 @@ class Bernoulli(Distribution): ...@@ -256,7 +257,6 @@ class Bernoulli(Distribution):
probs1 = self.probs if probs is None else probs probs1 = self.probs if probs is None else probs
l_zero = self.const(0.0) l_zero = self.const(0.0)
h_one = self.const(1.0) h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed)
sample = self.less(sample_uniform, probs1) sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self.dtype) return self.cast(sample, self.dtype)
return sample
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Exponential Distribution""" """Exponential Distribution"""
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type from ._utils.utils import cast_to_tensor, check_greater_zero, check_type
...@@ -107,7 +108,8 @@ class Exponential(Distribution): ...@@ -107,7 +108,8 @@ class Exponential(Distribution):
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.exp = P.Exp() self.exp = P.Exp()
...@@ -118,7 +120,7 @@ class Exponential(Distribution): ...@@ -118,7 +120,7 @@ class Exponential(Distribution):
self.shape = P.Shape() self.shape = P.Shape()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.sq = P.Square() self.sq = P.Square()
self.uniform = P.UniformReal(seed=seed) self.uniform = C.uniform
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
...@@ -251,5 +253,6 @@ class Exponential(Distribution): ...@@ -251,5 +253,6 @@ class Exponential(Distribution):
rate = self.rate if rate is None else rate rate = self.rate if rate is None else rate
minval = self.const(self.minval) minval = self.const(self.minval)
maxval = self.const(1.0) maxval = self.const(1.0)
sample = self.uniform(shape + self.shape(rate), minval, maxval) sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed)
return -self.log(sample) / rate sample = -self.log(sample_uniform) / rate
return self.cast(sample, self.dtype)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Geometric Distribution""" """Geometric Distribution"""
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type from ._utils.utils import cast_to_tensor, check_prob, check_type
...@@ -109,6 +110,7 @@ class Geometric(Distribution): ...@@ -109,6 +110,7 @@ class Geometric(Distribution):
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.fill = P.Fill() self.fill = P.Fill()
...@@ -121,7 +123,7 @@ class Geometric(Distribution): ...@@ -121,7 +123,7 @@ class Geometric(Distribution):
self.shape = P.Shape() self.shape = P.Shape()
self.sq = P.Square() self.sq = P.Square()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed) self.uniform = C.uniform
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
...@@ -269,5 +271,6 @@ class Geometric(Distribution): ...@@ -269,5 +271,6 @@ class Geometric(Distribution):
probs = self.probs if probs is None else probs probs = self.probs if probs is None else probs
minval = self.const(self.minval) minval = self.const(self.minval)
maxval = self.const(1.0) maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval, self.seed)
return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
return self.cast(sample, self.dtype)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Uniform Distribution""" """Uniform Distribution"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type from ._utils.utils import convert_to_batch, check_greater, check_type
...@@ -108,7 +109,8 @@ class Uniform(Distribution): ...@@ -108,7 +109,8 @@ class Uniform(Distribution):
self._low = low self._low = low
self._high = high self._high = high
# ops needed for the class # ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.exp = P.Exp() self.exp = P.Exp()
...@@ -121,8 +123,8 @@ class Uniform(Distribution): ...@@ -121,8 +123,8 @@ class Uniform(Distribution):
self.shape = P.Shape() self.shape = P.Shape()
self.sq = P.Square() self.sq = P.Square()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)
self.zeroslike = P.ZerosLike() self.zeroslike = P.ZerosLike()
self.uniform = C.uniform
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
...@@ -284,6 +286,6 @@ class Uniform(Distribution): ...@@ -284,6 +286,6 @@ class Uniform(Distribution):
broadcast_shape = self.shape(low + high) broadcast_shape = self.shape(low + high)
l_zero = self.const(0.0) l_zero = self.const(0.0)
h_one = self.const(1.0) h_one = self.const(1.0)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed)
sample = (high - low) * sample_uniform + low sample = (high - low) * sample_uniform + low
return sample return self.cast(sample, self.dtype)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册