提交 8b34de66 编写于 作者: W Wei Ho 提交者: TensorFlower Gardener

Adds option to pass callable initializer function to Variable constructor to...

Adds option to pass callable initializer function to Variable constructor to allow colocation of variable initialization with the device the variable is on, instead of always being on the chief supervisor.

Also updates variable_scope.get_variable() and create_partitioned_variables() to take advantage of this when an initializer fn is passed in.
Change: 119697860
上级 c052df40
......@@ -302,6 +302,53 @@ class VariablesTestCase(tf.test.TestCase):
self.assertEqual(var.op.device, init_op.device)
sess.run(init_op)
def testInitializerFunction(self):
value = [[-42], [133.7]]
shape = [2, 1]
with self.test_session():
initializer = lambda: tf.constant(value)
with self.assertRaises(ValueError):
# Checks that dtype must be specified.
tf.Variable(initializer)
v1 = tf.Variable(initializer, dtype=tf.float32)
self.assertEqual(shape, v1.get_shape())
self.assertAllClose(value, v1.initial_value.eval())
with self.assertRaises(tf.errors.FailedPreconditionError):
v1.eval()
v2 = tf.Variable(tf.neg(v1.initialized_value()), dtype=tf.float32)
self.assertEqual(v1.get_shape(), v2.get_shape())
self.assertAllClose(np.negative(value), v2.initial_value.eval())
# Once v2.initial_value.eval() has been called, v1 has effectively been
# initialized.
self.assertAllClose(value, v1.eval())
with self.assertRaises(tf.errors.FailedPreconditionError):
v2.eval()
tf.initialize_all_variables().run()
self.assertAllClose(np.negative(value), v2.eval())
def testInitializerFunctionDevicePlacement(self):
with self.test_session():
initializer = lambda: tf.constant(42.0)
with tf.device("/cpu:100"):
v1 = tf.Variable(initializer, dtype=tf.float32, name="v1")
expected_device = "/device:CPU:100"
expected_group_v1 = [b"loc:@v1"]
self.assertEqual(expected_device, v1.op.device)
self.assertEqual(expected_group_v1, v1.op.colocation_groups())
for i in v1.initializer.inputs:
self.assertEqual(expected_device, i.op.device)
self.assertEqual(expected_group_v1, i.op.colocation_groups())
v2 = tf.Variable(initializer, dtype=tf.float32, name="v2")
expected_group_v2 = [b"loc:@v2"]
self.assertEqual(expected_group_v2, v2.op.colocation_groups())
for i in v2.initializer.inputs:
self.assertEqual(expected_group_v2, i.op.colocation_groups())
class IsInitializedTest(tf.test.TestCase):
......
......@@ -167,19 +167,22 @@ def create_partitioned_variables(
slice_offset[slice_dim] += var_shape[slice_dim]
if callable(initializer):
init_val = initializer(var_shape, dtype=dtype)
init_val = ops.convert_to_tensor(init_val, dtype=dtype)
init = initializer
init_shape = var_shape
elif isinstance(initializer, ops.Tensor):
init_val = array_ops.slice(initializer, var_offset, var_shape)
init = array_ops.slice(initializer, var_offset, var_shape)
# Use the dtype of the given tensor.
dtype = init_val.dtype.base_dtype
dtype = init.dtype.base_dtype
init_shape = None
else:
init_val = ops.convert_to_tensor(initializer, dtype=dtype)
init_val = array_ops.slice(init_val, var_offset, var_shape)
init = ops.convert_to_tensor(initializer, dtype=dtype)
init = array_ops.slice(init, var_offset, var_shape)
init_shape = None
var = variable_scope.get_variable(name="part_%d" % i,
shape=init_shape,
dtype=dtype,
initializer=init_val,
initializer=init,
trainable=trainable,
collections=collections)
......
......@@ -144,14 +144,19 @@ class _VariableStore(object):
with ops.control_dependencies(None):
if initializing_from_value:
init_val = initializer
variable_dtype = None
else:
with ops.name_scope(name + "/Initializer/"):
init_val = initializer(shape.as_list(), dtype=dtype)
init_val = lambda: initializer(shape.as_list(), dtype=dtype)
variable_dtype = dtype.base_dtype
# Create the variable.
v = variables.Variable(init_val, name=name, trainable=trainable,
v = variables.Variable(initial_value=init_val,
name=name,
trainable=trainable,
collections=collections,
caching_device=caching_device)
caching_device=caching_device,
dtype=variable_dtype)
self._vars[name] = v
logging.info("Created variable %s with shape %s and init %s", v.name,
format(shape), initializer)
......
......@@ -156,9 +156,12 @@ class Variable(object):
variable to its initial value.
Args:
initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
The initial value for the Variable. Must have a shape specified unless
`validate_shape` is set to False.
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
which is the initial value for the Variable. The initial value must have
a shape specified unless `validate_shape` is set to False. Can also be a
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
......@@ -211,9 +214,12 @@ class Variable(object):
"""Creates a new variable from arguments.
Args:
initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
The initial value for the Variable. Must have a shape specified unless
`validate_shape` is set to False.
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
which is the initial value for the Variable. The initial value must have
a shape specified unless `validate_shape` is set to False. Can also be a
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
......@@ -240,25 +246,62 @@ class Variable(object):
"""
if initial_value is None:
raise ValueError("initial_value must be specified.")
init_from_fn = callable(initial_value)
if init_from_fn and dtype is None:
raise ValueError(
"dtype must also be specified when initial_value is callable.")
if collections is None:
collections = [ops.GraphKeys.VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.control_dependencies(None):
with ops.op_scope([initial_value], name, "Variable") as name:
self._initial_value = ops.convert_to_tensor(initial_value,
name="initial_value",
dtype=dtype)
initial_value_shape = self._initial_value.get_shape()
if validate_shape and not initial_value_shape.is_fully_defined():
raise ValueError("initial_value must have a shape specified: %s"
% self._initial_value)
shape_to_set = initial_value_shape if validate_shape else []
self._variable = state_ops.variable_op(
shape_to_set, self._initial_value.dtype.base_dtype,
set_shape=validate_shape, name=name)
with ops.op_scope(
[] if init_from_fn else [initial_value], name, "Variable") as name:
# Get the initial value from a callable function. The real shape of the
# variable will be set later, since under the init_from_fn case, the
# shape won't be known until after the function is invoked.
if init_from_fn:
self._variable = state_ops.variable_op(
[],
dtype.base_dtype,
set_shape=False,
name=name)
with ops.colocate_with(self._variable.op):
with ops.name_scope("Initializer"):
# Colocate the tensors created by the initial_value() function
# with the variable itself.
self._initial_value = ops.convert_to_tensor(initial_value(),
name="initial_value",
dtype=dtype)
# Or get the initial value from a Tensor or Python object.
else:
self._initial_value = ops.convert_to_tensor(initial_value,
name="initial_value",
dtype=dtype)
# In this case, the variable op can't be created until after the
# initial_value has been converted to a Tensor with a known type.
self._variable = state_ops.variable_op(
[],
self._initial_value.dtype.base_dtype,
set_shape=False,
name=name)
# Manually overrides the variable's shape with the initial value's.
if validate_shape:
initial_value_shape = self._initial_value.get_shape()
if not initial_value_shape.is_fully_defined():
raise ValueError("initial_value must have a shape specified: %s"
% self._initial_value)
self._variable.set_shape(initial_value_shape)
# TODO(b/28152992): Remove the below hack modifying the node_def shape
# directly once set_shape() handles it.
self._variable.op.node_def.attr["shape"].shape.CopyFrom(
initial_value_shape.as_proto())
# Assigns initial value.
with ops.colocate_with(self._variable.op):
self._initializer_op = state_ops.assign(
self._variable, self._initial_value,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册