提交 ee983d8a 编写于 作者: E Eugene Brevdo 提交者: TensorFlower Gardener

Add VIMCO advantage function to bayesflow.

Change: 139853413
上级 fff98d1d
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
st = tf.contrib.bayesflow.stochastic_tensor
......@@ -25,6 +26,31 @@ sge = tf.contrib.bayesflow.stochastic_gradient_estimators
dists = tf.contrib.distributions
def _vimco(loss):
"""Python implementation of VIMCO."""
n = loss.shape[0]
log_loss = np.log(loss)
geometric_mean = []
for j in range(n):
geometric_mean.append(
np.exp(np.mean([log_loss[i, :] for i in range(n) if i != j], 0)))
geometric_mean = np.array(geometric_mean)
learning_signal = []
for j in range(n):
learning_signal.append(
np.sum([loss[i, :] for i in range(n) if i != j], 0))
learning_signal = np.array(learning_signal)
local_learning_signal = np.log(1/n * (learning_signal + geometric_mean))
# log_mean - local_learning_signal
log_mean = np.log(np.mean(loss, 0))
advantage = log_mean - local_learning_signal
return advantage
class StochasticGradientEstimatorsTest(tf.test.TestCase):
def setUp(self):
......@@ -97,6 +123,56 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
self._testScoreFunction(
sge.get_score_function_with_advantage(advantage_fn), expected)
def testVIMCOAdvantageFn(self):
# simple_loss: (3, 2) with 3 samples, batch size 2
simple_loss = np.array(
[[1.0, 1.5],
[1e-6, 1e4],
[2.0, 3.0]])
# random_loss: (100, 50, 64) with 100 samples, batch shape (50, 64)
random_loss = 100*np.random.rand(100, 50, 64)
advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=False)
with self.test_session() as sess:
for loss in [simple_loss, random_loss]:
expected = _vimco(loss)
loss_t = tf.constant(loss, dtype=tf.float32)
advantage_t = advantage_fn(None, loss_t) # ST is not used
advantage = sess.run(advantage_t)
self.assertEqual(expected.shape, advantage_t.get_shape())
self.assertAllClose(expected, advantage, atol=5e-5)
def testVIMCOAdvantageGradients(self):
loss = np.log(
[[1.0, 1.5],
[1e-6, 1e4],
[2.0, 3.0]])
advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True)
with self.test_session():
loss_t = tf.constant(loss, dtype=tf.float64)
advantage_t = advantage_fn(None, loss_t) # ST is not used
gradient_error = tf.test.compute_gradient_error(
loss_t, loss_t.get_shape().as_list(),
advantage_t, advantage_t.get_shape().as_list(),
x_init_value=loss)
self.assertLess(gradient_error, 1e-3)
def testVIMCOAdvantageWithSmallProbabilities(self):
theta_value = np.random.rand(10, 100000)
# Test with float16 dtype to ensure stability even in this extreme case.
theta = tf.constant(theta_value, dtype=tf.float16)
advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True)
with self.test_session() as sess:
log_loss = -tf.reduce_sum(theta, [1])
advantage_t = advantage_fn(None, log_loss)
grad_t = tf.gradients(advantage_t, theta)[0]
advantage, grad = sess.run((advantage_t, grad_t))
self.assertTrue(np.all(np.isfinite(advantage)))
self.assertTrue(np.all(np.isfinite(grad)))
def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
ema_decay = 0.8
x = st.StochasticTensor(
......
......@@ -56,6 +56,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
......@@ -194,4 +196,122 @@ def get_mean_baseline(ema_decay=0.99, name=None):
return mean_baseline
def get_vimco_advantage_fn(have_log_loss=False):
"""VIMCO (Variational Inference for Monte Carlo Objectives) baseline.
Implements VIMCO baseline from the article of the same name:
https://arxiv.org/pdf/1602.06725v2.pdf
Given a `loss` tensor (containing non-negative probabilities or ratios),
calculates the advantage VIMCO advantage via Eq. 9 of the above paper.
The tensor `loss` should be shaped `[n, ...]`, with rank at least 1. Here,
the first axis is considered the single sampling dimension and `n` must
be at least 2. Specifically, the `StochasticTensor` is assumed to have
used the `SampleValue(n)` value type with `n > 1`.
Args:
have_log_loss: Python `Boolean`. If `True`, the loss is assumed to be the
log loss. If `False` (the default), it is assumed to be a nonnegative
probability or probability ratio.
Returns:
Callable baseline function that takes the `StochasticTensor` (unused) and
the downstream `loss`, and returns the VIMCO baseline for the loss.
"""
def vimco_advantage_fn(_, loss, name=None):
"""Internal VIMCO function.
Args:
_: ignored `StochasticTensor`.
loss: The loss `Tensor`.
name: Python string, the name scope to use.
Returns:
The advantage `Tensor`.
"""
with ops.name_scope(name, "VIMCOAdvantage", values=[loss]):
loss = ops.convert_to_tensor(loss)
loss_shape = loss.get_shape()
loss_num_elements = loss_shape[0].value
n = math_ops.cast(
loss_num_elements or array_ops.shape(loss)[0], dtype=loss.dtype)
if have_log_loss:
log_loss = loss
else:
log_loss = math_ops.log(loss)
# Calculate L_hat, Eq. (4) -- stably
log_mean = math_ops.reduce_logsumexp(log_loss, [0]) - math_ops.log(n)
# expand_dims: Expand shape [a, b, c] to [a, 1, b, c]
log_loss_expanded = array_ops.expand_dims(log_loss, [1])
# divide: log_loss_sub with shape [a, a, b, c], where
#
# log_loss_sub[i] = log_loss - log_loss[i]
#
# = [ log_loss[j] - log_loss[i] for rows j = 0 ... i - 1 ]
# [ zeros ]
# [ log_loss[j] - log_loss[i] for rows j = i + 1 ... a - 1 ]
#
log_loss_sub = log_loss - log_loss_expanded
# reduce_sum: Sums each row across all the sub[i]'s; result is:
# reduce_sum[j] = (n - 1) * log_loss[j] - (sum_{i != j} loss[i])
# divide by (n - 1) to get:
# geometric_reduction[j] =
# log_loss[j] - (sum_{i != j} log_loss[i]) / (n - 1)
geometric_reduction = math_ops.reduce_sum(log_loss_sub, [0]) / (n - 1)
# subtract this from the original log_loss to get the baseline:
# geometric_mean[j] = exp((sum_{i != j} log_loss[i]) / (n - 1))
log_geometric_mean = log_loss - geometric_reduction
## Equation (9)
# Calculate sum_{i != j} loss[i] -- via exp(reduce_logsumexp(.))
# reduce_logsumexp: log-sum-exp each row across all the
# -sub[i]'s, result is:
#
# exp(reduce_logsumexp[j]) =
# 1 + sum_{i != j} exp(log_loss[i] - log_loss[j])
log_local_learning_reduction = math_ops.reduce_logsumexp(
-log_loss_sub, [0])
# convert local_learning_reduction to the sum-exp of the log-sum-exp
# (local_learning_reduction[j] - 1) * exp(log_loss[j])
# = sum_{i != j} exp(log_loss[i])
local_learning_log_sum = (
_logexpm1(log_local_learning_reduction) + log_loss)
# Add (logaddexp) the local learning signals (Eq. 9)
local_learning_signal = (
math_ops.reduce_logsumexp(
array_ops.stack((local_learning_log_sum, log_geometric_mean)),
[0])
- math_ops.log(n))
advantage = log_mean - local_learning_signal
return advantage
return vimco_advantage_fn
def _logexpm1(x):
"""Stably calculate log(exp(x)-1)."""
with ops.name_scope("logsumexp1"):
eps = np.finfo(x.dtype.as_numpy_dtype).eps
# Choose a small offset that makes gradient calculations stable for
# float16, float32, and float64.
safe_log = lambda y: math_ops.log(y + eps / 1e8) # For gradient stability
return array_ops.where(
math_ops.abs(x) < eps,
safe_log(x) + x/2 + x*x/24, # small x approximation to log(expm1(x))
safe_log(math_ops.exp(x) - 1))
__all__ = make_all(__name__)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册