提交 9083e9dc 编写于 作者: X Xun Deng

fixed prob, survival function of exponential distribution

上级 39e27911
......@@ -198,9 +198,9 @@ class Exponential(Distribution):
return self._entropy(rate) + self._kl_loss(dist, rate_b, rate)
def _prob(self, value, rate=None):
def _log_prob(self, value, rate=None):
r"""
pdf of Exponential distribution.
log_pdf of Exponential distribution.
Args:
Args:
......@@ -211,15 +211,16 @@ class Exponential(Distribution):
Value should be greater or equal to zero.
.. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
log_pdf(x) = \log(rate) - rate * x if x >= 0 else 0
"""
value = self._check_value(value, "value")
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
prob = self.exp(self.log(rate) - rate * value)
prob = self.log(rate) - rate * value
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
comp = self.less(value, zeros)
return self.select(comp, zeros, prob)
return self.select(comp, neginf, prob)
def _cdf(self, value, rate=None):
r"""
......@@ -243,6 +244,27 @@ class Exponential(Distribution):
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)
def _log_survival(self, value, rate=None):
r"""
log survival_function of Exponential distribution.
Args:
value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate.
Note:
Value should be greater or equal to zero.
.. math::
log_survival_function(x) = -1 * \lambda * x if x >= 0 else 0
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
sf = -1. * rate * value
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, sf)
def _kl_loss(self, dist, rate_b, rate=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册