提交 060192ea 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Affine bijector will default to identity matrix in the presence of no scale args.

Also throw an error when the low-rank update is ill-specified.
Change: 141566486
上级 bc885f9c
......@@ -646,8 +646,7 @@ class AffineBijectorTest(tf.test.TestCase):
for run in (static_run, dynamic_run):
mu = -1.
# Corresponds to scale = 2
bijector = bijectors.Affine(
shift=mu, scale_identity_multiplier=None, scale_diag=[2.])
bijector = bijectors.Affine(shift=mu, scale_diag=[2.])
self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
x = [1., 2, 3] # Three scalar samples (no batches).
self.assertAllClose([1., 3, 5], run(bijector.forward, x))
......@@ -718,8 +717,7 @@ class AffineBijectorTest(tf.test.TestCase):
mu = [1.]
# One batch, scalar.
# Corresponds to scale = 1.
bijector = bijectors.Affine(
shift=mu, scale_identity_multiplier=None, scale_diag=[1.])
bijector = bijectors.Affine(shift=mu, scale_diag=[1.])
self.assertEqual(
0, bijector.shaper.event_ndims.eval()) # "is scalar"
x = [1.] # One sample from one batches.
......@@ -765,8 +763,7 @@ class AffineBijectorTest(tf.test.TestCase):
mu = [1., -1]
# Univariate, two batches.
# Corresponds to scale = 1.
bijector = bijectors.Affine(
shift=mu, scale_identity_multiplier=None, scale_diag=[1.])
bijector = bijectors.Affine(shift=mu, scale_diag=[1.])
self.assertEqual(
0, bijector.shaper.event_ndims.eval()) # "is scalar"
x = [1., 1] # One sample from each of two batches.
......@@ -824,9 +821,7 @@ class AffineBijectorTest(tf.test.TestCase):
mu = [1., -1]
# Multivariate
# Corresponds to scale = [[2., 0], [0, 1.]]
bijector = bijectors.Affine(
shift=mu, scale_identity_multiplier=None,
scale_diag=[2., 1], event_ndims=1)
bijector = bijectors.Affine(shift=mu, scale_diag=[2., 1], event_ndims=1)
self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
x = [1., 1]
# matmul(sigma, x) + shift
......@@ -865,7 +860,7 @@ class AffineBijectorTest(tf.test.TestCase):
event_ndims: event_ndims_value}
bijector = bijectors.Affine(
shift=mu, scale_identity_multiplier=None,
shift=mu,
scale_diag=scale_diag, event_ndims=event_ndims)
self.assertEqual(1, sess.run(bijector.shaper.event_ndims, feed_dict))
self.assertAllClose([[3., 1]], sess.run(bijector.forward(x), feed_dict))
......@@ -913,7 +908,7 @@ class AffineBijectorTest(tf.test.TestCase):
# Corresponds to 1 2x2 matrix, with twos on the diagonal.
scale_diag = [[2., 2]]
bijector = bijectors.Affine(
shift=mu, scale_identity_multiplier=None,
shift=mu,
scale_diag=scale_diag, event_ndims=1)
self.assertEqual(
1, bijector.shaper.event_ndims.eval()) # "is vector"
......@@ -939,7 +934,7 @@ class AffineBijectorTest(tf.test.TestCase):
scale_diag_value, event_ndims: event_ndims_value}
bijector = bijectors.Affine(
shift=mu, scale_identity_multiplier=None,
shift=mu,
scale_diag=scale_diag, event_ndims=event_ndims)
self.assertEqual(1, sess.run(bijector.shaper.event_ndims, feed_dict))
self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict))
......@@ -1010,7 +1005,6 @@ class AffineBijectorTest(tf.test.TestCase):
# scale = [[2., 0], [2, 3]]
bijector = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
scale_diag=[1., 2.],
scale_tril=[[1., 0], [2., 1]],
event_ndims=1)
......@@ -1068,7 +1062,6 @@ class AffineBijectorTest(tf.test.TestCase):
event_ndims=1)
bijector_ref = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
scale_diag=[10., 2, 3],
event_ndims=1)
......@@ -1102,14 +1095,12 @@ class AffineBijectorTest(tf.test.TestCase):
# Corresponds to scale = [[10, 0, 0], [0, 3, 0], [0, 0, 5]]
bijector = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
scale_diag=[2., 3, 4],
scale_perturb_diag=[2., 1],
scale_perturb_factor=[[2., 0], [0., 0], [0, 1]],
event_ndims=1)
bijector_ref = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
scale_diag=[10., 3, 5],
event_ndims=1)
......@@ -1142,14 +1133,12 @@ class AffineBijectorTest(tf.test.TestCase):
# Corresponds to scale = [[10, 0, 0], [1, 3, 0], [2, 3, 5]]
bijector = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
scale_tril=[[2., 0, 0], [1, 3, 0], [2, 3, 4]],
scale_perturb_diag=[2., 1],
scale_perturb_factor=[[2., 0], [0., 0], [0, 1]],
event_ndims=1)
bijector_ref = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
scale_tril=[[10., 0, 0], [1, 3, 0], [2, 3, 5]],
event_ndims=1)
......@@ -1182,14 +1171,12 @@ class AffineBijectorTest(tf.test.TestCase):
# Corresponds to scale = [[6, 0, 0], [1, 3, 0], [2, 3, 5]]
bijector = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
scale_tril=[[2., 0, 0], [1, 3, 0], [2, 3, 4]],
scale_perturb_diag=None,
scale_perturb_factor=[[2., 0], [0., 0], [0, 1]],
event_ndims=1)
bijector_ref = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
scale_tril=[[6., 0, 0], [1, 3, 0], [2, 3, 5]],
event_ndims=1)
......@@ -1212,7 +1199,6 @@ class AffineBijectorTest(tf.test.TestCase):
mu = [1., -1]
bijector = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
# Has zero on the diagonal.
scale_diag=[0., 1],
event_ndims=1,
......@@ -1242,7 +1228,6 @@ class AffineBijectorTest(tf.test.TestCase):
# Check Diag matrix with zero scaling.
bijector = bijectors.Affine(
shift=mu,
scale_identity_multiplier=None,
scale_diag=[0.0],
validate_args=True)
with self.assertRaisesOpError("Condition x > 0"):
......@@ -1271,10 +1256,19 @@ class AffineBijectorTest(tf.test.TestCase):
v = scale_perturb_factor
d2 = scale_perturb_diag
# No scale.
if c is None and d1 is None and tril is None:
# Ambiguous low rank update.
if v is None and d2 is not None:
return None
if c is None and d1 is None and tril is None:
# Special case when no scale args are passed in. This means use an
# identity matrix.
if v is None and d2 is None:
c = 1.
# No scale.
else:
return None
matrix = np.float32(0.)
if c is not None:
# Infer the dimension from x.
......@@ -1321,11 +1315,6 @@ class AffineBijectorTest(tf.test.TestCase):
bijector_args = dict({"event_ndims": 1}, **args)
# Special case this, since the default value for this in the bijector
# is 1.0.
if "scale_identity_multiplier" not in bijector_args:
bijector_args["scale_identity_multiplier"] = None
# We haven't specified enough information for the scale.
if scale is None:
with self.assertRaisesRegexp(ValueError, ("must be specified.")):
......@@ -1671,7 +1660,7 @@ class InvertBijectorTest(tf.test.TestCase):
bijectors.Identity(),
bijectors.Exp(event_ndims=1),
bijectors.Affine(
shift=[0., 1.], scale_identity_multiplier=None,
shift=[0., 1.],
scale_diag=[2., 3.], event_ndims=1),
bijectors.Softplus(event_ndims=1),
bijectors.SoftmaxCentered(event_ndims=1),
......
......@@ -56,6 +56,7 @@ import re
import numpy as np
import six
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
from tensorflow.contrib.distributions.python.ops import operator_pd_identity
......@@ -1412,7 +1413,7 @@ class Affine(Bijector):
def __init__(self,
shift,
scale_identity_multiplier=1.0,
scale_identity_multiplier=None,
scale_diag=None,
scale_tril=None,
scale_perturb_diag=None,
......@@ -1491,7 +1492,9 @@ class Affine(Bijector):
super(Affine, self).__init__(
batch_ndims=self._infer_batch_ndims(),
event_ndims=event_ndims,
graph_parents=[self._shift, self._scale],
graph_parents=[self._shift] + (
[self._scale] if contrib_framework.is_tensor(self._scale)
else self._scale.inputs),
is_constant_jacobian=True,
validate_args=validate_args,
name=name)
......@@ -1532,10 +1535,19 @@ class Affine(Bijector):
"""
# Special case, only handling a scaled identity matrix. We don't know its
# dimensions, so this is special cased.
self._is_only_identity_multiplier = (identity_multiplier is not None and
diag is None and
# We don't check identity_multiplier, since below we set it to 1. if all
# other scale args are None.
self._is_only_identity_multiplier = (diag is None and
tril is None and
perturb_factor is None)
# When no args are specified, treat this as if it were an identity matrix.
if self._is_only_identity_multiplier and identity_multiplier is None:
identity_multiplier = 1.
# Ambiguous definition of low rank update.
if perturb_diag is not None and perturb_factor is None:
raise ValueError("When perturb_diag is specified, perturb_factor must be "
"specified.")
# TODO(srvasude): Create a Linear Operator corresponding to a lower
# triangular matrix, and make VDVTUpdate use that, removing this special
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册