提交 2522285c 编写于 作者: J Joshua V. Dillon 提交者: TensorFlower Gardener

Allow fully dynamic batch/event overrides.

Previously it was assumed overrides would be statically on or off (either way the override itself could have been non-static). This cl allows for fully dynamic specification of batch_shape and/or event_shape override, ie, we no longer require that the decision to override shape be known during graph construction.
Change: 143576091
上级 e1eae197
......@@ -85,20 +85,20 @@ class DistributionTest(test.TestCase):
sigma = 2.
normal = ds.Normal(mu, sigma, validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event()))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch()))
normal = ds.Normal([mu], [sigma], validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event))
self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event()))
self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch()))
mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event))
self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event()))
self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch()))
mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event()))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch()))
# We now test every codepath within the underlying is_scalar_helper
# function.
......
......@@ -20,11 +20,13 @@ from __future__ import print_function
import numpy as np
from scipy import stats
from tensorflow.contrib import distributions
from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import bijector as bijector_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
......@@ -73,8 +75,7 @@ class TransformedDistributionTest(test.TestCase):
# Note: the Jacobian callable only works for this example; more generally
# you may or may not need a reduce_sum.
log_normal = ds.TransformedDistribution(
distribution=ds.Normal(
mu=mu, sigma=sigma),
distribution=ds.Normal(mu=mu, sigma=sigma),
bijector=bs.Exp(event_ndims=0))
sp_dist = stats.lognorm(s=sigma, scale=np.exp(mu))
......@@ -105,41 +106,38 @@ class TransformedDistributionTest(test.TestCase):
mu = 3.0
sigma = 0.02
log_normal = ds.TransformedDistribution(
distribution=ds.Normal(
mu=mu, sigma=sigma),
distribution=ds.Normal(mu=mu, sigma=sigma),
bijector=bs.Exp(event_ndims=0))
sample = log_normal.sample(1)
sample_val, log_pdf_val = sess.run([sample, log_normal.log_pdf(sample)])
self.assertAllClose(
stats.lognorm.logpdf(
sample_val, s=sigma, scale=np.exp(mu)),
stats.lognorm.logpdf(sample_val, s=sigma, scale=np.exp(mu)),
log_pdf_val,
atol=1e-2)
def testConditioning(self):
with self.test_session():
conditional_normal = ds.TransformedDistribution(
distribution=ds.Normal(
mu=0., sigma=1.),
distribution=ds.Normal(mu=0., sigma=1.),
bijector=_ChooseLocation(loc=[-100., 100.]))
z = [-1, +1, -1, -1, +1]
self.assertAllClose(
np.sign(
conditional_normal.sample(
5, bijector_kwargs={"z": z}).eval()),
z)
np.sign(conditional_normal.sample(
5, bijector_kwargs={"z": z}).eval()), z)
def testShapeChangingBijector(self):
with self.test_session():
softmax = bs.SoftmaxCentered()
standard_normal = ds.Normal(mu=0., sigma=1.)
multi_logit_normal = ds.TransformedDistribution(
distribution=standard_normal, bijector=softmax)
x = [[-np.log(3.), 0.], [np.log(3), np.log(5)]]
distribution=standard_normal,
bijector=softmax)
x = [[-np.log(3.), 0.],
[np.log(3), np.log(5)]]
y = softmax.forward(x).eval()
expected_log_pdf = (stats.norm(
loc=0., scale=1.).logpdf(x) - np.sum(np.log(y), axis=-1))
expected_log_pdf = (stats.norm(loc=0., scale=1.).logpdf(x) -
np.sum(np.log(y), axis=-1))
self.assertAllClose(expected_log_pdf,
multi_logit_normal.log_prob(y).eval())
self.assertAllClose(
......@@ -152,7 +150,9 @@ class TransformedDistributionTest(test.TestCase):
with self.test_session():
shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
diag = np.array([[1, 2, 3], [2, 3, 2]], dtype=np.float32)
actual_mvn = ds.MultivariateNormalDiag(shift, diag, validate_args=True)
actual_mvn_entropy = np.concatenate([
[stats.multivariate_normal(shift[i], np.diag(diag[i]**2)).entropy()]
for i in range(len(diag))])
fake_mvn = ds.TransformedDistribution(
ds.MultivariateNormalDiag(
array_ops.zeros_like(shift),
......@@ -160,11 +160,10 @@ class TransformedDistributionTest(test.TestCase):
validate_args=True),
bs.AffineLinearOperator(
shift,
scale=la.LinearOperatorDiag(
diag, is_non_singular=True),
scale=la.LinearOperatorDiag(diag, is_non_singular=True),
validate_args=True),
validate_args=True)
self.assertAllClose(actual_mvn.entropy().eval(),
self.assertAllClose(actual_mvn_entropy,
fake_mvn.entropy().eval())
......@@ -181,121 +180,151 @@ class ScalarToMultiTest(test.TestCase):
dtype=np.float32)
def _testMVN(self,
base_distribution,
batch_shape=None,
event_shape=None,
base_distribution_class,
base_distribution_kwargs,
batch_shape=(),
event_shape=(),
not_implemented_message=None):
with self.test_session() as sess:
# Overriding shapes must be compatible w/bijector; most bijectors are
# batch_shape agnostic and only care about event_ndims.
# In the case of `Affine`, if we got it wrong then it would fire an
# exception due to incompatible dimensions.
fake_mvn = ds.TransformedDistribution(
distribution=base_distribution[0](validate_args=True,
**base_distribution[1]),
bijector=bs.Affine(
shift=self._shift, scale_tril=self._tril),
batch_shape_pl = array_ops.placeholder(
dtypes.int32, name="dynamic_batch_shape")
event_shape_pl = array_ops.placeholder(
dtypes.int32, name="dynamic_event_shape")
feed_dict = {batch_shape_pl: np.array(batch_shape, dtype=np.int32),
event_shape_pl: np.array(event_shape, dtype=np.int32)}
fake_mvn_dynamic = ds.TransformedDistribution(
distribution=base_distribution_class(validate_args=True,
**base_distribution_kwargs),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=batch_shape_pl,
event_shape=event_shape_pl,
validate_args=True)
fake_mvn_static = ds.TransformedDistribution(
distribution=base_distribution_class(validate_args=True,
**base_distribution_kwargs),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=True)
actual_mean = np.tile(self._shift, [2, 1]) # Affine elided this tile.
actual_cov = np.matmul(self._tril, np.transpose(self._tril, [0, 2, 1]))
actual_mvn = ds.MultivariateNormalFull(mu=actual_mean, sigma=actual_cov)
# Ensure sample works by checking first, second moments.
n = 5e3
y = fake_mvn.sample(int(n), seed=0)
sample_mean = math_ops.reduce_mean(y, 0)
centered_y = array_ops.transpose(y - sample_mean, [1, 2, 0])
sample_cov = math_ops.matmul(centered_y, centered_y, transpose_b=True) / n
[sample_mean_, sample_cov_] = sess.run([sample_mean, sample_cov])
self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1)
self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1)
# Ensure all other functions work as intended.
x = fake_mvn.sample(5, seed=0).eval()
self.assertAllEqual([5, 2, 3], x.shape)
self.assertAllEqual(actual_mvn.get_event_shape(),
fake_mvn.get_event_shape())
self.assertAllEqual(actual_mvn.event_shape().eval(),
fake_mvn.event_shape().eval())
self.assertAllEqual(actual_mvn.get_batch_shape(),
fake_mvn.get_batch_shape())
self.assertAllEqual(actual_mvn.batch_shape().eval(),
fake_mvn.batch_shape().eval())
self.assertAllClose(
actual_mvn.log_prob(x).eval(),
fake_mvn.log_prob(x).eval(),
atol=0.,
rtol=1e-7)
self.assertAllClose(
actual_mvn.prob(x).eval(),
fake_mvn.prob(x).eval(),
atol=0.,
rtol=1e-6)
self.assertAllClose(
actual_mvn.entropy().eval(),
fake_mvn.entropy().eval(),
atol=0.,
rtol=1e-6)
for unsupported_fn in (fake_mvn.log_cdf, fake_mvn.cdf,
fake_mvn.survival_function,
fake_mvn.log_survival_function):
def actual_mvn_log_prob(x):
return np.concatenate([
[stats.multivariate_normal(
actual_mean[i], actual_cov[i]).logpdf(x[:, i, :])]
for i in range(len(actual_cov))]).T
actual_mvn_entropy = np.concatenate([
[stats.multivariate_normal(
actual_mean[i], actual_cov[i]).entropy()]
for i in range(len(actual_cov))])
self.assertAllEqual([3], fake_mvn_static.get_event_shape())
self.assertAllEqual([2], fake_mvn_static.get_batch_shape())
self.assertAllEqual(tensor_shape.TensorShape(None),
fake_mvn_dynamic.get_event_shape())
self.assertAllEqual(tensor_shape.TensorShape(None),
fake_mvn_dynamic.get_batch_shape())
x = fake_mvn_static.sample(5, seed=0).eval()
for unsupported_fn in (fake_mvn_static.log_cdf,
fake_mvn_static.cdf,
fake_mvn_static.survival_function,
fake_mvn_static.log_survival_function):
with self.assertRaisesRegexp(NotImplementedError,
not_implemented_message):
self.assertRaisesRegexp(unsupported_fn(x))
unsupported_fn(x)
num_samples = 5e3
for fake_mvn, feed_dict in ((fake_mvn_static, {}),
(fake_mvn_dynamic, feed_dict)):
# Ensure sample works by checking first, second moments.
y = fake_mvn.sample(int(num_samples), seed=0)
x = y[0:5, ...]
sample_mean = math_ops.reduce_mean(y, 0)
centered_y = array_ops.transpose(y - sample_mean, [1, 2, 0])
sample_cov = math_ops.matmul(
centered_y, centered_y, transpose_b=True) / num_samples
[
sample_mean_,
sample_cov_,
x_,
fake_event_shape_,
fake_batch_shape_,
fake_log_prob_,
fake_prob_,
fake_entropy_,
] = sess.run([
sample_mean,
sample_cov,
x,
fake_mvn.event_shape(),
fake_mvn.batch_shape(),
fake_mvn.log_prob(x),
fake_mvn.prob(x),
fake_mvn.entropy(),
], feed_dict=feed_dict)
self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1)
self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1)
# Ensure all other functions work as intended.
self.assertAllEqual([5, 2, 3], x_.shape)
self.assertAllEqual([3], fake_event_shape_)
self.assertAllEqual([2], fake_batch_shape_)
self.assertAllClose(actual_mvn_log_prob(x_), fake_log_prob_,
atol=0., rtol=1e-6)
self.assertAllClose(np.exp(actual_mvn_log_prob(x_)), fake_prob_,
atol=0., rtol=1e-5)
self.assertAllClose(actual_mvn_entropy, fake_entropy_,
atol=0., rtol=1e-6)
def testScalarBatchScalarEvent(self):
self._testMVN(
base_distribution=[ds.Normal, {
"mu": 0.,
"sigma": 1.
}],
base_distribution_class=ds.Normal,
base_distribution_kwargs={"mu": 0., "sigma": 1.},
batch_shape=[2],
event_shape=[3],
not_implemented_message="not implemented when overriding event_shape")
def testScalarBatchNonScalarEvent(self):
self._testMVN(
base_distribution=[
ds.MultivariateNormalDiag, {
"mu": [0., 0., 0.],
"diag_stdev": [1., 1, 1]
}
],
base_distribution_class=ds.MultivariateNormalDiag,
base_distribution_kwargs={"mu": [0., 0., 0.], "diag_stdev": [1., 1, 1]},
batch_shape=[2],
not_implemented_message="not implemented$")
with self.test_session():
# Can't override event_shape for scalar batch, non-scalar event.
with self.assertRaisesRegexp(ValueError, "requires scalar"):
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
ds.TransformedDistribution(
distribution=ds.MultivariateNormalDiag(
mu=[0.], diag_stdev=[1.]),
bijector=bs.Affine(
shift=self._shift, scale_tril=self._tril),
distribution=ds.MultivariateNormalDiag(mu=[0.], diag_stdev=[1.]),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=[2],
event_shape=[3],
validate_args=True)
def testNonScalarBatchScalarEvent(self):
self._testMVN(
base_distribution=[ds.Normal, {
"mu": [0., 0],
"sigma": [1., 1]
}],
base_distribution_class=ds.Normal,
base_distribution_kwargs={"mu": [0., 0], "sigma": [1., 1]},
event_shape=[3],
not_implemented_message="not implemented when overriding event_shape")
with self.test_session():
# Can't override batch_shape for non-scalar batch, scalar event.
with self.assertRaisesRegexp(ValueError, "requires scalar"):
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
ds.TransformedDistribution(
distribution=ds.Normal(
mu=[0.], sigma=[1.]),
bijector=bs.Affine(
shift=self._shift, scale_tril=self._tril),
distribution=ds.Normal(mu=[0.], sigma=[1.]),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=[2],
event_shape=[3],
validate_args=True)
......@@ -304,12 +333,11 @@ class ScalarToMultiTest(test.TestCase):
with self.test_session():
# Can't override event_shape and/or batch_shape for non_scalar batch,
# non-scalar event.
with self.assertRaisesRegexp(ValueError, "requires scalar"):
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
ds.TransformedDistribution(
distribution=ds.MultivariateNormalDiag(
mu=[[0.]], diag_stdev=[[1.]]),
bijector=bs.Affine(
shift=self._shift, scale_tril=self._tril),
distribution=ds.MultivariateNormalDiag(mu=[[0.]],
diag_stdev=[[1.]]),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=[2],
event_shape=[3],
validate_args=True)
......
......@@ -526,19 +526,33 @@ class Distribution(_BaseDistribution):
"""
return self._get_event_shape()
@property
def is_scalar_event(self):
"""Indicates that `event_shape==[]`."""
return ops.convert_to_tensor(
self._is_scalar_helper(self.get_event_shape, self.event_shape),
name="is_scalar_event")
def is_scalar_event(self, name="is_scalar_event"):
"""Indicates that `event_shape == []`.
@property
def is_scalar_batch(self):
"""Indicates that `batch_shape==[]`."""
return ops.convert_to_tensor(
self._is_scalar_helper(self.get_batch_shape, self.batch_shape),
name="is_scalar_batch")
Args:
name: The name to give this op.
Returns:
is_scalar_event: `Boolean` `scalar` `Tensor`.
"""
with self._name_scope(name):
return ops.convert_to_tensor(
self._is_scalar_helper(self.get_event_shape, self.event_shape),
name="is_scalar_event")
def is_scalar_batch(self, name="is_scalar_batch"):
"""Indicates that `batch_shape == []`.
Args:
name: The name to give this op.
Returns:
is_scalar_batch: `Boolean` `scalar` `Tensor`.
"""
with self._name_scope(name):
return ops.convert_to_tensor(
self._is_scalar_helper(self.get_batch_shape, self.batch_shape),
name="is_scalar_batch")
def _sample_n(self, n, seed=None):
raise NotImplementedError("sample_n is not implemented")
......@@ -888,7 +902,7 @@ class Distribution(_BaseDistribution):
"""Helper function to standardize op scope."""
with ops.name_scope(self.name):
with ops.name_scope(name, values=(
(values or []) + self._graph_parents)) as scope:
([] if values is None else values) + self._graph_parents)) as scope:
yield scope
def _expand_sample_shape_to_vector(self, x, name):
......
......@@ -22,6 +22,8 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import bijector as bijectors
from tensorflow.contrib.distributions.python.ops import distribution as distributions
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
......@@ -47,19 +49,46 @@ _condition_kwargs_dict = {
# graph construction.
def _static_value(x):
"""Returns the static value of a `Tensor` or `None`."""
return tensor_util.constant_value(ops.convert_to_tensor(x))
def _logical_and(*args):
"""Convenience function which attempts to statically `reduce_all`."""
args_ = [_static_value(x) for x in args]
if any(x is not None and not bool(x) for x in args_):
return constant_op.constant(False)
if all(x is not None and bool(x) for x in args_):
return constant_op.constant(True)
if len(args) == 2:
return math_ops.logical_and(*args)
return math_ops.reduce_all(args)
def _logical_equal(x, y):
"""Convenience function which attempts to statically compute `x == y`."""
x_ = _static_value(x)
y_ = _static_value(y)
if x_ is None or y_ is None:
return math_ops.equal(x, y)
return constant_op.constant(np.array_equal(x_, y_))
def _logical_not(x):
"""Convenience function which attempts to statically apply `logical_not`."""
if tensor_util.constant_value(x) is not None:
return not tensor_util.constant_value(x)
return math_ops.logical_not(x)
x_ = _static_value(x)
if x_ is None:
return math_ops.logical_not(x)
return constant_op.constant(np.logical_not(x_))
def _concat_vectors(*args):
"""Convenience function which flattens input vectors."""
vals = [tensor_util.constant_value(ops.convert_to_tensor(x)) for x in args]
if any(x is None for x in vals):
"""Convenience function which concatenates input vectors."""
args_ = [_static_value(x) for x in args]
if any(x_ is None for x_ in args_):
return array_ops.concat_v2(args, 0)
return [x for v in vals for x in v]
return constant_op.constant([x_ for vec_ in args_ for x_ in vec_])
def _pick_scalar_condition(pred, cond_true, cond_false):
......@@ -67,20 +96,36 @@ def _pick_scalar_condition(pred, cond_true, cond_false):
# Note: This function is only valid if all of pred, cond_true, and cond_false
# are scalars. This means its semantics are arguably more like tf.cond than
# tf.select even though we use tf.select to implement it.
pred_static = tensor_util.constant_value(pred)
if pred_static is None:
return math_ops.select(pred, cond_true, cond_false)
return cond_true if pred_static else cond_false
pred_ = _static_value(pred)
if pred_ is None:
return array_ops.where(pred, cond_true, cond_false)
return cond_true if pred_ else cond_false
def _ones_like(x):
"""Convenience function attempts to statically construct `ones_like`."""
# Should only be used for small vectors.
if x.get_shape().is_fully_defined():
return np.ones(x.get_shape().as_list(), dtype=x.dtype.as_numpy_dtype())
return array_ops.ones(x.get_shape().as_list(), dtype=x.dtype)
return array_ops.ones_like(x)
def _ndims_from_shape(shape):
"""Returns `Tensor`'s `rank` implied by a `Tensor` shape."""
if shape.get_shape().ndims not in (None, 1):
raise ValueError("input is not a valid shape: not 1D")
if not shape.dtype.is_integer:
raise TypeError("input is not a valid shape: wrong dtype")
if shape.get_shape().is_fully_defined():
return constant_op.constant(shape.get_shape().as_list()[0])
return array_ops.shape(shape)[0]
def _is_scalar_from_shape(shape):
"""Returns `True` `Tensor` if `Tensor` shape implies a scalar."""
return _logical_equal(_ndims_from_shape(shape), 0)
class TransformedDistribution(distributions.Distribution):
"""A Transformed Distribution.
......@@ -234,9 +279,9 @@ class TransformedDistribution(distributions.Distribution):
bijector: The object responsible for calculating the transformation.
Typically an instance of `Bijector`. `None` means `Identity()`.
batch_shape: `integer` vector `Tensor` which overrides `distribution`
`batch_shape`; valid only if `distribution.is_scalar_batch`.
`batch_shape`; valid only if `distribution.is_scalar_batch()`.
event_shape: `integer` vector `Tensor` which overrides `distribution`
`event_shape`; valid only if `distribution.is_scalar_event`.
`event_shape`; valid only if `distribution.is_scalar_event()`.
validate_args: Python Boolean. Whether to validate input with asserts.
If `validate_args` is `False`, and the inputs are invalid,
correct behavior is not guaranteed.
......@@ -245,53 +290,55 @@ class TransformedDistribution(distributions.Distribution):
"""
parameters = locals()
parameters.pop("self")
if bijector is None:
bijector = bijectors.Identity(validate_args=validate_args)
name = name or bijector.name + distribution.name
name = name or (("" if bijector is None else bijector.name) +
distribution.name)
with ops.name_scope(name, values=[event_shape, batch_shape]):
if batch_shape is not None:
batch_shape = self._maybe_validate_shape_override(
ops.convert_to_tensor(batch_shape, name="batch_shape"),
distribution.is_scalar_batch, validate_args)
self._override_batch_shape = batch_shape
if event_shape is not None:
event_shape = self._maybe_validate_shape_override(
ops.convert_to_tensor(event_shape, name="event_shape"),
distribution.is_scalar_event, validate_args)
self._override_event_ndims = (
event_shape.get_shape().ndims
if event_shape.get_shape().ndims is not None
else array_ops.rank(event_shape, name="event_ndims"))
else:
self._override_event_ndims = 0
self._override_event_shape = event_shape
# For convenience we define some handy constants.
self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero")
self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty")
if bijector is None:
bijector = bijectors.Identity(validate_args=validate_args)
# We will keep track of a static and dynamic version of
# self._is_{batch,event}_override. This way we can do more prior to graph
# execution, including possibly raising Python exceptions.
self._override_batch_shape = self._maybe_validate_shape_override(
batch_shape, distribution.is_scalar_batch(), validate_args,
"batch_shape")
self._is_batch_override = _logical_not(_logical_equal(
_ndims_from_shape(self._override_batch_shape), self._zero))
self._is_maybe_batch_override = bool(
tensor_util.constant_value(self._override_batch_shape) is None or
tensor_util.constant_value(self._override_batch_shape))
self._override_event_shape = self._maybe_validate_shape_override(
event_shape, distribution.is_scalar_event(), validate_args,
"event_shape")
self._is_event_override = _logical_not(_logical_equal(
_ndims_from_shape(self._override_event_shape), self._zero))
self._is_maybe_event_override = bool(
tensor_util.constant_value(self._override_event_shape) is None or
tensor_util.constant_value(self._override_event_shape))
# To convert a scalar distribution into a multivariate distribution we
# will draw dims from the sample dims, which are otherwise iid. This is
# easy to do except in the case that:
# batch_shape is None and
# event_shape is not None and
# not distribution.is_scalar_batch.
# When that case happens the event dims will incorrectly be to the left of
# the batch dims. In this case we'll cyclically permute left the new dims.
if batch_shape is None and event_shape is not None:
self._needs_rotation = ops.convert_to_tensor(
_logical_not(distribution.is_scalar_batch), name="needs_rotation")
n = _pick_scalar_condition(self._needs_rotation,
self._override_event_ndims, 0)
# We'll be reducing the head dims (if at all), i.e., this will be []
# if we don't need to reduce.
self._reduce_event_indices = math_ops.range(
n - self._override_event_ndims, n)
else:
self._needs_rotation = ops.convert_to_tensor(False,
name="needs_rotation")
# We'll be reducing the tail dims (if at all), i.e., this will be []
# if we don't need to reduce.
self._reduce_event_indices = (
math_ops.range(-self._override_event_ndims, 0)
if event_shape is not None else [])
# easy to do except in the case that the base distribution has batch dims
# and we're overriding event shape. When that case happens the event dims
# will incorrectly be to the left of the batch dims. In this case we'll
# cyclically permute left the new dims.
self._needs_rotation = _logical_and(
self._is_event_override,
_logical_not(self._is_batch_override),
_logical_not(distribution.is_scalar_batch()))
override_event_ndims = _ndims_from_shape(self._override_event_shape)
self._rotate_ndims = _pick_scalar_condition(
self._needs_rotation, override_event_ndims, 0)
# We'll be reducing the head dims (if at all), i.e., this will be []
# if we don't need to reduce.
self._reduce_event_indices = math_ops.range(
self._rotate_ndims - override_event_ndims, self._rotate_ndims)
self._distribution = distribution
self._bijector = bijector
......@@ -319,28 +366,29 @@ class TransformedDistribution(distributions.Distribution):
return self._bijector
def _event_shape(self):
return self.bijector.forward_event_shape(
self.distribution.event_shape()
if self._override_event_shape is None
else self._override_event_shape)
return self.bijector.forward_event_shape(distribution_util.pick_vector(
self._is_event_override,
self._override_event_shape,
self.distribution.event_shape()))
def _get_event_shape(self):
static_override = tensor_util.constant_value(self._override_event_shape)
return self.bijector.get_forward_event_shape(
self.distribution.get_event_shape()
if self._override_event_shape is None
else tensor_shape.TensorShape(
tensor_util.constant_value(self._override_event_shape)))
if static_override is not None and not static_override
else tensor_shape.TensorShape(static_override))
def _batch_shape(self):
if self._override_batch_shape is None:
return self.distribution.batch_shape()
return self._override_batch_shape
return distribution_util.pick_vector(
self._is_batch_override,
self._override_batch_shape,
self.distribution.batch_shape())
def _get_batch_shape(self):
if self._override_batch_shape is None:
static_override = tensor_util.constant_value(self._override_batch_shape)
if static_override is not None and not static_override:
return self.distribution.get_batch_shape()
return tensor_shape.TensorShape(tensor_util.constant_value(
self._override_batch_shape))
return tensor_shape.TensorShape(static_override)
@distribution_util.AppendDocstring(
"""Samples from the base distribution and then passes through
......@@ -350,20 +398,11 @@ class TransformedDistribution(distributions.Distribution):
bijector_kwargs=None, distribution_kwargs=None):
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
if (self._override_batch_shape is None and
self._override_event_shape is None):
sample_shape = [n]
else:
if (self._override_batch_shape is not None and
self._override_event_shape is not None):
sample_shape = [[n],
self._override_batch_shape,
self._override_event_shape]
elif self._override_batch_shape is not None:
sample_shape = [[n], self._override_batch_shape]
elif self._override_event_shape is not None:
sample_shape = [self._override_event_shape, [n]]
sample_shape = _concat_vectors(*sample_shape)
sample_shape = _concat_vectors(
distribution_util.pick_vector(self._needs_rotation, self._empty, [n]),
self._override_batch_shape,
self._override_event_shape,
distribution_util.pick_vector(self._needs_rotation, [n], self._empty))
x = self.distribution.sample(sample_shape=sample_shape, seed=seed,
**distribution_kwargs)
x = self._maybe_rotate_dims(x)
......@@ -383,7 +422,7 @@ class TransformedDistribution(distributions.Distribution):
y, **bijector_kwargs)
x = self._maybe_rotate_dims(x, rotate_right=True)
log_prob = self.distribution.log_prob(x, **distribution_kwargs)
if self._override_event_shape is not None:
if self._is_maybe_event_override:
log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
return ildj + log_prob
......@@ -401,14 +440,14 @@ class TransformedDistribution(distributions.Distribution):
y, **bijector_kwargs)
x = self._maybe_rotate_dims(x, rotate_right=True)
prob = self.distribution.prob(x, **distribution_kwargs)
if self._override_event_shape is not None:
if self._is_maybe_event_override:
prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
return math_ops.exp(ildj) * prob
@distribution_util.AppendDocstring(
condition_kwargs_dict=_condition_kwargs_dict)
def _log_cdf(self, y, bijector_kwargs=None, distribution_kwargs=None):
if self._override_event_shape is not None:
if self._is_maybe_event_override:
raise NotImplementedError("log_cdf is not implemented when overriding "
"event_shape")
bijector_kwargs = bijector_kwargs or {}
......@@ -419,7 +458,7 @@ class TransformedDistribution(distributions.Distribution):
@distribution_util.AppendDocstring(
condition_kwargs_dict=_condition_kwargs_dict)
def _cdf(self, y, bijector_kwargs=None, distribution_kwargs=None):
if self._override_event_shape is not None:
if self._is_maybe_event_override:
raise NotImplementedError("cdf is not implemented when overriding "
"event_shape")
bijector_kwargs = bijector_kwargs or {}
......@@ -431,7 +470,7 @@ class TransformedDistribution(distributions.Distribution):
condition_kwargs_dict=_condition_kwargs_dict)
def _log_survival_function(self, y,
bijector_kwargs=None, distribution_kwargs=None):
if self._override_event_shape is not None:
if self._is_maybe_event_override:
raise NotImplementedError("log_survival_function is not implemented when "
"overriding event_shape")
bijector_kwargs = bijector_kwargs or {}
......@@ -443,7 +482,7 @@ class TransformedDistribution(distributions.Distribution):
condition_kwargs_dict=_condition_kwargs_dict)
def _survival_function(self, y,
bijector_kwargs=None, distribution_kwargs=None):
if self._override_event_shape is not None:
if self._is_maybe_event_override:
raise NotImplementedError("survival_function is not implemented when "
"overriding event_shape")
bijector_kwargs = bijector_kwargs or {}
......@@ -462,64 +501,77 @@ class TransformedDistribution(distributions.Distribution):
# E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
# where c can by anything.
entropy = self.distribution.entropy()
if self._override_event_shape is not None:
if self._is_maybe_event_override:
# H[X] = sum_i H[X_i] if X_i are mutually independent.
# This means that a reduce_sum is a simple rescaling.
entropy *= math_ops.cast(math_ops.reduce_prod(self._override_event_shape),
dtype=entropy.dtype.base_dtype)
if self._override_batch_shape is not None:
entropy = array_ops.reshape(entropy,
_ones_like(self._override_batch_shape))
entropy = array_ops.tile(entropy, self._override_batch_shape)
if self._is_maybe_batch_override:
new_shape = array_ops.concat_v2([
_ones_like(self._override_batch_shape),
self.distribution.batch_shape()], 0)
entropy = array_ops.reshape(entropy, new_shape)
multiples = array_ops.concat_v2([
self._override_batch_shape,
_ones_like(self.distribution.batch_shape())], 0)
entropy = array_ops.tile(entropy, multiples)
dummy = 0.
return entropy - self.bijector.inverse_log_det_jacobian(dummy)
def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
validate_args):
validate_args, name):
"""Helper to __init__ which ensures override batch/event_shape are valid."""
if override_shape is None:
override_shape = []
override_shape = ops.convert_to_tensor(override_shape, dtype=dtypes.int32,
name=name)
if not override_shape.dtype.is_integer:
raise TypeError("shape override must be an integer")
override_is_scalar = _is_scalar_from_shape(override_shape)
if tensor_util.constant_value(override_is_scalar):
return self._empty
dynamic_assertions = []
if override_shape.get_shape().ndims is not None:
if override_shape.get_shape().ndims != 1:
raise ValueError("shape override must be a vector")
elif validate_args:
is_vector = check_ops.assert_rank(
dynamic_assertions += [check_ops.assert_rank(
override_shape, 1,
message="shape override must be a vector")
override_shape = control_flow_ops.with_dependencies(
[is_vector], override_shape)
message="shape override must be a vector")]
if override_shape.get_shape().is_fully_defined():
if any(s <= 0 for s in override_shape.get_shape().as_list()):
if tensor_util.constant_value(override_shape) is not None:
if any(s <= 0 for s in tensor_util.constant_value(override_shape)):
raise ValueError("shape override must have positive elements")
elif validate_args:
is_positive = check_ops.assert_positive(
dynamic_assertions += [check_ops.assert_positive(
override_shape,
message="shape override must have positive elements")
override_shape = control_flow_ops.with_dependencies(
[is_positive], override_shape)
message="shape override must have positive elements")]
if tensor_util.constant_value(base_is_scalar) is not None:
if not tensor_util.constant_value(base_is_scalar):
raise ValueError("shape override requires scalar distribution.")
is_both_nonscalar = _logical_and(_logical_not(base_is_scalar),
_logical_not(override_is_scalar))
if tensor_util.constant_value(is_both_nonscalar) is not None:
if tensor_util.constant_value(is_both_nonscalar):
raise ValueError("base distribution not scalar")
elif validate_args:
is_scalar = check_ops.assert_equal(
base_is_scalar, True,
message="shape override requires scalar distribution.")
override_shape = control_flow_ops.with_dependencies(
[is_scalar], override_shape)
dynamic_assertions += [check_ops.assert_equal(
is_both_nonscalar, False,
message="base distribution not scalar")]
return override_shape
if not dynamic_assertions:
return override_shape
return control_flow_ops.with_dependencies(
dynamic_assertions, override_shape)
def _maybe_rotate_dims(self, x, rotate_right=False):
"""Helper which rolls left event_dims left or right event_dims right."""
if tensor_util.constant_value(self._needs_rotation) is False:
return x
ndims = array_ops.rank(x)
n = _pick_scalar_condition(self._needs_rotation,
self._override_event_ndims, 0)
if rotate_right:
n = ndims - n
n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims
return array_ops.transpose(
x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册