提交 22598e5c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4084 Complement the arg passing conventions

Merge pull request !4084 from peixu_ren/custom_bijector
...@@ -69,31 +69,31 @@ class Bijector(Cell): ...@@ -69,31 +69,31 @@ class Bijector(Cell):
def is_injective(self): def is_injective(self):
return self._is_injective return self._is_injective
def forward(self, *args): def forward(self, *args, **kwargs):
""" """
Forward transformation: transform the input value to another distribution. Forward transformation: transform the input value to another distribution.
""" """
return self._forward(*args) return self._forward(*args, **kwargs)
def inverse(self, *args): def inverse(self, *args, **kwargs):
""" """
Inverse transformation: transform the input value back to the original distribution. Inverse transformation: transform the input value back to the original distribution.
""" """
return self._inverse(*args) return self._inverse(*args, **kwargs)
def forward_log_jacobian(self, *args): def forward_log_jacobian(self, *args, **kwargs):
""" """
Logarithm of the derivative of forward transformation. Logarithm of the derivative of forward transformation.
""" """
return self._forward_log_jacobian(*args) return self._forward_log_jacobian(*args, **kwargs)
def inverse_log_jacobian(self, *args): def inverse_log_jacobian(self, *args, **kwargs):
""" """
Logarithm of the derivative of forward transformation. Logarithm of the derivative of forward transformation.
""" """
return self._inverse_log_jacobian(*args) return self._inverse_log_jacobian(*args, **kwargs)
def __call__(self, *args): def __call__(self, *args, **kwargs):
""" """
Call Bijector directly. Call Bijector directly.
This __call__ may go into two directions: This __call__ may go into two directions:
...@@ -107,9 +107,9 @@ class Bijector(Cell): ...@@ -107,9 +107,9 @@ class Bijector(Cell):
""" """
if isinstance(args[0], Distribution): if isinstance(args[0], Distribution):
return TransformedDistribution(self, args[0]) return TransformedDistribution(self, args[0])
return super(Bijector, self).__call__(*args) return super(Bijector, self).__call__(*args, **kwargs)
def construct(self, name, *args): def construct(self, name, *args, **kwargs):
""" """
Override construct in Cell. Override construct in Cell.
...@@ -120,11 +120,11 @@ class Bijector(Cell): ...@@ -120,11 +120,11 @@ class Bijector(Cell):
Always raise RuntimeError as Distribution should not be called directly. Always raise RuntimeError as Distribution should not be called directly.
""" """
if name == 'forward': if name == 'forward':
return self.forward(*args) return self.forward(*args, **kwargs)
if name == 'inverse': if name == 'inverse':
return self.inverse(*args) return self.inverse(*args, **kwargs)
if name == 'forward_log_jacobian': if name == 'forward_log_jacobian':
return self.forward_log_jacobian(*args) return self.forward_log_jacobian(*args, **kwargs)
if name == 'inverse_log_jacobian': if name == 'inverse_log_jacobian':
return self.inverse_log_jacobian(*args) return self.inverse_log_jacobian(*args, **kwargs)
return None return None
...@@ -27,7 +27,7 @@ class Distribution(Cell): ...@@ -27,7 +27,7 @@ class Distribution(Cell):
Note: Note:
Derived class should override operations such as ,_mean, _prob, Derived class should override operations such as ,_mean, _prob,
and _log_prob. Arguments should be passed in through *args. and _log_prob. Arguments should be passed in through *args or **kwargs.
Dist_spec_args are unique for each type of distribution. For example, mean and sd Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution. are the dist_spec_args for a Normal distribution.
...@@ -171,7 +171,7 @@ class Distribution(Cell): ...@@ -171,7 +171,7 @@ class Distribution(Cell):
if hasattr(self, '_cross_entropy'): if hasattr(self, '_cross_entropy'):
self._call_cross_entropy = self._cross_entropy self._call_cross_entropy = self._cross_entropy
def log_prob(self, *args): def log_prob(self, *args, **kwargs):
""" """
Evaluate the log probability(pdf or pmf) at the given value. Evaluate the log probability(pdf or pmf) at the given value.
...@@ -179,18 +179,18 @@ class Distribution(Cell): ...@@ -179,18 +179,18 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_log_prob(*args) return self._call_log_prob(*args, **kwargs)
def _calc_prob_from_log_prob(self, *args): def _calc_prob_from_log_prob(self, *args, **kwargs):
r""" r"""
Evaluate prob from log probability. Evaluate prob from log probability.
.. math:: .. math::
probability(x) = \exp(log_likehood(x)) probability(x) = \exp(log_likehood(x))
""" """
return self.exp(self._log_prob(*args)) return self.exp(self._log_prob(*args, **kwargs))
def prob(self, *args): def prob(self, *args, **kwargs):
""" """
Evaluate the probability (pdf or pmf) at given value. Evaluate the probability (pdf or pmf) at given value.
...@@ -198,18 +198,18 @@ class Distribution(Cell): ...@@ -198,18 +198,18 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_prob(*args) return self._call_prob(*args, **kwargs)
def _calc_log_prob_from_prob(self, *args): def _calc_log_prob_from_prob(self, *args, **kwargs):
r""" r"""
Evaluate log probability from probability. Evaluate log probability from probability.
.. math:: .. math::
log_prob(x) = \log(prob(x)) log_prob(x) = \log(prob(x))
""" """
return self.log(self._prob(*args)) return self.log(self._prob(*args, **kwargs))
def cdf(self, *args): def cdf(self, *args, **kwargs):
""" """
Evaluate the cdf at given value. Evaluate the cdf at given value.
...@@ -217,36 +217,36 @@ class Distribution(Cell): ...@@ -217,36 +217,36 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_cdf(*args) return self._call_cdf(*args, **kwargs)
def _calc_cdf_from_log_cdf(self, *args): def _calc_cdf_from_log_cdf(self, *args, **kwargs):
r""" r"""
Evaluate cdf from log_cdf. Evaluate cdf from log_cdf.
.. math:: .. math::
cdf(x) = \exp(log_cdf(x)) cdf(x) = \exp(log_cdf(x))
""" """
return self.exp(self._log_cdf(*args)) return self.exp(self._log_cdf(*args, **kwargs))
def _calc_cdf_from_survival(self, *args): def _calc_cdf_from_survival(self, *args, **kwargs):
r""" r"""
Evaluate cdf from survival function. Evaluate cdf from survival function.
.. math:: .. math::
cdf(x) = 1 - (survival_function(x)) cdf(x) = 1 - (survival_function(x))
""" """
return 1.0 - self._survival_function(*args) return 1.0 - self._survival_function(*args, **kwargs)
def _calc_cdf_from_log_survival(self, *args): def _calc_cdf_from_log_survival(self, *args, **kwargs):
r""" r"""
Evaluate cdf from log survival function. Evaluate cdf from log survival function.
.. math:: .. math::
cdf(x) = 1 - (\exp(log_survival(x))) cdf(x) = 1 - (\exp(log_survival(x)))
""" """
return 1.0 - self.exp(self._log_survival(*args)) return 1.0 - self.exp(self._log_survival(*args, **kwargs))
def log_cdf(self, *args): def log_cdf(self, *args, **kwargs):
""" """
Evaluate the log cdf at given value. Evaluate the log cdf at given value.
...@@ -254,18 +254,18 @@ class Distribution(Cell): ...@@ -254,18 +254,18 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_log_cdf(*args) return self._call_log_cdf(*args, **kwargs)
def _calc_log_cdf_from_call_cdf(self, *args): def _calc_log_cdf_from_call_cdf(self, *args, **kwargs):
r""" r"""
Evaluate log cdf from cdf. Evaluate log cdf from cdf.
.. math:: .. math::
log_cdf(x) = \log(cdf(x)) log_cdf(x) = \log(cdf(x))
""" """
return self.log(self._call_cdf(*args)) return self.log(self._call_cdf(*args, **kwargs))
def survival_function(self, *args): def survival_function(self, *args, **kwargs):
""" """
Evaluate the survival function at given value. Evaluate the survival function at given value.
...@@ -273,27 +273,27 @@ class Distribution(Cell): ...@@ -273,27 +273,27 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_survival(*args) return self._call_survival(*args, **kwargs)
def _calc_survival_from_call_cdf(self, *args): def _calc_survival_from_call_cdf(self, *args, **kwargs):
r""" r"""
Evaluate survival function from cdf. Evaluate survival function from cdf.
.. math:: .. math::
survival_function(x) = 1 - (cdf(x)) survival_function(x) = 1 - (cdf(x))
""" """
return 1.0 - self._call_cdf(*args) return 1.0 - self._call_cdf(*args, **kwargs)
def _calc_survival_from_log_survival(self, *args): def _calc_survival_from_log_survival(self, *args, **kwargs):
r""" r"""
Evaluate survival function from log survival function. Evaluate survival function from log survival function.
.. math:: .. math::
survival(x) = \exp(survival_function(x)) survival(x) = \exp(survival_function(x))
""" """
return self.exp(self._log_survival(*args)) return self.exp(self._log_survival(*args, **kwargs))
def log_survival(self, *args): def log_survival(self, *args, **kwargs):
""" """
Evaluate the log survival function at given value. Evaluate the log survival function at given value.
...@@ -301,18 +301,18 @@ class Distribution(Cell): ...@@ -301,18 +301,18 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_log_survival(*args) return self._call_log_survival(*args, **kwargs)
def _calc_log_survival_from_call_survival(self, *args): def _calc_log_survival_from_call_survival(self, *args, **kwargs):
r""" r"""
Evaluate log survival function from survival function. Evaluate log survival function from survival function.
.. math:: .. math::
log_survival(x) = \log(survival_function(x)) log_survival(x) = \log(survival_function(x))
""" """
return self.log(self._call_survival(*args)) return self.log(self._call_survival(*args, **kwargs))
def kl_loss(self, *args): def kl_loss(self, *args, **kwargs):
""" """
Evaluate the KL divergence, i.e. KL(a||b). Evaluate the KL divergence, i.e. KL(a||b).
...@@ -320,72 +320,72 @@ class Distribution(Cell): ...@@ -320,72 +320,72 @@ class Distribution(Cell):
Args must include type of the distribution, parameters of distribution b. Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional. Parameters for distribution a are optional.
""" """
return self._kl_loss(*args) return self._kl_loss(*args, **kwargs)
def mean(self, *args): def mean(self, *args, **kwargs):
""" """
Evaluate the mean. Evaluate the mean.
Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._mean(*args) return self._mean(*args, **kwargs)
def mode(self, *args): def mode(self, *args, **kwargs):
""" """
Evaluate the mode. Evaluate the mode.
Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._mode(*args) return self._mode(*args, **kwargs)
def sd(self, *args): def sd(self, *args, **kwargs):
""" """
Evaluate the standard deviation. Evaluate the standard deviation.
Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_sd(*args) return self._call_sd(*args, **kwargs)
def var(self, *args): def var(self, *args, **kwargs):
""" """
Evaluate the variance. Evaluate the variance.
Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_var(*args) return self._call_var(*args, **kwargs)
def _calc_sd_from_var(self, *args): def _calc_sd_from_var(self, *args, **kwargs):
r""" r"""
Evaluate log probability from probability. Evaluate log probability from probability.
.. math:: .. math::
STD(x) = \sqrt(VAR(x)) STD(x) = \sqrt(VAR(x))
""" """
return self.sqrt(self._var(*args)) return self.sqrt(self._var(*args, **kwargs))
def _calc_var_from_sd(self, *args): def _calc_var_from_sd(self, *args, **kwargs):
r""" r"""
Evaluate log probability from probability. Evaluate log probability from probability.
.. math:: .. math::
VAR(x) = STD(x) ^ 2 VAR(x) = STD(x) ^ 2
""" """
return self.sq(self._sd(*args)) return self.sq(self._sd(*args, **kwargs))
def entropy(self, *args): def entropy(self, *args, **kwargs):
""" """
Evaluate the entropy. Evaluate the entropy.
Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._entropy(*args) return self._entropy(*args, **kwargs)
def cross_entropy(self, *args): def cross_entropy(self, *args, **kwargs):
""" """
Evaluate the cross_entropy between distribution a and b. Evaluate the cross_entropy between distribution a and b.
...@@ -393,32 +393,29 @@ class Distribution(Cell): ...@@ -393,32 +393,29 @@ class Distribution(Cell):
Args must include type of the distribution, parameters of distribution b. Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional. Parameters for distribution a are optional.
""" """
return self._call_cross_entropy(*args) return self._call_cross_entropy(*args, **kwargs)
def _calc_cross_entropy(self, *args): def _calc_cross_entropy(self, *args, **kwargs):
r""" r"""
Evaluate cross_entropy from entropy and kl divergence. Evaluate cross_entropy from entropy and kl divergence.
.. math:: .. math::
H(X, Y) = H(X) + KL(X||Y) H(X, Y) = H(X) + KL(X||Y)
""" """
return self._entropy(*args) + self._kl_loss(*args) return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs)
def sample(self, *args): def sample(self, *args, **kwargs):
""" """
Sampling function. Sampling function.
Args:
*args (list): arguments passed in through construct.
Note: Note:
Shape of the sample is default to (). Shape of the sample is default to ().
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._sample(*args) return self._sample(*args, **kwargs)
def construct(self, name, *args): def construct(self, name, *args, **kwargs):
""" """
Override construct in Cell. Override construct in Cell.
...@@ -433,31 +430,31 @@ class Distribution(Cell): ...@@ -433,31 +430,31 @@ class Distribution(Cell):
""" """
if name == 'log_prob': if name == 'log_prob':
return self._call_log_prob(*args) return self._call_log_prob(*args, **kwargs)
if name == 'prob': if name == 'prob':
return self._call_prob(*args) return self._call_prob(*args, **kwargs)
if name == 'cdf': if name == 'cdf':
return self._call_cdf(*args) return self._call_cdf(*args, **kwargs)
if name == 'log_cdf': if name == 'log_cdf':
return self._call_log_cdf(*args) return self._call_log_cdf(*args, **kwargs)
if name == 'survival_function': if name == 'survival_function':
return self._call_survival(*args) return self._call_survival(*args, **kwargs)
if name == 'log_survival': if name == 'log_survival':
return self._call_log_survival(*args) return self._call_log_survival(*args, **kwargs)
if name == 'kl_loss': if name == 'kl_loss':
return self._kl_loss(*args) return self._kl_loss(*args, **kwargs)
if name == 'mean': if name == 'mean':
return self._mean(*args) return self._mean(*args, **kwargs)
if name == 'mode': if name == 'mode':
return self._mode(*args) return self._mode(*args, **kwargs)
if name == 'sd': if name == 'sd':
return self._call_sd(*args) return self._call_sd(*args, **kwargs)
if name == 'var': if name == 'var':
return self._call_var(*args) return self._call_var(*args, **kwargs)
if name == 'entropy': if name == 'entropy':
return self._entropy(*args) return self._entropy(*args, **kwargs)
if name == 'cross_entropy': if name == 'cross_entropy':
return self._call_cross_entropy(*args) return self._call_cross_entropy(*args, **kwargs)
if name == 'sample': if name == 'sample':
return self._sample(*args) return self._sample(*args, **kwargs)
return None return None
...@@ -256,8 +256,5 @@ class Normal(Distribution): ...@@ -256,8 +256,5 @@ class Normal(Distribution):
sd = self._sd_value if sd is None else sd sd = self._sd_value if sd is None else sd
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
sample_shape = shape + batch_shape sample_shape = shape + batch_shape
mean_zero = self.const(0.0) sample_norm = C.normal(sample_shape, mean, sd, self.seed)
sd_one = self.const(1.0) return sample_norm
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample = mean + sample_norm * sd
return sample
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册