提交 4863a607 编写于 作者: J Joshua V. Dillon 提交者: TensorFlower Gardener

Add fill_lower_triangular utility function.

Change: 137863303
上级 1f8936bb
......@@ -57,7 +57,6 @@ initialized with parameters that define the distributions.
@@MultivariateNormalCholesky
@@MultivariateNormalDiagPlusVDVT
@@MultivariateNormalDiagWithSoftplusStDev
@@matrix_diag_transform
### Other multivariate distributions
......@@ -67,6 +66,10 @@ initialized with parameters that define the distributions.
@@WishartCholesky
@@WishartFull
### Multivariate Utilities
@@matrix_diag_transform
## Transformed distributions
@@TransformedDistribution
......@@ -86,7 +89,7 @@ representing the posterior or posterior predictive.
@@normal_conjugates_known_sigma_posterior
@@normal_conjugates_known_sigma_predictive
## Kullback Leibler Divergence
## Kullback-Leibler Divergence
@@kl
@@RegisterKL
......
......@@ -25,7 +25,7 @@ import tensorflow as tf
from tensorflow.contrib.distributions.python.ops import distribution_util
class DistributionUtilTest(tf.test.TestCase):
class AssertCloseTest(tf.test.TestCase):
def testAssertCloseIntegerDtype(self):
x = [1, 5, 10, 15, 20]
......@@ -110,6 +110,9 @@ class DistributionUtilTest(tf.test.TestCase):
distribution_util.assert_integer_form(w)]):
tf.identity(w).eval()
class GetLogitsAndProbTest(tf.test.TestCase):
def testGetLogitsAndProbImproperArguments(self):
with self.test_session():
with self.assertRaises(ValueError):
......@@ -229,6 +232,9 @@ class DistributionUtilTest(tf.test.TestCase):
p=p4, multidimensional=True, validate_args=False)
prob.eval()
class LogCombinationsTest(tf.test.TestCase):
def testLogCombinationsBinomial(self):
n = [2, 5, 12, 15]
k = [1, 2, 4, 11]
......@@ -252,6 +258,9 @@ class DistributionUtilTest(tf.test.TestCase):
log_binom = distribution_util.log_combinations(n, counts)
self.assertEqual([2, 2], log_binom.get_shape())
class RotateTransposeTest(tf.test.TestCase):
def _np_rotate_transpose(self, x, shift):
if not isinstance(x, np.ndarray):
x = np.array(x)
......@@ -283,7 +292,10 @@ class DistributionUtilTest(tf.test.TestCase):
sess.run(distribution_util.rotate_transpose(x, shift),
feed_dict={x: x_value, shift: shift_value}))
def testChooseVector(self):
class PickVectorTest(tf.test.TestCase):
def testCorrectlyPicksVector(self):
with self.test_session():
x = np.arange(10, 12)
y = np.arange(15, 18)
......@@ -301,5 +313,47 @@ class DistributionUtilTest(tf.test.TestCase):
tf.constant(False), x, y)) # No eval.
class FillLowerTriangularTest(tf.test.TestCase):
def testCorrectlyMakes1x1LowerTril(self):
with self.test_session():
x = np.array([[1.], [2], [3]])
expected = np.array([[[1.]], [[2]], [[3]]])
actual = distribution_util.fill_lower_triangular(x)
self.assertAllEqual(expected.shape, actual.get_shape())
self.assertAllEqual(expected, actual.eval())
def testCorrectlyMakesNoBatchLowerTril(self):
with self.test_session():
x = np.arange(9)
expected = np.array(
[[0., 0., 0.],
[1., 2., 0.],
[3., 4., 5.]])
actual = distribution_util.fill_lower_triangular(x)
self.assertAllEqual(expected.shape, actual.get_shape())
self.assertAllEqual(expected, actual.eval())
def testCorrectlyMakesBatchLowerTril(self):
with self.test_session():
x = np.reshape(np.arange(24), (2, 2, 6))
expected = np.array(
[[[[0., 0., 0.],
[1., 2., 0.],
[3., 4., 5.]],
[[6., 0., 0.],
[7., 8., 0.],
[9., 10., 11.]]],
[[[12., 0., 0.],
[13., 14., 0.],
[15., 16., 17.]],
[[18., 0., 0.],
[19., 20., 0.],
[21., 22., 23.]]]])
actual = distribution_util.fill_lower_triangular(x)
self.assertAllEqual(expected.shape, actual.get_shape())
self.assertAllEqual(expected, actual.eval())
if __name__ == "__main__":
tf.test.main()
......@@ -20,11 +20,13 @@ from __future__ import print_function
import functools
import hashlib
import math
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
......@@ -376,7 +378,7 @@ def pick_vector(cond,
TypeError: if `cond` is not a constant and
`true_vector.dtype != false_vector.dtype`
"""
with ops.op_scope((cond, true_vector, false_vector), name):
with ops.name_scope(name, values=(cond, true_vector, false_vector)):
cond = ops.convert_to_tensor(cond, name="cond")
if cond.dtype != dtypes.bool:
raise TypeError("%s.dtype=%s which is not %s" %
......@@ -405,6 +407,101 @@ def gen_new_seed(seed, salt):
return None
def fill_lower_triangular(x, name="fill_lower_triangular"):
"""Creates a (batch of) lower triangular matrix from a vector of inputs.
If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
`n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.
Note: This function is very slow; possibly 10x slower than zero-ing out the
upper-triangular portion of a full matrix.
Example:
```python
fill_lower_triangular([1, 2, 3, 4, 5, 6])
# Returns: [[1, 0, 0],
# [2, 3, 0],
# [4, 5, 6]]
```
Args:
x: `Tensor` representing lower triangular elements.
name: `String`. The name to give this op.
Returns:
tril: `Tensor` with lower triangular elements filled from `x`.
"""
with ops.name_scope(name, values=(x,)):
x = ops.convert_to_tensor(x, name="x")
ndims = x.get_shape().ndims
if ndims is not None and x.get_shape()[-1].value is not None:
d = x.get_shape()[-1].value
# d = n^2/2 + n/2 implies n is:
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
final_shape = x.get_shape()[:-1].concatenate(
tensor_shape.TensorShape([n, n]))
else:
ndims = array_ops.rank(x)
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
# d = n^2/2 + n/2 implies n is:
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
dtype=dtypes.int32)
final_shape = x.get_shape()[:-1].concatenate(
tensor_shape.TensorShape([None, None]))
# Make ids for each batch dim.
if (x.get_shape().ndims is not None and
x.get_shape()[:-1].is_fully_defined()):
batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32)
m = np.prod(batch_shape)
else:
batch_shape = array_ops.shape(x)[:-1]
m = array_ops.reduce_prod(batch_shape)
# Flatten batch dims.
y = array_ops.reshape(x, [-1, d])
# Prepend a zero to each row.
y = array_ops.pad(y, paddings=[[0, 0], [1, 0]])
# Make ids for each batch dim.
if x.get_shape()[:-1].is_fully_defined():
m = np.asarray(np.prod(x.get_shape()[:-1].as_list()), dtype=np.int32)
else:
m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
batch_ids = math_ops.range(m)
def make_tril_ids(n):
"""Internal helper to create vector of linear indices into y."""
cols = array_ops.reshape(array_ops.tile(math_ops.range(n), [n]), [n, n])
rows = array_ops.tile(
array_ops.expand_dims(math_ops.range(n), -1), [1, n])
pred = math_ops.greater(cols, rows)
tril_ids = array_ops.tile(array_ops.reshape(
math_ops.cumsum(math_ops.range(n)), [n, 1]), [1, n]) + cols
tril_ids = math_ops.select(pred,
array_ops.zeros([n, n], dtype=dtypes.int32),
tril_ids + 1)
tril_ids = array_ops.reshape(tril_ids, [-1])
return tril_ids
tril_ids = make_tril_ids(n)
# Assemble the ids into pairs.
idx = array_ops.pack([
array_ops.tile(array_ops.expand_dims(batch_ids, -1), [1, n*n]),
array_ops.tile([tril_ids], [m, 1])])
idx = array_ops.transpose(idx, [1, 2, 0])
y = array_ops.gather_nd(y, idx)
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
y.set_shape(y.get_shape().merge_with(final_shape))
return y
class AppendDocstring(object):
"""Helper class to promote private subclass docstring to public counterpart.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册