提交 93d9cc3c 编写于 作者: P peixu_ren

Add erf and erfc as generic functions for all the backend and fix notation in power_transform.

上级 c165a6d0
...@@ -17,14 +17,14 @@ from mindspore.ops import operations as P ...@@ -17,14 +17,14 @@ from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ..distribution._utils.utils import CheckTensor from ..distribution._utils.utils import CheckTensor
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
from .bijector import Bijector from .bijector import Bijector
class PowerTransform(Bijector): class PowerTransform(Bijector):
r""" r"""
Power Bijector. Power Bijector.
This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c is power. This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c >= 0 is the power.
The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`. The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`.
...@@ -61,10 +61,10 @@ class PowerTransform(Bijector): ...@@ -61,10 +61,10 @@ class PowerTransform(Bijector):
validator.check_number("power", power, 0, Rel.GE, self.name) validator.check_number("power", power, 0, Rel.GE, self.name)
self._power = power self._power = power
self.pow = P.Pow() self.pow = P.Pow()
self.exp = exp_by_step self.exp = exp_generic
self.expm1 = expm1_by_step self.expm1 = expm1_generic
self.log = log_by_step self.log = log_generic
self.log1p = log1p_by_step self.log1p = log1p_generic
self.checktensor = CheckTensor() self.checktensor = CheckTensor()
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor, CheckTensor from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.custom_ops import log_by_step from ..distribution._utils.custom_ops import log_generic
from .bijector import Bijector from .bijector import Bijector
...@@ -69,7 +69,7 @@ class ScalarAffine(Bijector): ...@@ -69,7 +69,7 @@ class ScalarAffine(Bijector):
param=param) param=param)
self.abs = P.Abs() self.abs = P.Abs()
self.log = log_by_step self.log = log_generic
self.checktensor = CheckTensor() self.checktensor = CheckTensor()
......
...@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype ...@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
from mindspore.nn.layer.activation import LogSigmoid from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor, CheckTensor from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic
from .bijector import Bijector from .bijector import Bijector
...@@ -61,9 +61,9 @@ class Softplus(Bijector): ...@@ -61,9 +61,9 @@ class Softplus(Bijector):
super(Softplus, self).__init__(name=name, param=param) super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness) self._sharpness = cast_to_tensor(sharpness)
self.exp = exp_by_step self.exp = exp_generic
self.log = log_by_step self.log = log_generic
self.expm1 = expm1_by_step self.expm1 = expm1_generic
self.abs = P.Abs() self.abs = P.Abs()
self.fill = P.Fill() self.fill = P.Fill()
self.greater = P.Greater() self.greater = P.Greater()
......
...@@ -28,8 +28,10 @@ __all__ = [ ...@@ -28,8 +28,10 @@ __all__ = [
'check_scalar_from_param', 'check_scalar_from_param',
'check_prob', 'check_prob',
'check_type', 'check_type',
'exp_by_step', 'exp_generic',
'expm1_by_step', 'expm1_generic',
'log_by_step', 'log_generic',
'log1p_by_step', 'log1p_generic',
'erf_generic',
'erfc_generic',
] ]
...@@ -17,8 +17,7 @@ import numpy as np ...@@ -17,8 +17,7 @@ import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
def exp_generic(input_x):
def exp_by_step(input_x):
""" """
Log op on Ascend doesn't supprot int types. Log op on Ascend doesn't supprot int types.
Fix this with casting the type. Fix this with casting the type.
...@@ -30,14 +29,14 @@ def exp_by_step(input_x): ...@@ -30,14 +29,14 @@ def exp_by_step(input_x):
return exp(input_x) return exp(input_x)
def expm1_by_step(input_x): def expm1_generic(input_x):
""" """
Expm1 ops under GPU context. Expm1 ops under GPU context.
""" """
return exp_by_step(input_x) - 1.0 return exp_generic(input_x) - 1.0
def log_by_step(input_x): def log_generic(input_x):
""" """
Log op on Ascend is calculated as log(abs(x)). Log op on Ascend is calculated as log(abs(x)).
Fix this with putting negative values as nan. Fix this with putting negative values as nan.
...@@ -63,8 +62,166 @@ def log_by_step(input_x): ...@@ -63,8 +62,166 @@ def log_by_step(input_x):
return select(neg_x, nan, result) return select(neg_x, nan, result)
def log1p_by_step(x): def log1p_generic(x):
""" """
Log1p ops on GPU device or when device_target == GPU. Log1p ops on GPU device or when device_target == GPU.
""" """
return log_by_step(x + 1.0) return log_generic(x + 1.0)
def _evaluate_polynomial(x, coefficients):
poly = 0
for co in coefficients:
poly = poly * x + co
return poly
def erf_f32_generic(x):
"""
Calculate erf for dtype of f32
"""
k_erf_tcoefficient = [+7.853861353153693e-5,
-8.010193625184903e-4,
+5.188327685732524e-3,
-2.685381193529856e-2,
+1.128358514861418e-1,
-3.761262582423300e-1,
+1.128379165726710e+0]
poly = _evaluate_polynomial(x * x, k_erf_tcoefficient)
return x * poly
def erf_f64_generic(x):
"""
Calculate erf for dtype of f64
"""
k_erf_tcoefficient = [9.60497373987051638749e0,
9.00260197203842689217e1,
2.23200534594684319226e3,
7.00332514112805075473e3,
5.55923013010394962768e4]
k_erf_ucoefficient = [1.00000000000000000000e0,
3.35617141647503099647e1,
5.21357949780152679795e2,
4.59432382970980127987e3,
2.26290000613890934246e4,
4.92673942608635921086e4]
z = x * x
poly1 = _evaluate_polynomial(z, k_erf_tcoefficient)
poly2 = _evaluate_polynomial(z, k_erf_ucoefficient)
return x * poly1 / poly2
def erfc_f32_generic(x):
"""
Calculate erfc for dtype of f32
"""
k_maxlog = 88.72283905206835
k_erfc_pcoefficient = [+2.326819970068386e-2,
-1.387039388740657e-1,
+3.687424674597105e-1,
-5.824733027278666e-1,
+6.210004621745983e-1,
-4.944515323274145e-1,
+3.404879937665872e-1,
-2.741127028184656e-1,
+5.638259427386472e-1]
k_erfc_rcoefficient = [-1.047766399936249e+1,
+1.297719955372516e+1,
-7.495518717768503e+0,
+2.921019019210786e+0,
-1.015265279202700e+0,
+4.218463358204948e-1,
-2.820767439740514e-1,
+5.641895067754075e-1]
abs_cal = P.Abs()
select = P.Select()
less = P.Less()
fill = P.Fill()
dtype = P.DType()
shape = P.Shape()
abs_x = abs_cal(x)
z = exp_generic(-x * x)
q = 1 / abs_x
y = q * q
poly1 = _evaluate_polynomial(y, k_erfc_pcoefficient)
poly2 = _evaluate_polynomial(y, k_erfc_rcoefficient)
p = select(less(abs_x, 2.0), poly1, poly2)
y = z * q * p
zeros = fill(dtype(x), shape(x), 0)
y_clamp = select(less(z, -k_maxlog), zeros, y)
return select(less(x, 0), 2.0 - y_clamp, y_clamp)
def erfc_f64_generic(x):
"""
Calculate erfc for dtype of f64
"""
k_maxlog = 7.09782712893383996843e2
k_erfc_pcoefficient = [2.46196981473530512524e-10,
5.64189564831068821977e-1,
7.46321056442269912687e0,
4.86371970985681366614e1,
1.96520832956077098242e2,
5.26445194995477358631e2,
9.34528527171957607540e2,
1.02755188689515710272e3,
5.57535335369399327526e2]
k_erfc_qcoefficient = [1.00000000000000000000e0,
1.32281951154744992508e1,
8.67072140885989742329e1,
3.54937778887819891062e2,
9.75708501743205489753e2,
1.82390916687909736289e3,
2.24633760818710981792e3,
1.65666309194161350182e3,
5.57535340817727675546e2]
k_erfc_rcoefficient = [5.64189583547755073984e-1,
1.27536670759978104416e0,
5.01905042251180477414e0,
6.16021097993053585195e0,
7.40974269950448939160e0,
2.97886665372100240670e0]
k_erfc_scoefficient = [1.00000000000000000000e0,
2.26052863220117276590e0,
9.39603524938001434673e0,
1.20489539808096656605e1,
1.70814450747565897222e1,
9.60896809063285878198e0,
3.36907645100081516050e02]
abs_cal = P.Abs()
select = P.Select()
less = P.Less()
fill = P.Fill()
dtype = P.DType()
shape = P.Shape()
abs_x = abs_cal(x)
z = -x * x
exp_z = exp_generic(z)
temp1 = exp_z * _evaluate_polynomial(abs_x, k_erfc_pcoefficient) / _evaluate_polynomial(abs_x, k_erfc_qcoefficient)
temp2 = exp_z * _evaluate_polynomial(abs_x, k_erfc_rcoefficient) / _evaluate_polynomial(abs_x, k_erfc_scoefficient)
y = select(less(abs_x, 8.0), temp1, temp2)
zeros = fill(dtype(x), shape(x), 0)
y_clamp = select(less(z, k_maxlog), zeros, y)
poly2 = _evaluate_polynomial(y, k_erfc_rcoefficient)
p = select(less(abs_x, 2.0), poly1, poly2)
y = z * q * p
zeros = fill(dtype(x), shape(x), 0)
y_clamp = select(less(z, -k_maxlog), zeros, y)
return select(less(x, 0), 2.0 - y_clamp, y_clamp)
def erfc_generic(x):
select = P.Select()
greater = P.Greater()
abs_cal = P.Abs()
return select(greater(abs_cal(x), 1), erfc_f32_generic(x), 1 - erf_f32_generic(x))
def erf_generic(x):
select = P.Select()
less = P.Less()
abs_cal = P.Abs()
return select(less(abs_cal(x), 1), erf_f32_generic(x), 1 - erfc_f32_generic(x))
...@@ -18,7 +18,7 @@ from mindspore.ops import operations as P ...@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
from ._utils.custom_ops import exp_by_step, log_by_step from ._utils.custom_ops import exp_generic, log_generic, erf_generic
class Bernoulli(Distribution): class Bernoulli(Distribution):
...@@ -109,13 +109,13 @@ class Bernoulli(Distribution): ...@@ -109,13 +109,13 @@ class Bernoulli(Distribution):
self._probs = probs self._probs = probs
# ops needed for the class # ops needed for the class
self.exp = exp_by_step self.exp = exp_generic
self.log = log_by_step self.log = log_generic
self.erf = erf_generic
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.erf = P.Erf()
self.floor = P.Floor() self.floor = P.Floor()
self.fill = P.Fill() self.fill = P.Fill()
self.less = P.Less() self.less = P.Less()
......
...@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype ...@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import exp_by_step, log_by_step from ._utils.custom_ops import exp_generic, log_generic
class Exponential(Distribution): class Exponential(Distribution):
""" """
...@@ -112,8 +112,8 @@ class Exponential(Distribution): ...@@ -112,8 +112,8 @@ class Exponential(Distribution):
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
self.exp = exp_by_step self.exp = exp_generic
self.log = log_by_step self.log = log_generic
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
......
...@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype ...@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import exp_by_step, log_by_step from ._utils.custom_ops import exp_generic, log_generic
class Geometric(Distribution): class Geometric(Distribution):
...@@ -114,8 +114,8 @@ class Geometric(Distribution): ...@@ -114,8 +114,8 @@ class Geometric(Distribution):
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
self.exp = exp_by_step self.exp = exp_generic
self.log = log_by_step self.log = log_generic
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
......
...@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype ...@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
class Normal(Distribution): class Normal(Distribution):
""" """
...@@ -114,13 +114,13 @@ class Normal(Distribution): ...@@ -114,13 +114,13 @@ class Normal(Distribution):
self._sd_value = sd self._sd_value = sd
#ops needed for the class #ops needed for the class
self.exp = exp_by_step self.exp = exp_generic
self.expm1 = expm1_by_step self.expm1 = expm1_generic
self.log = log_by_step self.log = log_generic
self.erf = erf_generic
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.erf = P.Erf()
self.fill = P.Fill() self.fill = P.Fill()
self.shape = P.Shape() self.shape = P.Shape()
self.sq = P.Square() self.sq = P.Square()
......
...@@ -18,7 +18,7 @@ from mindspore.common import dtype as mstype ...@@ -18,7 +18,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_type, raise_not_impl_error from ._utils.utils import check_type, raise_not_impl_error
from ._utils.custom_ops import exp_by_step, log_by_step from ._utils.custom_ops import exp_generic, log_generic
class TransformedDistribution(Distribution): class TransformedDistribution(Distribution):
""" """
...@@ -55,8 +55,8 @@ class TransformedDistribution(Distribution): ...@@ -55,8 +55,8 @@ class TransformedDistribution(Distribution):
self._bijector = bijector self._bijector = bijector
self._distribution = distribution self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian self._is_linear_transformation = bijector.is_constant_jacobian
self.exp = exp_by_step self.exp = exp_generic
self.log = log_by_step self.log = log_generic
@property @property
def bijector(self): def bijector(self):
......
...@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype ...@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import exp_by_step, log_by_step from ._utils.custom_ops import exp_generic, log_generic
class Uniform(Distribution): class Uniform(Distribution):
""" """
...@@ -113,8 +113,8 @@ class Uniform(Distribution): ...@@ -113,8 +113,8 @@ class Uniform(Distribution):
self._high = high self._high = high
# ops needed for the class # ops needed for the class
self.exp = exp_by_step self.exp = exp_generic
self.log = log_by_step self.log = log_generic
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册