提交 2522285c 编写于 作者: J Joshua V. Dillon 提交者: TensorFlower Gardener

Allow fully dynamic batch/event overrides.

Previously it was assumed overrides would be statically on or off (either way the override itself could have been non-static). This cl allows for fully dynamic specification of batch_shape and/or event_shape override, ie, we no longer require that the decision to override shape be known during graph construction.
Change: 143576091
上级 e1eae197
......@@ -85,20 +85,20 @@ class DistributionTest(test.TestCase):
sigma = 2.
normal = ds.Normal(mu, sigma, validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event()))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch()))
normal = ds.Normal([mu], [sigma], validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event))
self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event()))
self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch()))
mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event))
self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event()))
self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch()))
mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event()))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch()))
# We now test every codepath within the underlying is_scalar_helper
# function.
......
......@@ -20,11 +20,13 @@ from __future__ import print_function
import numpy as np
from scipy import stats
from tensorflow.contrib import distributions
from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import bijector as bijector_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
......@@ -73,8 +75,7 @@ class TransformedDistributionTest(test.TestCase):
# Note: the Jacobian callable only works for this example; more generally
# you may or may not need a reduce_sum.
log_normal = ds.TransformedDistribution(
distribution=ds.Normal(
mu=mu, sigma=sigma),
distribution=ds.Normal(mu=mu, sigma=sigma),
bijector=bs.Exp(event_ndims=0))
sp_dist = stats.lognorm(s=sigma, scale=np.exp(mu))
......@@ -105,41 +106,38 @@ class TransformedDistributionTest(test.TestCase):
mu = 3.0
sigma = 0.02
log_normal = ds.TransformedDistribution(
distribution=ds.Normal(
mu=mu, sigma=sigma),
distribution=ds.Normal(mu=mu, sigma=sigma),
bijector=bs.Exp(event_ndims=0))
sample = log_normal.sample(1)
sample_val, log_pdf_val = sess.run([sample, log_normal.log_pdf(sample)])
self.assertAllClose(
stats.lognorm.logpdf(
sample_val, s=sigma, scale=np.exp(mu)),
stats.lognorm.logpdf(sample_val, s=sigma, scale=np.exp(mu)),
log_pdf_val,
atol=1e-2)
def testConditioning(self):
with self.test_session():
conditional_normal = ds.TransformedDistribution(
distribution=ds.Normal(
mu=0., sigma=1.),
distribution=ds.Normal(mu=0., sigma=1.),
bijector=_ChooseLocation(loc=[-100., 100.]))
z = [-1, +1, -1, -1, +1]
self.assertAllClose(
np.sign(
conditional_normal.sample(
5, bijector_kwargs={"z": z}).eval()),
z)
np.sign(conditional_normal.sample(
5, bijector_kwargs={"z": z}).eval()), z)
def testShapeChangingBijector(self):
with self.test_session():
softmax = bs.SoftmaxCentered()
standard_normal = ds.Normal(mu=0., sigma=1.)
multi_logit_normal = ds.TransformedDistribution(
distribution=standard_normal, bijector=softmax)
x = [[-np.log(3.), 0.], [np.log(3), np.log(5)]]
distribution=standard_normal,
bijector=softmax)
x = [[-np.log(3.), 0.],
[np.log(3), np.log(5)]]
y = softmax.forward(x).eval()
expected_log_pdf = (stats.norm(
loc=0., scale=1.).logpdf(x) - np.sum(np.log(y), axis=-1))
expected_log_pdf = (stats.norm(loc=0., scale=1.).logpdf(x) -
np.sum(np.log(y), axis=-1))
self.assertAllClose(expected_log_pdf,
multi_logit_normal.log_prob(y).eval())
self.assertAllClose(
......@@ -152,7 +150,9 @@ class TransformedDistributionTest(test.TestCase):
with self.test_session():
shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
diag = np.array([[1, 2, 3], [2, 3, 2]], dtype=np.float32)
actual_mvn = ds.MultivariateNormalDiag(shift, diag, validate_args=True)
actual_mvn_entropy = np.concatenate([
[stats.multivariate_normal(shift[i], np.diag(diag[i]**2)).entropy()]
for i in range(len(diag))])
fake_mvn = ds.TransformedDistribution(
ds.MultivariateNormalDiag(
array_ops.zeros_like(shift),
......@@ -160,11 +160,10 @@ class TransformedDistributionTest(test.TestCase):
validate_args=True),
bs.AffineLinearOperator(
shift,
scale=la.LinearOperatorDiag(
diag, is_non_singular=True),
scale=la.LinearOperatorDiag(diag, is_non_singular=True),
validate_args=True),
validate_args=True)
self.assertAllClose(actual_mvn.entropy().eval(),
self.assertAllClose(actual_mvn_entropy,
fake_mvn.entropy().eval())
......@@ -181,121 +180,151 @@ class ScalarToMultiTest(test.TestCase):
dtype=np.float32)
def _testMVN(self,
base_distribution,
batch_shape=None,
event_shape=None,
base_distribution_class,
base_distribution_kwargs,
batch_shape=(),
event_shape=(),
not_implemented_message=None):
with self.test_session() as sess:
# Overriding shapes must be compatible w/bijector; most bijectors are
# batch_shape agnostic and only care about event_ndims.
# In the case of `Affine`, if we got it wrong then it would fire an
# exception due to incompatible dimensions.
fake_mvn = ds.TransformedDistribution(
distribution=base_distribution[0](validate_args=True,
**base_distribution[1]),
bijector=bs.Affine(
shift=self._shift, scale_tril=self._tril),
batch_shape_pl = array_ops.placeholder(
dtypes.int32, name="dynamic_batch_shape")
event_shape_pl = array_ops.placeholder(
dtypes.int32, name="dynamic_event_shape")
feed_dict = {batch_shape_pl: np.array(batch_shape, dtype=np.int32),
event_shape_pl: np.array(event_shape, dtype=np.int32)}
fake_mvn_dynamic = ds.TransformedDistribution(
distribution=base_distribution_class(validate_args=True,
**base_distribution_kwargs),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=batch_shape_pl,
event_shape=event_shape_pl,
validate_args=True)
fake_mvn_static = ds.TransformedDistribution(
distribution=base_distribution_class(validate_args=True,
**base_distribution_kwargs),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=True)
actual_mean = np.tile(self._shift, [2, 1]) # Affine elided this tile.
actual_cov = np.matmul(self._tril, np.transpose(self._tril, [0, 2, 1]))
actual_mvn = ds.MultivariateNormalFull(mu=actual_mean, sigma=actual_cov)
# Ensure sample works by checking first, second moments.
n = 5e3
y = fake_mvn.sample(int(n), seed=0)
sample_mean = math_ops.reduce_mean(y, 0)
centered_y = array_ops.transpose(y - sample_mean, [1, 2, 0])
sample_cov = math_ops.matmul(centered_y, centered_y, transpose_b=True) / n
[sample_mean_, sample_cov_] = sess.run([sample_mean, sample_cov])
self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1)
self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1)
# Ensure all other functions work as intended.
x = fake_mvn.sample(5, seed=0).eval()
self.assertAllEqual([5, 2, 3], x.shape)
self.assertAllEqual(actual_mvn.get_event_shape(),
fake_mvn.get_event_shape())
self.assertAllEqual(actual_mvn.event_shape().eval(),
fake_mvn.event_shape().eval())
self.assertAllEqual(actual_mvn.get_batch_shape(),
fake_mvn.get_batch_shape())
self.assertAllEqual(actual_mvn.batch_shape().eval(),
fake_mvn.batch_shape().eval())
self.assertAllClose(
actual_mvn.log_prob(x).eval(),
fake_mvn.log_prob(x).eval(),
atol=0.,
rtol=1e-7)
self.assertAllClose(
actual_mvn.prob(x).eval(),
fake_mvn.prob(x).eval(),
atol=0.,
rtol=1e-6)
self.assertAllClose(
actual_mvn.entropy().eval(),
fake_mvn.entropy().eval(),
atol=0.,
rtol=1e-6)
for unsupported_fn in (fake_mvn.log_cdf, fake_mvn.cdf,
fake_mvn.survival_function,
fake_mvn.log_survival_function):
def actual_mvn_log_prob(x):
return np.concatenate([
[stats.multivariate_normal(
actual_mean[i], actual_cov[i]).logpdf(x[:, i, :])]
for i in range(len(actual_cov))]).T
actual_mvn_entropy = np.concatenate([
[stats.multivariate_normal(
actual_mean[i], actual_cov[i]).entropy()]
for i in range(len(actual_cov))])
self.assertAllEqual([3], fake_mvn_static.get_event_shape())
self.assertAllEqual([2], fake_mvn_static.get_batch_shape())
self.assertAllEqual(tensor_shape.TensorShape(None),
fake_mvn_dynamic.get_event_shape())
self.assertAllEqual(tensor_shape.TensorShape(None),
fake_mvn_dynamic.get_batch_shape())
x = fake_mvn_static.sample(5, seed=0).eval()
for unsupported_fn in (fake_mvn_static.log_cdf,
fake_mvn_static.cdf,
fake_mvn_static.survival_function,
fake_mvn_static.log_survival_function):
with self.assertRaisesRegexp(NotImplementedError,
not_implemented_message):
self.assertRaisesRegexp(unsupported_fn(x))
unsupported_fn(x)
num_samples = 5e3
for fake_mvn, feed_dict in ((fake_mvn_static, {}),
(fake_mvn_dynamic, feed_dict)):
# Ensure sample works by checking first, second moments.
y = fake_mvn.sample(int(num_samples), seed=0)
x = y[0:5, ...]
sample_mean = math_ops.reduce_mean(y, 0)
centered_y = array_ops.transpose(y - sample_mean, [1, 2, 0])
sample_cov = math_ops.matmul(
centered_y, centered_y, transpose_b=True) / num_samples
[
sample_mean_,
sample_cov_,
x_,
fake_event_shape_,
fake_batch_shape_,
fake_log_prob_,
fake_prob_,
fake_entropy_,
] = sess.run([
sample_mean,
sample_cov,
x,
fake_mvn.event_shape(),
fake_mvn.batch_shape(),
fake_mvn.log_prob(x),
fake_mvn.prob(x),
fake_mvn.entropy(),
], feed_dict=feed_dict)
self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1)
self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1)
# Ensure all other functions work as intended.
self.assertAllEqual([5, 2, 3], x_.shape)
self.assertAllEqual([3], fake_event_shape_)
self.assertAllEqual([2], fake_batch_shape_)
self.assertAllClose(actual_mvn_log_prob(x_), fake_log_prob_,
atol=0., rtol=1e-6)
self.assertAllClose(np.exp(actual_mvn_log_prob(x_)), fake_prob_,
atol=0., rtol=1e-5)
self.assertAllClose(actual_mvn_entropy, fake_entropy_,
atol=0., rtol=1e-6)
def testScalarBatchScalarEvent(self):
self._testMVN(
base_distribution=[ds.Normal, {
"mu": 0.,
"sigma": 1.
}],
base_distribution_class=ds.Normal,
base_distribution_kwargs={"mu": 0., "sigma": 1.},
batch_shape=[2],
event_shape=[3],
not_implemented_message="not implemented when overriding event_shape")
def testScalarBatchNonScalarEvent(self):
self._testMVN(
base_distribution=[
ds.MultivariateNormalDiag, {
"mu": [0., 0., 0.],
"diag_stdev": [1., 1, 1]
}
],
base_distribution_class=ds.MultivariateNormalDiag,
base_distribution_kwargs={"mu": [0., 0., 0.], "diag_stdev": [1., 1, 1]},
batch_shape=[2],
not_implemented_message="not implemented$")
with self.test_session():
# Can't override event_shape for scalar batch, non-scalar event.
with self.assertRaisesRegexp(ValueError, "requires scalar"):
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
ds.TransformedDistribution(
distribution=ds.MultivariateNormalDiag(
mu=[0.], diag_stdev=[1.]),
bijector=bs.Affine(
shift=self._shift, scale_tril=self._tril),
distribution=ds.MultivariateNormalDiag(mu=[0.], diag_stdev=[1.]),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=[2],
event_shape=[3],
validate_args=True)
def testNonScalarBatchScalarEvent(self):
self._testMVN(
base_distribution=[ds.Normal, {
"mu": [0., 0],
"sigma": [1., 1]
}],
base_distribution_class=ds.Normal,
base_distribution_kwargs={"mu": [0., 0], "sigma": [1., 1]},
event_shape=[3],
not_implemented_message="not implemented when overriding event_shape")
with self.test_session():
# Can't override batch_shape for non-scalar batch, scalar event.
with self.assertRaisesRegexp(ValueError, "requires scalar"):
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
ds.TransformedDistribution(
distribution=ds.Normal(
mu=[0.], sigma=[1.]),
bijector=bs.Affine(
shift=self._shift, scale_tril=self._tril),
distribution=ds.Normal(mu=[0.], sigma=[1.]),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=[2],
event_shape=[3],
validate_args=True)
......@@ -304,12 +333,11 @@ class ScalarToMultiTest(test.TestCase):
with self.test_session():
# Can't override event_shape and/or batch_shape for non_scalar batch,
# non-scalar event.
with self.assertRaisesRegexp(ValueError, "requires scalar"):
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
ds.TransformedDistribution(
distribution=ds.MultivariateNormalDiag(
mu=[[0.]], diag_stdev=[[1.]]),
bijector=bs.Affine(
shift=self._shift, scale_tril=self._tril),
distribution=ds.MultivariateNormalDiag(mu=[[0.]],
diag_stdev=[[1.]]),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=[2],
event_shape=[3],
validate_args=True)
......
......@@ -526,19 +526,33 @@ class Distribution(_BaseDistribution):
"""
return self._get_event_shape()
@property
def is_scalar_event(self):
"""Indicates that `event_shape==[]`."""
return ops.convert_to_tensor(
self._is_scalar_helper(self.get_event_shape, self.event_shape),
name="is_scalar_event")
def is_scalar_event(self, name="is_scalar_event"):
"""Indicates that `event_shape == []`.
@property
def is_scalar_batch(self):
"""Indicates that `batch_shape==[]`."""
return ops.convert_to_tensor(
self._is_scalar_helper(self.get_batch_shape, self.batch_shape),
name="is_scalar_batch")
Args:
name: The name to give this op.
Returns:
is_scalar_event: `Boolean` `scalar` `Tensor`.
"""
with self._name_scope(name):
return ops.convert_to_tensor(
self._is_scalar_helper(self.get_event_shape, self.event_shape),
name="is_scalar_event")
def is_scalar_batch(self, name="is_scalar_batch"):
"""Indicates that `batch_shape == []`.
Args:
name: The name to give this op.
Returns:
is_scalar_batch: `Boolean` `scalar` `Tensor`.
"""
with self._name_scope(name):
return ops.convert_to_tensor(
self._is_scalar_helper(self.get_batch_shape, self.batch_shape),
name="is_scalar_batch")
def _sample_n(self, n, seed=None):
raise NotImplementedError("sample_n is not implemented")
......@@ -888,7 +902,7 @@ class Distribution(_BaseDistribution):
"""Helper function to standardize op scope."""
with ops.name_scope(self.name):
with ops.name_scope(name, values=(
(values or []) + self._graph_parents)) as scope:
([] if values is None else values) + self._graph_parents)) as scope:
yield scope
def _expand_sample_shape_to_vector(self, x, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册