提交 60bb6beb 编写于 作者: P peixu_ren

Complement the arg passing conventions in distribution and bijector base classes

上级 c68b92e0
......@@ -69,31 +69,31 @@ class Bijector(Cell):
def is_injective(self):
return self._is_injective
def forward(self, *args):
def forward(self, *args, **kwargs):
"""
Forward transformation: transform the input value to another distribution.
"""
return self._forward(*args)
return self._forward(*args, **kwargs)
def inverse(self, *args):
def inverse(self, *args, **kwargs):
"""
Inverse transformation: transform the input value back to the original distribution.
"""
return self._inverse(*args)
return self._inverse(*args, **kwargs)
def forward_log_jacobian(self, *args):
def forward_log_jacobian(self, *args, **kwargs):
"""
Logarithm of the derivative of forward transformation.
"""
return self._forward_log_jacobian(*args)
return self._forward_log_jacobian(*args, **kwargs)
def inverse_log_jacobian(self, *args):
def inverse_log_jacobian(self, *args, **kwargs):
"""
Logarithm of the derivative of forward transformation.
"""
return self._inverse_log_jacobian(*args)
return self._inverse_log_jacobian(*args, **kwargs)
def __call__(self, *args):
def __call__(self, *args, **kwargs):
"""
Call Bijector directly.
This __call__ may go into two directions:
......@@ -107,9 +107,9 @@ class Bijector(Cell):
"""
if isinstance(args[0], Distribution):
return TransformedDistribution(self, args[0])
return super(Bijector, self).__call__(*args)
return super(Bijector, self).__call__(*args, **kwargs)
def construct(self, name, *args):
def construct(self, name, *args, **kwargs):
"""
Override construct in Cell.
......@@ -120,11 +120,11 @@ class Bijector(Cell):
Always raise RuntimeError as Distribution should not be called directly.
"""
if name == 'forward':
return self.forward(*args)
return self.forward(*args, **kwargs)
if name == 'inverse':
return self.inverse(*args)
return self.inverse(*args, **kwargs)
if name == 'forward_log_jacobian':
return self.forward_log_jacobian(*args)
return self.forward_log_jacobian(*args, **kwargs)
if name == 'inverse_log_jacobian':
return self.inverse_log_jacobian(*args)
return self.inverse_log_jacobian(*args, **kwargs)
return None
......@@ -27,7 +27,7 @@ class Distribution(Cell):
Note:
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Arguments should be passed in through *args.
and _log_prob. Arguments should be passed in through *args or **kwargs.
Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution.
......@@ -171,7 +171,7 @@ class Distribution(Cell):
if hasattr(self, '_cross_entropy'):
self._call_cross_entropy = self._cross_entropy
def log_prob(self, *args):
def log_prob(self, *args, **kwargs):
"""
Evaluate the log probability(pdf or pmf) at the given value.
......@@ -179,18 +179,18 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return self._call_log_prob(*args)
return self._call_log_prob(*args, **kwargs)
def _calc_prob_from_log_prob(self, *args):
def _calc_prob_from_log_prob(self, *args, **kwargs):
r"""
Evaluate prob from log probability.
.. math::
probability(x) = \exp(log_likehood(x))
"""
return self.exp(self._log_prob(*args))
return self.exp(self._log_prob(*args, **kwargs))
def prob(self, *args):
def prob(self, *args, **kwargs):
"""
Evaluate the probability (pdf or pmf) at given value.
......@@ -198,18 +198,18 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return self._call_prob(*args)
return self._call_prob(*args, **kwargs)
def _calc_log_prob_from_prob(self, *args):
def _calc_log_prob_from_prob(self, *args, **kwargs):
r"""
Evaluate log probability from probability.
.. math::
log_prob(x) = \log(prob(x))
"""
return self.log(self._prob(*args))
return self.log(self._prob(*args, **kwargs))
def cdf(self, *args):
def cdf(self, *args, **kwargs):
"""
Evaluate the cdf at given value.
......@@ -217,36 +217,36 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return self._call_cdf(*args)
return self._call_cdf(*args, **kwargs)
def _calc_cdf_from_log_cdf(self, *args):
def _calc_cdf_from_log_cdf(self, *args, **kwargs):
r"""
Evaluate cdf from log_cdf.
.. math::
cdf(x) = \exp(log_cdf(x))
"""
return self.exp(self._log_cdf(*args))
return self.exp(self._log_cdf(*args, **kwargs))
def _calc_cdf_from_survival(self, *args):
def _calc_cdf_from_survival(self, *args, **kwargs):
r"""
Evaluate cdf from survival function.
.. math::
cdf(x) = 1 - (survival_function(x))
"""
return 1.0 - self._survival_function(*args)
return 1.0 - self._survival_function(*args, **kwargs)
def _calc_cdf_from_log_survival(self, *args):
def _calc_cdf_from_log_survival(self, *args, **kwargs):
r"""
Evaluate cdf from log survival function.
.. math::
cdf(x) = 1 - (\exp(log_survival(x)))
"""
return 1.0 - self.exp(self._log_survival(*args))
return 1.0 - self.exp(self._log_survival(*args, **kwargs))
def log_cdf(self, *args):
def log_cdf(self, *args, **kwargs):
"""
Evaluate the log cdf at given value.
......@@ -254,18 +254,18 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return self._call_log_cdf(*args)
return self._call_log_cdf(*args, **kwargs)
def _calc_log_cdf_from_call_cdf(self, *args):
def _calc_log_cdf_from_call_cdf(self, *args, **kwargs):
r"""
Evaluate log cdf from cdf.
.. math::
log_cdf(x) = \log(cdf(x))
"""
return self.log(self._call_cdf(*args))
return self.log(self._call_cdf(*args, **kwargs))
def survival_function(self, *args):
def survival_function(self, *args, **kwargs):
"""
Evaluate the survival function at given value.
......@@ -273,27 +273,27 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return self._call_survival(*args)
return self._call_survival(*args, **kwargs)
def _calc_survival_from_call_cdf(self, *args):
def _calc_survival_from_call_cdf(self, *args, **kwargs):
r"""
Evaluate survival function from cdf.
.. math::
survival_function(x) = 1 - (cdf(x))
"""
return 1.0 - self._call_cdf(*args)
return 1.0 - self._call_cdf(*args, **kwargs)
def _calc_survival_from_log_survival(self, *args):
def _calc_survival_from_log_survival(self, *args, **kwargs):
r"""
Evaluate survival function from log survival function.
.. math::
survival(x) = \exp(survival_function(x))
"""
return self.exp(self._log_survival(*args))
return self.exp(self._log_survival(*args, **kwargs))
def log_survival(self, *args):
def log_survival(self, *args, **kwargs):
"""
Evaluate the log survival function at given value.
......@@ -301,18 +301,18 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return self._call_log_survival(*args)
return self._call_log_survival(*args, **kwargs)
def _calc_log_survival_from_call_survival(self, *args):
def _calc_log_survival_from_call_survival(self, *args, **kwargs):
r"""
Evaluate log survival function from survival function.
.. math::
log_survival(x) = \log(survival_function(x))
"""
return self.log(self._call_survival(*args))
return self.log(self._call_survival(*args, **kwargs))
def kl_loss(self, *args):
def kl_loss(self, *args, **kwargs):
"""
Evaluate the KL divergence, i.e. KL(a||b).
......@@ -320,72 +320,72 @@ class Distribution(Cell):
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
"""
return self._kl_loss(*args)
return self._kl_loss(*args, **kwargs)
def mean(self, *args):
def mean(self, *args, **kwargs):
"""
Evaluate the mean.
Note:
Dist_spec_args are optional.
"""
return self._mean(*args)
return self._mean(*args, **kwargs)
def mode(self, *args):
def mode(self, *args, **kwargs):
"""
Evaluate the mode.
Note:
Dist_spec_args are optional.
"""
return self._mode(*args)
return self._mode(*args, **kwargs)
def sd(self, *args):
def sd(self, *args, **kwargs):
"""
Evaluate the standard deviation.
Note:
Dist_spec_args are optional.
"""
return self._call_sd(*args)
return self._call_sd(*args, **kwargs)
def var(self, *args):
def var(self, *args, **kwargs):
"""
Evaluate the variance.
Note:
Dist_spec_args are optional.
"""
return self._call_var(*args)
return self._call_var(*args, **kwargs)
def _calc_sd_from_var(self, *args):
def _calc_sd_from_var(self, *args, **kwargs):
r"""
Evaluate log probability from probability.
.. math::
STD(x) = \sqrt(VAR(x))
"""
return self.sqrt(self._var(*args))
return self.sqrt(self._var(*args, **kwargs))
def _calc_var_from_sd(self, *args):
def _calc_var_from_sd(self, *args, **kwargs):
r"""
Evaluate log probability from probability.
.. math::
VAR(x) = STD(x) ^ 2
"""
return self.sq(self._sd(*args))
return self.sq(self._sd(*args, **kwargs))
def entropy(self, *args):
def entropy(self, *args, **kwargs):
"""
Evaluate the entropy.
Note:
Dist_spec_args are optional.
"""
return self._entropy(*args)
return self._entropy(*args, **kwargs)
def cross_entropy(self, *args):
def cross_entropy(self, *args, **kwargs):
"""
Evaluate the cross_entropy between distribution a and b.
......@@ -393,32 +393,29 @@ class Distribution(Cell):
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
"""
return self._call_cross_entropy(*args)
return self._call_cross_entropy(*args, **kwargs)
def _calc_cross_entropy(self, *args):
def _calc_cross_entropy(self, *args, **kwargs):
r"""
Evaluate cross_entropy from entropy and kl divergence.
.. math::
H(X, Y) = H(X) + KL(X||Y)
"""
return self._entropy(*args) + self._kl_loss(*args)
return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs)
def sample(self, *args):
def sample(self, *args, **kwargs):
"""
Sampling function.
Args:
*args (list): arguments passed in through construct.
Note:
Shape of the sample is default to ().
Dist_spec_args are optional.
"""
return self._sample(*args)
return self._sample(*args, **kwargs)
def construct(self, name, *args):
def construct(self, name, *args, **kwargs):
"""
Override construct in Cell.
......@@ -433,31 +430,31 @@ class Distribution(Cell):
"""
if name == 'log_prob':
return self._call_log_prob(*args)
return self._call_log_prob(*args, **kwargs)
if name == 'prob':
return self._call_prob(*args)
return self._call_prob(*args, **kwargs)
if name == 'cdf':
return self._call_cdf(*args)
return self._call_cdf(*args, **kwargs)
if name == 'log_cdf':
return self._call_log_cdf(*args)
return self._call_log_cdf(*args, **kwargs)
if name == 'survival_function':
return self._call_survival(*args)
return self._call_survival(*args, **kwargs)
if name == 'log_survival':
return self._call_log_survival(*args)
return self._call_log_survival(*args, **kwargs)
if name == 'kl_loss':
return self._kl_loss(*args)
return self._kl_loss(*args, **kwargs)
if name == 'mean':
return self._mean(*args)
return self._mean(*args, **kwargs)
if name == 'mode':
return self._mode(*args)
return self._mode(*args, **kwargs)
if name == 'sd':
return self._call_sd(*args)
return self._call_sd(*args, **kwargs)
if name == 'var':
return self._call_var(*args)
return self._call_var(*args, **kwargs)
if name == 'entropy':
return self._entropy(*args)
return self._entropy(*args, **kwargs)
if name == 'cross_entropy':
return self._call_cross_entropy(*args)
return self._call_cross_entropy(*args, **kwargs)
if name == 'sample':
return self._sample(*args)
return self._sample(*args, **kwargs)
return None
......@@ -256,8 +256,5 @@ class Normal(Distribution):
sd = self._sd_value if sd is None else sd
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample = mean + sample_norm * sd
return sample
sample_norm = C.normal(sample_shape, mean, sd, self.seed)
return sample_norm
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册