提交 ac96df2d 编写于 作者: S Sherry Moore 提交者: TensorFlower Gardener

Added is_variable_initialized(variable) function.

Change: 119321281
上级 a0bc7959
......@@ -28,6 +28,8 @@ REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
TemporaryVariableOp);
REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
DestroyTemporaryVariableOp);
REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
IsVariableInitializedOp);
#if GOOGLE_CUDA
// Only register 'Variable' on GPU for the subset of types also supported by
......@@ -43,7 +45,12 @@ REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T"), \
DestroyTemporaryVariableOp);
DestroyTemporaryVariableOp); \
REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("dtype") \
.HostMemory("is_initialized"), \
IsVariableInitializedOp);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
......
......@@ -158,6 +158,22 @@ class DestroyTemporaryVariableOp : public OpKernel {
string var_name_;
};
class IsVariableInitializedOp : public OpKernel {
public:
IsVariableInitializedOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Get a mutable input tensor of the Ref input.
const Tensor& input_tensor = context->mutable_input(0, false);
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}), &output));
auto output_tensor = output->tensor<bool, 0>();
bool result = input_tensor.IsInitialized();
output_tensor() = result;
}
};
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_VARIABLE_OPS_H_
......@@ -40,6 +40,20 @@ shared_name: If non-empty, this variable is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("IsVariableInitialized")
.Output("is_initialized: bool")
.Input("ref: Ref(dtype)")
.Attr("dtype: type")
.SetAllowsUninitializedInput()
.Doc(R"doc(
Checks whether a tensor has been initialized.
Outputs boolean scalar indicating whether the tensor has been initialized.
ref: Should be from a `Variable` node. May be uninitialized.
dtype: The type of elements in the variable tensor.
)doc");
REGISTER_OP("TemporaryVariable")
.Output("ref: Ref(dtype)")
.Attr("shape: shape")
......
......@@ -237,6 +237,14 @@ class VariableOpTest(tf.test.TestCase):
result = tf.mul(var, var)
self.assertAllClose([4.0], result.eval())
def testIsVariableInitialized(self):
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu):
v0 = state_ops.variable_op([1, 2], tf.float32)
self.assertEqual(False, tf.is_variable_initialized(v0).eval())
tf.assign(v0, [[2.0, 3.0]]).eval()
self.assertEqual(True, tf.is_variable_initialized(v0).eval())
if __name__ == "__main__":
tf.test.main()
......@@ -30,6 +30,7 @@ collected in the graph.
@@initialize_all_variables
@@initialize_variables
@@initialize_local_variables
@@is_variable_initialized
@@assert_variables_initialized
## Saving and Restoring Variables
......@@ -134,6 +135,8 @@ def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
# NOTE(mrry): Shapes are conditionally set in the Python wrapper.
ops.RegisterShape("Variable")(common_shapes.unknown_shape)
ops.RegisterShape("IsVariableInitialized")(common_shapes.scalar_shape)
@ops.RegisterShape("TemporaryVariable")
def _TemporaryVariableShape(op):
......
......@@ -798,6 +798,18 @@ def initialize_local_variables():
return initialize_variables(local_variables())
def is_variable_initialized(variable):
"""Returns an Op to check if a variable has been initialized.
Args:
variable: A `Variable`.
Returns:
An operation to check whether a variable has been initialized.
"""
return state_ops.is_variable_initialized(variable)
def assert_variables_initialized(var_list=None):
"""Returns an Op to check if variables are initialized.
......
......@@ -71,6 +71,8 @@ class SessionManagerTest(tf.test.TestCase):
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
v = tf.Variable([6.0, 7.0, 8.0], name="v")
with self.test_session():
self.assertEqual(False, tf.is_variable_initialized(v).eval())
tf.train.SessionManager(ready_op=tf.assert_variables_initialized())
saver = tf.train.Saver({"v": v})
# This should fail as there's no checkpoint within 2 seconds.
......@@ -85,6 +87,9 @@ class SessionManagerTest(tf.test.TestCase):
sess = sm.prepare_session("", init_op=None, saver=saver,
checkpoint_dir=checkpoint_dir,
wait_for_checkpoint=True, max_wait_secs=2)
self.assertEqual(
True, tf.is_variable_initialized(
sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
def testRecoverSession(self):
# Create a checkpoint.
......@@ -109,11 +114,16 @@ class SessionManagerTest(tf.test.TestCase):
# Create a new Graph and SessionManager and recover.
with tf.Graph().as_default():
v = tf.Variable(2, name="v")
with self.test_session():
self.assertEqual(False, tf.is_variable_initialized(v).eval())
sm2 = tf.train.SessionManager(ready_op=tf.assert_variables_initialized())
saver = tf.train.Saver({"v": v})
sess, initialized = sm2.recover_session("", saver=saver,
checkpoint_dir=checkpoint_dir)
self.assertTrue(initialized)
self.assertEqual(
True, tf.is_variable_initialized(
sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
self.assertEquals(1, sess.run(v))
def testWaitForSessionReturnsNoneAfterTimeout(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册