提交 4f7503a8 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

K-FAC: Support for registering multiple minibatches with register_fully_connected()

PiperOrigin-RevId: 173121735
上级 2845bfcd
......@@ -282,6 +282,73 @@ class LayerCollectionTest(test.TestCase):
single_loss = sess.run(lc.total_loss())
self.assertAlmostEqual(7.6983433, single_loss)
def testRegisterFullyConnectedReuse(self):
"""Ensure the 'reuse' keyword argument function as intended."""
with ops.Graph().as_default():
inputs = [
array_ops.ones([2, 10]), #
array_ops.zeros([5, 10])
]
outputs = [
array_ops.zeros([2, 5]), #
array_ops.ones([5, 5])
]
params = (
variable_scope.get_variable('w', [10, 5]), #
variable_scope.get_variable('b', [5]))
# Fails on second if reuse=False.
lc = layer_collection.LayerCollection()
lc.register_fully_connected(params, inputs[0], outputs[0])
with self.assertRaises(ValueError):
lc.register_fully_connected(params, inputs[1], outputs[1], reuse=False)
# Succeeds on second if reuse=True.
lc = layer_collection.LayerCollection()
lc.register_fully_connected(params, inputs[0], outputs[0])
lc.register_fully_connected(params, inputs[1], outputs[1], reuse=True)
# Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
lc = layer_collection.LayerCollection()
lc.register_fully_connected(params, inputs[0], outputs[0])
with self.assertRaises(ValueError):
lc.register_fully_connected(
params,
inputs[1],
outputs[1],
reuse=layer_collection.VARIABLE_SCOPE)
# Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
lc = layer_collection.LayerCollection()
lc.register_fully_connected(params, inputs[0], outputs[0])
with variable_scope.variable_scope(
variable_scope.get_variable_scope(), reuse=True):
lc.register_fully_connected(
params,
inputs[1],
outputs[1],
reuse=layer_collection.VARIABLE_SCOPE)
# Fails if block type changes.
lc = layer_collection.LayerCollection()
lc.register_fully_connected(
params,
inputs[0],
outputs[0],
approx=layer_collection.APPROX_KRONECKER_NAME)
with self.assertRaises(ValueError):
lc.register_fully_connected(
params,
inputs[1],
outputs[1],
approx=layer_collection.APPROX_DIAGONAL_NAME,
reuse=True)
# Fails if reuse requested but no FisherBlock exists.
lc = layer_collection.LayerCollection()
with self.assertRaises(KeyError):
lc.register_fully_connected(params, inputs[0], outputs[0], reuse=True)
def testMakeOrGetFactor(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
......
......@@ -39,10 +39,15 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
# Names for various approximations that can be requested for Fisher blocks.
APPROX_KRONECKER_NAME = "kron"
APPROX_DIAGONAL_NAME = "diagonal"
APPROX_FULL_NAME = "full"
# Possible value for 'reuse' keyword argument. Sets 'reuse' to
# tf.get_variable_scope().reuse.
VARIABLE_SCOPE = "VARIABLE_SCOPE"
# TODO(jamesmartens): need to add find_canonical_output back into this somewhere
......@@ -254,18 +259,57 @@ class LayerCollection(object):
params,
inputs,
outputs,
approx=APPROX_KRONECKER_NAME):
approx=APPROX_KRONECKER_NAME,
reuse=VARIABLE_SCOPE):
"""Registers a fully connnected layer.
Args:
params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
this layer. Weight matrix should have shape [input_size, output_size].
Bias should have shape [output_size].
inputs: Tensor of shape [batch_size, input_size]. Inputs to layer.
outputs: Tensor of shape [batch_size, output_size]. Preactivations
produced by layer.
approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME.
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If VARIABLE_SCOPE, use
tf.get_variable_scope().reuse.
Raises:
ValueError: For improper value to 'approx'.
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
approx_to_block_types = {
APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
}
if approx not in approx_to_block_types:
raise ValueError("Bad value {} for approx.".format(approx))
block_type = approx_to_block_types[approx]
has_bias = isinstance(params, (tuple, list))
if approx == APPROX_KRONECKER_NAME:
block = fb.FullyConnectedKFACBasicFB(self, has_bias)
block.register_additional_minibatch(inputs, outputs)
self.register_block(params, block)
elif approx == APPROX_DIAGONAL_NAME:
block = fb.FullyConnectedDiagonalFB(self, has_bias)
block.register_additional_minibatch(inputs, outputs)
self.register_block(params, block)
if reuse == VARIABLE_SCOPE:
reuse = variable_scope.get_variable_scope().reuse
if reuse:
block = self.fisher_blocks.get(params, None)
if block is None:
raise KeyError(
"Reuse requested but no FisherBlock found for params {}.".format(
params))
if not isinstance(block, block_type):
raise ValueError(
"Requested block of type {} but block of type {} already exists "
"for params {}.".format(block_type, type(block), params))
else:
raise ValueError("Bad value {} for approx.".format(approx))
block = block_type(self, has_bias)
self.register_block(params, block)
block.register_additional_minibatch(inputs, outputs)
def register_conv2d(self, params, strides, padding, inputs, outputs,
approx=APPROX_KRONECKER_NAME):
......
......@@ -35,6 +35,7 @@ _allowed_symbols = [
"APPROX_KRONECKER_NAME",
"APPROX_DIAGONAL_NAME",
"APPROX_FULL_NAME",
"VARIABLE_SCOPE",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册