diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index 4f22529d6bf829d373853073400df0fef10d8e1b..325c6bdc37ef686cdf5d9fc2732adb44457ea379 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -85,20 +85,20 @@ class DistributionTest(test.TestCase): sigma = 2. normal = ds.Normal(mu, sigma, validate_args=True) - self.assertTrue(tensor_util.constant_value(normal.is_scalar_event)) - self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch)) + self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) + self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch())) normal = ds.Normal([mu], [sigma], validate_args=True) - self.assertTrue(tensor_util.constant_value(normal.is_scalar_event)) - self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch)) + self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) + self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch())) mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True) - self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event)) - self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch)) + self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) + self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch())) mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) - self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event)) - self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch)) + self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) + self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch())) # We now test every codepath within the underlying is_scalar_helper # function. diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 49169d9cf0ec271c291799815e5e5fef7dd69cfd..4770cd8b95646300a9651ed03de79816c4583bd3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -20,11 +20,13 @@ from __future__ import print_function import numpy as np from scipy import stats + from tensorflow.contrib import distributions from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijector as bijector_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -73,8 +75,7 @@ class TransformedDistributionTest(test.TestCase): # Note: the Jacobian callable only works for this example; more generally # you may or may not need a reduce_sum. log_normal = ds.TransformedDistribution( - distribution=ds.Normal( - mu=mu, sigma=sigma), + distribution=ds.Normal(mu=mu, sigma=sigma), bijector=bs.Exp(event_ndims=0)) sp_dist = stats.lognorm(s=sigma, scale=np.exp(mu)) @@ -105,41 +106,38 @@ class TransformedDistributionTest(test.TestCase): mu = 3.0 sigma = 0.02 log_normal = ds.TransformedDistribution( - distribution=ds.Normal( - mu=mu, sigma=sigma), + distribution=ds.Normal(mu=mu, sigma=sigma), bijector=bs.Exp(event_ndims=0)) sample = log_normal.sample(1) sample_val, log_pdf_val = sess.run([sample, log_normal.log_pdf(sample)]) self.assertAllClose( - stats.lognorm.logpdf( - sample_val, s=sigma, scale=np.exp(mu)), + stats.lognorm.logpdf(sample_val, s=sigma, scale=np.exp(mu)), log_pdf_val, atol=1e-2) def testConditioning(self): with self.test_session(): conditional_normal = ds.TransformedDistribution( - distribution=ds.Normal( - mu=0., sigma=1.), + distribution=ds.Normal(mu=0., sigma=1.), bijector=_ChooseLocation(loc=[-100., 100.])) z = [-1, +1, -1, -1, +1] self.assertAllClose( - np.sign( - conditional_normal.sample( - 5, bijector_kwargs={"z": z}).eval()), - z) + np.sign(conditional_normal.sample( + 5, bijector_kwargs={"z": z}).eval()), z) def testShapeChangingBijector(self): with self.test_session(): softmax = bs.SoftmaxCentered() standard_normal = ds.Normal(mu=0., sigma=1.) multi_logit_normal = ds.TransformedDistribution( - distribution=standard_normal, bijector=softmax) - x = [[-np.log(3.), 0.], [np.log(3), np.log(5)]] + distribution=standard_normal, + bijector=softmax) + x = [[-np.log(3.), 0.], + [np.log(3), np.log(5)]] y = softmax.forward(x).eval() - expected_log_pdf = (stats.norm( - loc=0., scale=1.).logpdf(x) - np.sum(np.log(y), axis=-1)) + expected_log_pdf = (stats.norm(loc=0., scale=1.).logpdf(x) - + np.sum(np.log(y), axis=-1)) self.assertAllClose(expected_log_pdf, multi_logit_normal.log_prob(y).eval()) self.assertAllClose( @@ -152,7 +150,9 @@ class TransformedDistributionTest(test.TestCase): with self.test_session(): shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) diag = np.array([[1, 2, 3], [2, 3, 2]], dtype=np.float32) - actual_mvn = ds.MultivariateNormalDiag(shift, diag, validate_args=True) + actual_mvn_entropy = np.concatenate([ + [stats.multivariate_normal(shift[i], np.diag(diag[i]**2)).entropy()] + for i in range(len(diag))]) fake_mvn = ds.TransformedDistribution( ds.MultivariateNormalDiag( array_ops.zeros_like(shift), @@ -160,11 +160,10 @@ class TransformedDistributionTest(test.TestCase): validate_args=True), bs.AffineLinearOperator( shift, - scale=la.LinearOperatorDiag( - diag, is_non_singular=True), + scale=la.LinearOperatorDiag(diag, is_non_singular=True), validate_args=True), validate_args=True) - self.assertAllClose(actual_mvn.entropy().eval(), + self.assertAllClose(actual_mvn_entropy, fake_mvn.entropy().eval()) @@ -181,121 +180,151 @@ class ScalarToMultiTest(test.TestCase): dtype=np.float32) def _testMVN(self, - base_distribution, - batch_shape=None, - event_shape=None, + base_distribution_class, + base_distribution_kwargs, + batch_shape=(), + event_shape=(), not_implemented_message=None): with self.test_session() as sess: # Overriding shapes must be compatible w/bijector; most bijectors are # batch_shape agnostic and only care about event_ndims. # In the case of `Affine`, if we got it wrong then it would fire an # exception due to incompatible dimensions. - fake_mvn = ds.TransformedDistribution( - distribution=base_distribution[0](validate_args=True, - **base_distribution[1]), - bijector=bs.Affine( - shift=self._shift, scale_tril=self._tril), + batch_shape_pl = array_ops.placeholder( + dtypes.int32, name="dynamic_batch_shape") + event_shape_pl = array_ops.placeholder( + dtypes.int32, name="dynamic_event_shape") + feed_dict = {batch_shape_pl: np.array(batch_shape, dtype=np.int32), + event_shape_pl: np.array(event_shape, dtype=np.int32)} + fake_mvn_dynamic = ds.TransformedDistribution( + distribution=base_distribution_class(validate_args=True, + **base_distribution_kwargs), + bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), + batch_shape=batch_shape_pl, + event_shape=event_shape_pl, + validate_args=True) + + fake_mvn_static = ds.TransformedDistribution( + distribution=base_distribution_class(validate_args=True, + **base_distribution_kwargs), + bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), batch_shape=batch_shape, event_shape=event_shape, validate_args=True) actual_mean = np.tile(self._shift, [2, 1]) # Affine elided this tile. actual_cov = np.matmul(self._tril, np.transpose(self._tril, [0, 2, 1])) - actual_mvn = ds.MultivariateNormalFull(mu=actual_mean, sigma=actual_cov) - - # Ensure sample works by checking first, second moments. - n = 5e3 - y = fake_mvn.sample(int(n), seed=0) - sample_mean = math_ops.reduce_mean(y, 0) - centered_y = array_ops.transpose(y - sample_mean, [1, 2, 0]) - sample_cov = math_ops.matmul(centered_y, centered_y, transpose_b=True) / n - [sample_mean_, sample_cov_] = sess.run([sample_mean, sample_cov]) - self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1) - self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1) - - # Ensure all other functions work as intended. - x = fake_mvn.sample(5, seed=0).eval() - self.assertAllEqual([5, 2, 3], x.shape) - self.assertAllEqual(actual_mvn.get_event_shape(), - fake_mvn.get_event_shape()) - self.assertAllEqual(actual_mvn.event_shape().eval(), - fake_mvn.event_shape().eval()) - self.assertAllEqual(actual_mvn.get_batch_shape(), - fake_mvn.get_batch_shape()) - self.assertAllEqual(actual_mvn.batch_shape().eval(), - fake_mvn.batch_shape().eval()) - self.assertAllClose( - actual_mvn.log_prob(x).eval(), - fake_mvn.log_prob(x).eval(), - atol=0., - rtol=1e-7) - self.assertAllClose( - actual_mvn.prob(x).eval(), - fake_mvn.prob(x).eval(), - atol=0., - rtol=1e-6) - self.assertAllClose( - actual_mvn.entropy().eval(), - fake_mvn.entropy().eval(), - atol=0., - rtol=1e-6) - for unsupported_fn in (fake_mvn.log_cdf, fake_mvn.cdf, - fake_mvn.survival_function, - fake_mvn.log_survival_function): + + def actual_mvn_log_prob(x): + return np.concatenate([ + [stats.multivariate_normal( + actual_mean[i], actual_cov[i]).logpdf(x[:, i, :])] + for i in range(len(actual_cov))]).T + + actual_mvn_entropy = np.concatenate([ + [stats.multivariate_normal( + actual_mean[i], actual_cov[i]).entropy()] + for i in range(len(actual_cov))]) + + self.assertAllEqual([3], fake_mvn_static.get_event_shape()) + self.assertAllEqual([2], fake_mvn_static.get_batch_shape()) + + self.assertAllEqual(tensor_shape.TensorShape(None), + fake_mvn_dynamic.get_event_shape()) + self.assertAllEqual(tensor_shape.TensorShape(None), + fake_mvn_dynamic.get_batch_shape()) + + x = fake_mvn_static.sample(5, seed=0).eval() + for unsupported_fn in (fake_mvn_static.log_cdf, + fake_mvn_static.cdf, + fake_mvn_static.survival_function, + fake_mvn_static.log_survival_function): with self.assertRaisesRegexp(NotImplementedError, not_implemented_message): - self.assertRaisesRegexp(unsupported_fn(x)) + unsupported_fn(x) + + num_samples = 5e3 + for fake_mvn, feed_dict in ((fake_mvn_static, {}), + (fake_mvn_dynamic, feed_dict)): + # Ensure sample works by checking first, second moments. + y = fake_mvn.sample(int(num_samples), seed=0) + x = y[0:5, ...] + sample_mean = math_ops.reduce_mean(y, 0) + centered_y = array_ops.transpose(y - sample_mean, [1, 2, 0]) + sample_cov = math_ops.matmul( + centered_y, centered_y, transpose_b=True) / num_samples + [ + sample_mean_, + sample_cov_, + x_, + fake_event_shape_, + fake_batch_shape_, + fake_log_prob_, + fake_prob_, + fake_entropy_, + ] = sess.run([ + sample_mean, + sample_cov, + x, + fake_mvn.event_shape(), + fake_mvn.batch_shape(), + fake_mvn.log_prob(x), + fake_mvn.prob(x), + fake_mvn.entropy(), + ], feed_dict=feed_dict) + + self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1) + self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1) + + # Ensure all other functions work as intended. + self.assertAllEqual([5, 2, 3], x_.shape) + self.assertAllEqual([3], fake_event_shape_) + self.assertAllEqual([2], fake_batch_shape_) + self.assertAllClose(actual_mvn_log_prob(x_), fake_log_prob_, + atol=0., rtol=1e-6) + self.assertAllClose(np.exp(actual_mvn_log_prob(x_)), fake_prob_, + atol=0., rtol=1e-5) + self.assertAllClose(actual_mvn_entropy, fake_entropy_, + atol=0., rtol=1e-6) def testScalarBatchScalarEvent(self): self._testMVN( - base_distribution=[ds.Normal, { - "mu": 0., - "sigma": 1. - }], + base_distribution_class=ds.Normal, + base_distribution_kwargs={"mu": 0., "sigma": 1.}, batch_shape=[2], event_shape=[3], not_implemented_message="not implemented when overriding event_shape") def testScalarBatchNonScalarEvent(self): self._testMVN( - base_distribution=[ - ds.MultivariateNormalDiag, { - "mu": [0., 0., 0.], - "diag_stdev": [1., 1, 1] - } - ], + base_distribution_class=ds.MultivariateNormalDiag, + base_distribution_kwargs={"mu": [0., 0., 0.], "diag_stdev": [1., 1, 1]}, batch_shape=[2], not_implemented_message="not implemented$") with self.test_session(): # Can't override event_shape for scalar batch, non-scalar event. - with self.assertRaisesRegexp(ValueError, "requires scalar"): + with self.assertRaisesRegexp(ValueError, "base distribution not scalar"): ds.TransformedDistribution( - distribution=ds.MultivariateNormalDiag( - mu=[0.], diag_stdev=[1.]), - bijector=bs.Affine( - shift=self._shift, scale_tril=self._tril), + distribution=ds.MultivariateNormalDiag(mu=[0.], diag_stdev=[1.]), + bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), batch_shape=[2], event_shape=[3], validate_args=True) def testNonScalarBatchScalarEvent(self): self._testMVN( - base_distribution=[ds.Normal, { - "mu": [0., 0], - "sigma": [1., 1] - }], + base_distribution_class=ds.Normal, + base_distribution_kwargs={"mu": [0., 0], "sigma": [1., 1]}, event_shape=[3], not_implemented_message="not implemented when overriding event_shape") with self.test_session(): # Can't override batch_shape for non-scalar batch, scalar event. - with self.assertRaisesRegexp(ValueError, "requires scalar"): + with self.assertRaisesRegexp(ValueError, "base distribution not scalar"): ds.TransformedDistribution( - distribution=ds.Normal( - mu=[0.], sigma=[1.]), - bijector=bs.Affine( - shift=self._shift, scale_tril=self._tril), + distribution=ds.Normal(mu=[0.], sigma=[1.]), + bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), batch_shape=[2], event_shape=[3], validate_args=True) @@ -304,12 +333,11 @@ class ScalarToMultiTest(test.TestCase): with self.test_session(): # Can't override event_shape and/or batch_shape for non_scalar batch, # non-scalar event. - with self.assertRaisesRegexp(ValueError, "requires scalar"): + with self.assertRaisesRegexp(ValueError, "base distribution not scalar"): ds.TransformedDistribution( - distribution=ds.MultivariateNormalDiag( - mu=[[0.]], diag_stdev=[[1.]]), - bijector=bs.Affine( - shift=self._shift, scale_tril=self._tril), + distribution=ds.MultivariateNormalDiag(mu=[[0.]], + diag_stdev=[[1.]]), + bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), batch_shape=[2], event_shape=[3], validate_args=True) diff --git a/tensorflow/contrib/distributions/python/ops/distribution.py b/tensorflow/contrib/distributions/python/ops/distribution.py index eb1c9852901ef446b5e46adcc7c8c1f8bcf2d886..74d5319613e74cc7f47cdc694f1cb8d838f70e87 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution.py +++ b/tensorflow/contrib/distributions/python/ops/distribution.py @@ -526,19 +526,33 @@ class Distribution(_BaseDistribution): """ return self._get_event_shape() - @property - def is_scalar_event(self): - """Indicates that `event_shape==[]`.""" - return ops.convert_to_tensor( - self._is_scalar_helper(self.get_event_shape, self.event_shape), - name="is_scalar_event") + def is_scalar_event(self, name="is_scalar_event"): + """Indicates that `event_shape == []`. - @property - def is_scalar_batch(self): - """Indicates that `batch_shape==[]`.""" - return ops.convert_to_tensor( - self._is_scalar_helper(self.get_batch_shape, self.batch_shape), - name="is_scalar_batch") + Args: + name: The name to give this op. + + Returns: + is_scalar_event: `Boolean` `scalar` `Tensor`. + """ + with self._name_scope(name): + return ops.convert_to_tensor( + self._is_scalar_helper(self.get_event_shape, self.event_shape), + name="is_scalar_event") + + def is_scalar_batch(self, name="is_scalar_batch"): + """Indicates that `batch_shape == []`. + + Args: + name: The name to give this op. + + Returns: + is_scalar_batch: `Boolean` `scalar` `Tensor`. + """ + with self._name_scope(name): + return ops.convert_to_tensor( + self._is_scalar_helper(self.get_batch_shape, self.batch_shape), + name="is_scalar_batch") def _sample_n(self, n, seed=None): raise NotImplementedError("sample_n is not implemented") @@ -888,7 +902,7 @@ class Distribution(_BaseDistribution): """Helper function to standardize op scope.""" with ops.name_scope(self.name): with ops.name_scope(name, values=( - (values or []) + self._graph_parents)) as scope: + ([] if values is None else values) + self._graph_parents)) as scope: yield scope def _expand_sample_shape_to_vector(self, x, name): diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py index 6aa521e1af70892c5b31cad4d79919c3ee083127..624b68e01f9728b8699a35140a1e8f37ea00b76d 100644 --- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import bijector as bijectors from tensorflow.contrib.distributions.python.ops import distribution as distributions from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -47,19 +49,46 @@ _condition_kwargs_dict = { # graph construction. +def _static_value(x): + """Returns the static value of a `Tensor` or `None`.""" + return tensor_util.constant_value(ops.convert_to_tensor(x)) + + +def _logical_and(*args): + """Convenience function which attempts to statically `reduce_all`.""" + args_ = [_static_value(x) for x in args] + if any(x is not None and not bool(x) for x in args_): + return constant_op.constant(False) + if all(x is not None and bool(x) for x in args_): + return constant_op.constant(True) + if len(args) == 2: + return math_ops.logical_and(*args) + return math_ops.reduce_all(args) + + +def _logical_equal(x, y): + """Convenience function which attempts to statically compute `x == y`.""" + x_ = _static_value(x) + y_ = _static_value(y) + if x_ is None or y_ is None: + return math_ops.equal(x, y) + return constant_op.constant(np.array_equal(x_, y_)) + + def _logical_not(x): """Convenience function which attempts to statically apply `logical_not`.""" - if tensor_util.constant_value(x) is not None: - return not tensor_util.constant_value(x) - return math_ops.logical_not(x) + x_ = _static_value(x) + if x_ is None: + return math_ops.logical_not(x) + return constant_op.constant(np.logical_not(x_)) def _concat_vectors(*args): - """Convenience function which flattens input vectors.""" - vals = [tensor_util.constant_value(ops.convert_to_tensor(x)) for x in args] - if any(x is None for x in vals): + """Convenience function which concatenates input vectors.""" + args_ = [_static_value(x) for x in args] + if any(x_ is None for x_ in args_): return array_ops.concat_v2(args, 0) - return [x for v in vals for x in v] + return constant_op.constant([x_ for vec_ in args_ for x_ in vec_]) def _pick_scalar_condition(pred, cond_true, cond_false): @@ -67,20 +96,36 @@ def _pick_scalar_condition(pred, cond_true, cond_false): # Note: This function is only valid if all of pred, cond_true, and cond_false # are scalars. This means its semantics are arguably more like tf.cond than # tf.select even though we use tf.select to implement it. - pred_static = tensor_util.constant_value(pred) - if pred_static is None: - return math_ops.select(pred, cond_true, cond_false) - return cond_true if pred_static else cond_false + pred_ = _static_value(pred) + if pred_ is None: + return array_ops.where(pred, cond_true, cond_false) + return cond_true if pred_ else cond_false def _ones_like(x): """Convenience function attempts to statically construct `ones_like`.""" # Should only be used for small vectors. if x.get_shape().is_fully_defined(): - return np.ones(x.get_shape().as_list(), dtype=x.dtype.as_numpy_dtype()) + return array_ops.ones(x.get_shape().as_list(), dtype=x.dtype) return array_ops.ones_like(x) +def _ndims_from_shape(shape): + """Returns `Tensor`'s `rank` implied by a `Tensor` shape.""" + if shape.get_shape().ndims not in (None, 1): + raise ValueError("input is not a valid shape: not 1D") + if not shape.dtype.is_integer: + raise TypeError("input is not a valid shape: wrong dtype") + if shape.get_shape().is_fully_defined(): + return constant_op.constant(shape.get_shape().as_list()[0]) + return array_ops.shape(shape)[0] + + +def _is_scalar_from_shape(shape): + """Returns `True` `Tensor` if `Tensor` shape implies a scalar.""" + return _logical_equal(_ndims_from_shape(shape), 0) + + class TransformedDistribution(distributions.Distribution): """A Transformed Distribution. @@ -234,9 +279,9 @@ class TransformedDistribution(distributions.Distribution): bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. `None` means `Identity()`. batch_shape: `integer` vector `Tensor` which overrides `distribution` - `batch_shape`; valid only if `distribution.is_scalar_batch`. + `batch_shape`; valid only if `distribution.is_scalar_batch()`. event_shape: `integer` vector `Tensor` which overrides `distribution` - `event_shape`; valid only if `distribution.is_scalar_event`. + `event_shape`; valid only if `distribution.is_scalar_event()`. validate_args: Python Boolean. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. @@ -245,53 +290,55 @@ class TransformedDistribution(distributions.Distribution): """ parameters = locals() parameters.pop("self") - if bijector is None: - bijector = bijectors.Identity(validate_args=validate_args) - name = name or bijector.name + distribution.name + name = name or (("" if bijector is None else bijector.name) + + distribution.name) with ops.name_scope(name, values=[event_shape, batch_shape]): - if batch_shape is not None: - batch_shape = self._maybe_validate_shape_override( - ops.convert_to_tensor(batch_shape, name="batch_shape"), - distribution.is_scalar_batch, validate_args) - self._override_batch_shape = batch_shape - - if event_shape is not None: - event_shape = self._maybe_validate_shape_override( - ops.convert_to_tensor(event_shape, name="event_shape"), - distribution.is_scalar_event, validate_args) - self._override_event_ndims = ( - event_shape.get_shape().ndims - if event_shape.get_shape().ndims is not None - else array_ops.rank(event_shape, name="event_ndims")) - else: - self._override_event_ndims = 0 - self._override_event_shape = event_shape + # For convenience we define some handy constants. + self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero") + self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty") + + if bijector is None: + bijector = bijectors.Identity(validate_args=validate_args) + + # We will keep track of a static and dynamic version of + # self._is_{batch,event}_override. This way we can do more prior to graph + # execution, including possibly raising Python exceptions. + + self._override_batch_shape = self._maybe_validate_shape_override( + batch_shape, distribution.is_scalar_batch(), validate_args, + "batch_shape") + self._is_batch_override = _logical_not(_logical_equal( + _ndims_from_shape(self._override_batch_shape), self._zero)) + self._is_maybe_batch_override = bool( + tensor_util.constant_value(self._override_batch_shape) is None or + tensor_util.constant_value(self._override_batch_shape)) + + self._override_event_shape = self._maybe_validate_shape_override( + event_shape, distribution.is_scalar_event(), validate_args, + "event_shape") + self._is_event_override = _logical_not(_logical_equal( + _ndims_from_shape(self._override_event_shape), self._zero)) + self._is_maybe_event_override = bool( + tensor_util.constant_value(self._override_event_shape) is None or + tensor_util.constant_value(self._override_event_shape)) # To convert a scalar distribution into a multivariate distribution we # will draw dims from the sample dims, which are otherwise iid. This is - # easy to do except in the case that: - # batch_shape is None and - # event_shape is not None and - # not distribution.is_scalar_batch. - # When that case happens the event dims will incorrectly be to the left of - # the batch dims. In this case we'll cyclically permute left the new dims. - if batch_shape is None and event_shape is not None: - self._needs_rotation = ops.convert_to_tensor( - _logical_not(distribution.is_scalar_batch), name="needs_rotation") - n = _pick_scalar_condition(self._needs_rotation, - self._override_event_ndims, 0) - # We'll be reducing the head dims (if at all), i.e., this will be [] - # if we don't need to reduce. - self._reduce_event_indices = math_ops.range( - n - self._override_event_ndims, n) - else: - self._needs_rotation = ops.convert_to_tensor(False, - name="needs_rotation") - # We'll be reducing the tail dims (if at all), i.e., this will be [] - # if we don't need to reduce. - self._reduce_event_indices = ( - math_ops.range(-self._override_event_ndims, 0) - if event_shape is not None else []) + # easy to do except in the case that the base distribution has batch dims + # and we're overriding event shape. When that case happens the event dims + # will incorrectly be to the left of the batch dims. In this case we'll + # cyclically permute left the new dims. + self._needs_rotation = _logical_and( + self._is_event_override, + _logical_not(self._is_batch_override), + _logical_not(distribution.is_scalar_batch())) + override_event_ndims = _ndims_from_shape(self._override_event_shape) + self._rotate_ndims = _pick_scalar_condition( + self._needs_rotation, override_event_ndims, 0) + # We'll be reducing the head dims (if at all), i.e., this will be [] + # if we don't need to reduce. + self._reduce_event_indices = math_ops.range( + self._rotate_ndims - override_event_ndims, self._rotate_ndims) self._distribution = distribution self._bijector = bijector @@ -319,28 +366,29 @@ class TransformedDistribution(distributions.Distribution): return self._bijector def _event_shape(self): - return self.bijector.forward_event_shape( - self.distribution.event_shape() - if self._override_event_shape is None - else self._override_event_shape) + return self.bijector.forward_event_shape(distribution_util.pick_vector( + self._is_event_override, + self._override_event_shape, + self.distribution.event_shape())) def _get_event_shape(self): + static_override = tensor_util.constant_value(self._override_event_shape) return self.bijector.get_forward_event_shape( self.distribution.get_event_shape() - if self._override_event_shape is None - else tensor_shape.TensorShape( - tensor_util.constant_value(self._override_event_shape))) + if static_override is not None and not static_override + else tensor_shape.TensorShape(static_override)) def _batch_shape(self): - if self._override_batch_shape is None: - return self.distribution.batch_shape() - return self._override_batch_shape + return distribution_util.pick_vector( + self._is_batch_override, + self._override_batch_shape, + self.distribution.batch_shape()) def _get_batch_shape(self): - if self._override_batch_shape is None: + static_override = tensor_util.constant_value(self._override_batch_shape) + if static_override is not None and not static_override: return self.distribution.get_batch_shape() - return tensor_shape.TensorShape(tensor_util.constant_value( - self._override_batch_shape)) + return tensor_shape.TensorShape(static_override) @distribution_util.AppendDocstring( """Samples from the base distribution and then passes through @@ -350,20 +398,11 @@ class TransformedDistribution(distributions.Distribution): bijector_kwargs=None, distribution_kwargs=None): bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} - if (self._override_batch_shape is None and - self._override_event_shape is None): - sample_shape = [n] - else: - if (self._override_batch_shape is not None and - self._override_event_shape is not None): - sample_shape = [[n], - self._override_batch_shape, - self._override_event_shape] - elif self._override_batch_shape is not None: - sample_shape = [[n], self._override_batch_shape] - elif self._override_event_shape is not None: - sample_shape = [self._override_event_shape, [n]] - sample_shape = _concat_vectors(*sample_shape) + sample_shape = _concat_vectors( + distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), + self._override_batch_shape, + self._override_event_shape, + distribution_util.pick_vector(self._needs_rotation, [n], self._empty)) x = self.distribution.sample(sample_shape=sample_shape, seed=seed, **distribution_kwargs) x = self._maybe_rotate_dims(x) @@ -383,7 +422,7 @@ class TransformedDistribution(distributions.Distribution): y, **bijector_kwargs) x = self._maybe_rotate_dims(x, rotate_right=True) log_prob = self.distribution.log_prob(x, **distribution_kwargs) - if self._override_event_shape is not None: + if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) return ildj + log_prob @@ -401,14 +440,14 @@ class TransformedDistribution(distributions.Distribution): y, **bijector_kwargs) x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x, **distribution_kwargs) - if self._override_event_shape is not None: + if self._is_maybe_event_override: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) return math_ops.exp(ildj) * prob @distribution_util.AppendDocstring( condition_kwargs_dict=_condition_kwargs_dict) def _log_cdf(self, y, bijector_kwargs=None, distribution_kwargs=None): - if self._override_event_shape is not None: + if self._is_maybe_event_override: raise NotImplementedError("log_cdf is not implemented when overriding " "event_shape") bijector_kwargs = bijector_kwargs or {} @@ -419,7 +458,7 @@ class TransformedDistribution(distributions.Distribution): @distribution_util.AppendDocstring( condition_kwargs_dict=_condition_kwargs_dict) def _cdf(self, y, bijector_kwargs=None, distribution_kwargs=None): - if self._override_event_shape is not None: + if self._is_maybe_event_override: raise NotImplementedError("cdf is not implemented when overriding " "event_shape") bijector_kwargs = bijector_kwargs or {} @@ -431,7 +470,7 @@ class TransformedDistribution(distributions.Distribution): condition_kwargs_dict=_condition_kwargs_dict) def _log_survival_function(self, y, bijector_kwargs=None, distribution_kwargs=None): - if self._override_event_shape is not None: + if self._is_maybe_event_override: raise NotImplementedError("log_survival_function is not implemented when " "overriding event_shape") bijector_kwargs = bijector_kwargs or {} @@ -443,7 +482,7 @@ class TransformedDistribution(distributions.Distribution): condition_kwargs_dict=_condition_kwargs_dict) def _survival_function(self, y, bijector_kwargs=None, distribution_kwargs=None): - if self._override_event_shape is not None: + if self._is_maybe_event_override: raise NotImplementedError("survival_function is not implemented when " "overriding event_shape") bijector_kwargs = bijector_kwargs or {} @@ -462,64 +501,77 @@ class TransformedDistribution(distributions.Distribution): # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy() - if self._override_event_shape is not None: + if self._is_maybe_event_override: # H[X] = sum_i H[X_i] if X_i are mutually independent. # This means that a reduce_sum is a simple rescaling. entropy *= math_ops.cast(math_ops.reduce_prod(self._override_event_shape), dtype=entropy.dtype.base_dtype) - if self._override_batch_shape is not None: - entropy = array_ops.reshape(entropy, - _ones_like(self._override_batch_shape)) - entropy = array_ops.tile(entropy, self._override_batch_shape) + if self._is_maybe_batch_override: + new_shape = array_ops.concat_v2([ + _ones_like(self._override_batch_shape), + self.distribution.batch_shape()], 0) + entropy = array_ops.reshape(entropy, new_shape) + multiples = array_ops.concat_v2([ + self._override_batch_shape, + _ones_like(self.distribution.batch_shape())], 0) + entropy = array_ops.tile(entropy, multiples) dummy = 0. return entropy - self.bijector.inverse_log_det_jacobian(dummy) def _maybe_validate_shape_override(self, override_shape, base_is_scalar, - validate_args): + validate_args, name): """Helper to __init__ which ensures override batch/event_shape are valid.""" + if override_shape is None: + override_shape = [] + + override_shape = ops.convert_to_tensor(override_shape, dtype=dtypes.int32, + name=name) + if not override_shape.dtype.is_integer: raise TypeError("shape override must be an integer") + override_is_scalar = _is_scalar_from_shape(override_shape) + if tensor_util.constant_value(override_is_scalar): + return self._empty + + dynamic_assertions = [] + if override_shape.get_shape().ndims is not None: if override_shape.get_shape().ndims != 1: raise ValueError("shape override must be a vector") elif validate_args: - is_vector = check_ops.assert_rank( + dynamic_assertions += [check_ops.assert_rank( override_shape, 1, - message="shape override must be a vector") - override_shape = control_flow_ops.with_dependencies( - [is_vector], override_shape) + message="shape override must be a vector")] - if override_shape.get_shape().is_fully_defined(): - if any(s <= 0 for s in override_shape.get_shape().as_list()): + if tensor_util.constant_value(override_shape) is not None: + if any(s <= 0 for s in tensor_util.constant_value(override_shape)): raise ValueError("shape override must have positive elements") elif validate_args: - is_positive = check_ops.assert_positive( + dynamic_assertions += [check_ops.assert_positive( override_shape, - message="shape override must have positive elements") - override_shape = control_flow_ops.with_dependencies( - [is_positive], override_shape) + message="shape override must have positive elements")] - if tensor_util.constant_value(base_is_scalar) is not None: - if not tensor_util.constant_value(base_is_scalar): - raise ValueError("shape override requires scalar distribution.") + is_both_nonscalar = _logical_and(_logical_not(base_is_scalar), + _logical_not(override_is_scalar)) + if tensor_util.constant_value(is_both_nonscalar) is not None: + if tensor_util.constant_value(is_both_nonscalar): + raise ValueError("base distribution not scalar") elif validate_args: - is_scalar = check_ops.assert_equal( - base_is_scalar, True, - message="shape override requires scalar distribution.") - override_shape = control_flow_ops.with_dependencies( - [is_scalar], override_shape) + dynamic_assertions += [check_ops.assert_equal( + is_both_nonscalar, False, + message="base distribution not scalar")] - return override_shape + if not dynamic_assertions: + return override_shape + return control_flow_ops.with_dependencies( + dynamic_assertions, override_shape) def _maybe_rotate_dims(self, x, rotate_right=False): """Helper which rolls left event_dims left or right event_dims right.""" if tensor_util.constant_value(self._needs_rotation) is False: return x ndims = array_ops.rank(x) - n = _pick_scalar_condition(self._needs_rotation, - self._override_event_ndims, 0) - if rotate_right: - n = ndims - n + n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims return array_ops.transpose( x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n)))