diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 36bc407238293c06da53c5897b68236ed59ad1fb..8111118462abbfcf660184ddf2c813a5de887ed6 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -57,7 +57,6 @@ initialized with parameters that define the distributions. @@MultivariateNormalCholesky @@MultivariateNormalDiagPlusVDVT @@MultivariateNormalDiagWithSoftplusStDev -@@matrix_diag_transform ### Other multivariate distributions @@ -67,6 +66,10 @@ initialized with parameters that define the distributions. @@WishartCholesky @@WishartFull +### Multivariate Utilities + +@@matrix_diag_transform + ## Transformed distributions @@TransformedDistribution @@ -86,7 +89,7 @@ representing the posterior or posterior predictive. @@normal_conjugates_known_sigma_posterior @@normal_conjugates_known_sigma_predictive -## Kullback Leibler Divergence +## Kullback-Leibler Divergence @@kl @@RegisterKL diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index b5d041c8c768aa936280fc7e3b3ce1c0ff239740..e82d604d58a2a9b9f1eb550cda597bc824afb0ce 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -25,7 +25,7 @@ import tensorflow as tf from tensorflow.contrib.distributions.python.ops import distribution_util -class DistributionUtilTest(tf.test.TestCase): +class AssertCloseTest(tf.test.TestCase): def testAssertCloseIntegerDtype(self): x = [1, 5, 10, 15, 20] @@ -110,6 +110,9 @@ class DistributionUtilTest(tf.test.TestCase): distribution_util.assert_integer_form(w)]): tf.identity(w).eval() + +class GetLogitsAndProbTest(tf.test.TestCase): + def testGetLogitsAndProbImproperArguments(self): with self.test_session(): with self.assertRaises(ValueError): @@ -229,6 +232,9 @@ class DistributionUtilTest(tf.test.TestCase): p=p4, multidimensional=True, validate_args=False) prob.eval() + +class LogCombinationsTest(tf.test.TestCase): + def testLogCombinationsBinomial(self): n = [2, 5, 12, 15] k = [1, 2, 4, 11] @@ -252,6 +258,9 @@ class DistributionUtilTest(tf.test.TestCase): log_binom = distribution_util.log_combinations(n, counts) self.assertEqual([2, 2], log_binom.get_shape()) + +class RotateTransposeTest(tf.test.TestCase): + def _np_rotate_transpose(self, x, shift): if not isinstance(x, np.ndarray): x = np.array(x) @@ -283,7 +292,10 @@ class DistributionUtilTest(tf.test.TestCase): sess.run(distribution_util.rotate_transpose(x, shift), feed_dict={x: x_value, shift: shift_value})) - def testChooseVector(self): + +class PickVectorTest(tf.test.TestCase): + + def testCorrectlyPicksVector(self): with self.test_session(): x = np.arange(10, 12) y = np.arange(15, 18) @@ -301,5 +313,47 @@ class DistributionUtilTest(tf.test.TestCase): tf.constant(False), x, y)) # No eval. +class FillLowerTriangularTest(tf.test.TestCase): + + def testCorrectlyMakes1x1LowerTril(self): + with self.test_session(): + x = np.array([[1.], [2], [3]]) + expected = np.array([[[1.]], [[2]], [[3]]]) + actual = distribution_util.fill_lower_triangular(x) + self.assertAllEqual(expected.shape, actual.get_shape()) + self.assertAllEqual(expected, actual.eval()) + + def testCorrectlyMakesNoBatchLowerTril(self): + with self.test_session(): + x = np.arange(9) + expected = np.array( + [[0., 0., 0.], + [1., 2., 0.], + [3., 4., 5.]]) + actual = distribution_util.fill_lower_triangular(x) + self.assertAllEqual(expected.shape, actual.get_shape()) + self.assertAllEqual(expected, actual.eval()) + + def testCorrectlyMakesBatchLowerTril(self): + with self.test_session(): + x = np.reshape(np.arange(24), (2, 2, 6)) + expected = np.array( + [[[[0., 0., 0.], + [1., 2., 0.], + [3., 4., 5.]], + [[6., 0., 0.], + [7., 8., 0.], + [9., 10., 11.]]], + [[[12., 0., 0.], + [13., 14., 0.], + [15., 16., 17.]], + [[18., 0., 0.], + [19., 20., 0.], + [21., 22., 23.]]]]) + actual = distribution_util.fill_lower_triangular(x) + self.assertAllEqual(expected.shape, actual.get_shape()) + self.assertAllEqual(expected, actual.eval()) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 89950d6aa22b401be77a6ce2bf2597ae030bfaa7..ad5fa5b5aee75b29290486b80c926776054e0dc7 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -20,11 +20,13 @@ from __future__ import print_function import functools import hashlib +import math import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -376,7 +378,7 @@ def pick_vector(cond, TypeError: if `cond` is not a constant and `true_vector.dtype != false_vector.dtype` """ - with ops.op_scope((cond, true_vector, false_vector), name): + with ops.name_scope(name, values=(cond, true_vector, false_vector)): cond = ops.convert_to_tensor(cond, name="cond") if cond.dtype != dtypes.bool: raise TypeError("%s.dtype=%s which is not %s" % @@ -405,6 +407,101 @@ def gen_new_seed(seed, salt): return None +def fill_lower_triangular(x, name="fill_lower_triangular"): + """Creates a (batch of) lower triangular matrix from a vector of inputs. + + If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1, + b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., + `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`. + + Note: This function is very slow; possibly 10x slower than zero-ing out the + upper-triangular portion of a full matrix. + + Example: + + ```python + fill_lower_triangular([1, 2, 3, 4, 5, 6]) + # Returns: [[1, 0, 0], + # [2, 3, 0], + # [4, 5, 6]] + ``` + + Args: + x: `Tensor` representing lower triangular elements. + name: `String`. The name to give this op. + + Returns: + tril: `Tensor` with lower triangular elements filled from `x`. + """ + with ops.name_scope(name, values=(x,)): + x = ops.convert_to_tensor(x, name="x") + ndims = x.get_shape().ndims + if ndims is not None and x.get_shape()[-1].value is not None: + d = x.get_shape()[-1].value + # d = n^2/2 + n/2 implies n is: + n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.)) + final_shape = x.get_shape()[:-1].concatenate( + tensor_shape.TensorShape([n, n])) + else: + ndims = array_ops.rank(x) + d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32) + # d = n^2/2 + n/2 implies n is: + n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.), + dtype=dtypes.int32) + final_shape = x.get_shape()[:-1].concatenate( + tensor_shape.TensorShape([None, None])) + + # Make ids for each batch dim. + if (x.get_shape().ndims is not None and + x.get_shape()[:-1].is_fully_defined()): + batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32) + m = np.prod(batch_shape) + else: + batch_shape = array_ops.shape(x)[:-1] + m = array_ops.reduce_prod(batch_shape) + + # Flatten batch dims. + y = array_ops.reshape(x, [-1, d]) + + # Prepend a zero to each row. + y = array_ops.pad(y, paddings=[[0, 0], [1, 0]]) + + # Make ids for each batch dim. + if x.get_shape()[:-1].is_fully_defined(): + m = np.asarray(np.prod(x.get_shape()[:-1].as_list()), dtype=np.int32) + else: + m = array_ops.reduce_prod(array_ops.shape(x)[:-1]) + batch_ids = math_ops.range(m) + + def make_tril_ids(n): + """Internal helper to create vector of linear indices into y.""" + cols = array_ops.reshape(array_ops.tile(math_ops.range(n), [n]), [n, n]) + rows = array_ops.tile( + array_ops.expand_dims(math_ops.range(n), -1), [1, n]) + pred = math_ops.greater(cols, rows) + tril_ids = array_ops.tile(array_ops.reshape( + math_ops.cumsum(math_ops.range(n)), [n, 1]), [1, n]) + cols + tril_ids = math_ops.select(pred, + array_ops.zeros([n, n], dtype=dtypes.int32), + tril_ids + 1) + tril_ids = array_ops.reshape(tril_ids, [-1]) + return tril_ids + tril_ids = make_tril_ids(n) + + # Assemble the ids into pairs. + idx = array_ops.pack([ + array_ops.tile(array_ops.expand_dims(batch_ids, -1), [1, n*n]), + array_ops.tile([tril_ids], [m, 1])]) + idx = array_ops.transpose(idx, [1, 2, 0]) + + y = array_ops.gather_nd(y, idx) + y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]])) + + y.set_shape(y.get_shape().merge_with(final_shape)) + + return y + + class AppendDocstring(object): """Helper class to promote private subclass docstring to public counterpart.