提交 ff0412fb 编写于 作者: I Ian Langmore 提交者: TensorFlower Gardener

Speed up quantized_distribution_test by reducing the number of times we add to the graph or eval.

Before: 1min 33s
After 30s
Change: 133887796
上级 aa509693
......@@ -33,7 +33,7 @@ class QuantizedDistributionTest(tf.test.TestCase):
self.assertTrue(np.isfinite(array).all())
def test_quantization_of_uniform_with_cutoffs_having_no_effect(self):
with self.test_session():
with self.test_session() as sess:
# The Quantized uniform with cutoffs == None divides the real line into:
# R = ...(-1, 0](0, 1](1, 2](2, 3](3, 4]...
# j = ... 0 1 2 3 4 ...
......@@ -60,34 +60,38 @@ class QuantizedDistributionTest(tf.test.TestCase):
b=3.0)
# pmf
pmf_n1, pmf_0, pmf_1, pmf_2, pmf_3, pmf_4, pmf_5 = sess.run(
qdist.pmf([-1., 0., 1., 2., 3., 4., 5.]))
# uniform had no mass below -1.
self.assertAllClose(0., qdist.pmf(-1.).eval())
self.assertAllClose(0., pmf_n1)
# uniform had no mass below 0.
self.assertAllClose(0., qdist.pmf(0.).eval())
self.assertAllClose(0., pmf_0)
# uniform put 1/3 of its mass in each of (0, 1], (1, 2], (2, 3],
# which are the intervals j = 1, 2, 3.
self.assertAllClose(1 / 3, qdist.pmf(1.).eval())
self.assertAllClose(1 / 3, qdist.pmf(2.).eval())
self.assertAllClose(1 / 3, qdist.pmf(3.).eval())
self.assertAllClose(1 / 3, pmf_1)
self.assertAllClose(1 / 3, pmf_2)
self.assertAllClose(1 / 3, pmf_3)
# uniform had no mass in (3, 4] or (4, 5], which are j = 4, 5.
self.assertAllClose(0 / 3, qdist.pmf(4.).eval())
self.assertAllClose(0 / 3, qdist.pmf(5.).eval())
self.assertAllClose(0 / 3, pmf_4)
self.assertAllClose(0 / 3, pmf_5)
# cdf
self.assertAllClose(0., qdist.cdf(-1.).eval())
self.assertAllClose(0., qdist.cdf(0.).eval())
self.assertAllClose(1 / 3, qdist.cdf(1.).eval())
self.assertAllClose(2 / 3, qdist.cdf(2.).eval())
cdf_n1, cdf_0, cdf_1, cdf_2, cdf_2p5, cdf_3, cdf_4, cdf_5 = sess.run(
qdist.cdf([-1., 0., 1., 2., 2.5, 3., 4., 5.]))
self.assertAllClose(0., cdf_n1)
self.assertAllClose(0., cdf_0)
self.assertAllClose(1 / 3, cdf_1)
self.assertAllClose(2 / 3, cdf_2)
# Note fractional values allowed for cdfs of discrete distributions.
# And adding 0.5 makes no difference because the quantized dist has
# mass only on the integers, never in between.
self.assertAllClose(2 / 3, qdist.cdf(2.5).eval())
self.assertAllClose(3 / 3, qdist.cdf(3.).eval())
self.assertAllClose(3 / 3, qdist.cdf(4.).eval())
self.assertAllClose(3 / 3, qdist.cdf(5.).eval())
self.assertAllClose(2 / 3, cdf_2p5)
self.assertAllClose(3 / 3, cdf_3)
self.assertAllClose(3 / 3, cdf_4)
self.assertAllClose(3 / 3, cdf_5)
def test_quantization_of_uniform_with_cutoffs_in_the_middle(self):
with self.test_session():
with self.test_session() as sess:
# The uniform is supported on [-3, 3]
# Consider partitions the real line in intervals
# ...(-3, -2](-2, -1](-1, 0](0, 1](1, 2](2, 3] ...
......@@ -103,25 +107,27 @@ class QuantizedDistributionTest(tf.test.TestCase):
b=3.0)
# pmf
cdf_n3, cdf_n2, cdf_n1, cdf_0, cdf_0p5, cdf_1, cdf_10 = sess.run(
qdist.cdf([-3., -2., -1., 0., 0.5, 1.0, 10.0]))
# Uniform had no mass on (-4, -3] or (-3, -2]
self.assertAllClose(0., qdist.cdf(-3.).eval())
self.assertAllClose(0., qdist.cdf(-2.).eval())
self.assertAllClose(0., cdf_n3)
self.assertAllClose(0., cdf_n2)
# Uniform had 1/6 of its mass in each of (-3, -2], and (-2, -1], which
# were collapsed into (-infty, -1], which is now the "-1" interval.
self.assertAllClose(1 / 3, qdist.cdf(-1.).eval())
self.assertAllClose(1 / 3, cdf_n1)
# The j=0 interval contained mass from (-3, 0], which is 1/2 of the
# uniform's mass.
self.assertAllClose(1 / 2, qdist.cdf(0.).eval())
self.assertAllClose(1 / 2, cdf_0)
# Adding 0.5 makes no difference because the quantized dist has mass on
# the integers, not in between them.
self.assertAllClose(1 / 2, qdist.cdf(0.5).eval())
self.assertAllClose(1 / 2, cdf_0p5)
# After applying the cutoff, all mass was either in the interval
# (0, infty), or below. (0, infty) is the interval indexed by j=1,
# so pmf(1) should equal 1.
self.assertAllClose(1., qdist.cdf(1.0).eval())
self.assertAllClose(1., cdf_1)
# Since no mass of qdist is above 1,
# pmf(10) = P[Y <= 10] = P[Y <= 1] = pmf(1).
self.assertAllClose(1., qdist.cdf(10.0).eval())
self.assertAllClose(1., cdf_10)
def test_quantization_of_batch_of_uniforms(self):
batch_shape = (5, 5)
......@@ -231,10 +237,12 @@ class QuantizedDistributionTest(tf.test.TestCase):
# The smallest value the samples can take on is 1, which corresponds to
# the interval (0, 1]. Recall we use ceiling in the sampling definition.
self.assertLess(0.5, samps.min())
for x in range(1, 10):
x_vals = np.arange(1, 11).astype(np.float32)
pmf_vals = qdist.pmf(x_vals).eval()
for ii in range(10):
self.assertAllClose(
qdist.pmf(float(x)).eval(),
(samps == x).mean(),
pmf_vals[ii],
(samps == x_vals[ii]).mean(),
atol=std_err_bound)
def test_normal_cdf_and_survival_function(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册