提交 3cf41b2e 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Test save/restore variable from graph_callable.

PiperOrigin-RevId: 171051237
上级 cf17ec96
......@@ -81,6 +81,7 @@ cuda_py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python/eager:graph_callable",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
......
......@@ -21,10 +21,14 @@ import os
from tensorflow.contrib.eager.python import saver as _saver
from tensorflow.python.eager import context
from tensorflow.python.eager import graph_callable
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
......@@ -87,6 +91,53 @@ class SaverTest(test.TestCase):
with _saver.restore_variables_on_create(ckpt_prefix):
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
def testSaveRestoreGraphCallable(self):
with context.eager_mode(), ops.device(self._dev()):
@graph_callable.graph_callable(
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
def model(x):
v = variable_scope.get_variable(
'v', initializer=init_ops.zeros_initializer(), shape=())
return v + x
# Default 2 + 0 = 2
self.assertEqual(
2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
# Save the variable value 0.
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
_saver.Saver(model.variables).save(ckpt_prefix)
# update variable to 1, so that 2 + 1 = 3
model.variables[0].assign(1.)
self.assertEqual(
3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
# load the variable value 0, so that 2 + 0 = 2
_saver.Saver(model.variables).restore(ckpt_prefix)
self.assertEqual(
2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
# update checkpoint variable to 1 and memory value to 2.
model.variables[0].assign(1.)
_saver.Saver(model.variables).save(ckpt_prefix)
model.variables[0].assign(2.)
self.assertEqual(
4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
# reset the graph and reload on create, so that 1 + 2 = 3
with ops.Graph().as_default():
with _saver.restore_variables_on_create(ckpt_prefix):
@graph_callable.graph_callable(
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
def model2(x):
v = variable_scope.get_variable(
'v', initializer=init_ops.zeros_initializer(), shape=())
return v + x
self.assertEqual(
3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
if __name__ == '__main__':
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册