提交 2f46b74c 编写于 作者: M Mustafa Ispir 提交者: TensorFlower Gardener

Test graph initialization logic in Estimator.

Change: 150479545
上级 86fae6bd
......@@ -392,6 +392,7 @@ class Estimator(object):
with ops.Graph().as_default() as g:
training.create_global_step(g)
random_seed.set_random_seed(self._config.tf_random_seed)
serving_input_receiver = serving_input_receiver_fn()
# Call the model_fn and collect the export_outputs.
......
......@@ -186,7 +186,8 @@ class EstimatorConstructorTest(test.TestCase):
def dummy_input_fn():
return {'x': [[1], [1]]}, [[1], [1]]
return ({'x': constant_op.constant([[1], [1]])},
constant_op.constant([[1], [1]]))
def model_fn_global_step_incrementer(features, labels, mode):
......@@ -357,7 +358,7 @@ class EstimatorTrainTest(test.TestCase):
mode=mode,
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
predictions=constant_op.constant(0.))
predictions=constant_op.constant([[0.]]))
est = estimator.Estimator(model_fn=_model_fn)
est.train(_input_fn, steps=1)
......@@ -365,6 +366,21 @@ class EstimatorTrainTest(test.TestCase):
self.assertEqual(given_labels, self.labels)
self.assertEqual(model_fn_lib.ModeKeys.TRAIN, self.mode)
def test_graph_initialization_global_step_and_random_seed(self):
expected_random_seed = run_config.RunConfig().tf_random_seed
def _model_fn(features, labels, mode):
_, _, _ = features, labels, mode
self.assertIsNotNone(training.get_global_step())
self.assertEqual(expected_random_seed, ops.get_default_graph().seed)
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
predictions=constant_op.constant([[0.]]))
est = estimator.Estimator(model_fn=_model_fn)
est.train(dummy_input_fn, steps=1)
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels
......@@ -552,7 +568,7 @@ class EstimatorEvaluateTest(test.TestCase):
mode=mode,
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
predictions=constant_op.constant(0.))
predictions=constant_op.constant([[0.]]))
est = estimator.Estimator(model_fn=_model_fn)
est.train(_input_fn, steps=1)
......@@ -561,6 +577,22 @@ class EstimatorEvaluateTest(test.TestCase):
self.assertEqual(given_labels, self.labels)
self.assertEqual(model_fn_lib.ModeKeys.EVAL, self.mode)
def test_graph_initialization_global_step_and_random_seed(self):
expected_random_seed = run_config.RunConfig().tf_random_seed
def _model_fn(features, labels, mode):
_, _, _ = features, labels, mode
self.assertIsNotNone(training.get_global_step())
self.assertEqual(expected_random_seed, ops.get_default_graph().seed)
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
predictions=constant_op.constant([[0.]]))
est = estimator.Estimator(model_fn=_model_fn)
est.train(dummy_input_fn, steps=1)
est.evaluate(dummy_input_fn, steps=1)
class EstimatorPredictTest(test.TestCase):
......@@ -816,6 +848,22 @@ class EstimatorPredictTest(test.TestCase):
self.assertIsNone(self.labels)
self.assertEqual(model_fn_lib.ModeKeys.PREDICT, self.mode)
def test_graph_initialization_global_step_and_random_seed(self):
expected_random_seed = run_config.RunConfig().tf_random_seed
def _model_fn(features, labels, mode):
_, _, _ = features, labels, mode
self.assertIsNotNone(training.get_global_step())
self.assertEqual(expected_random_seed, ops.get_default_graph().seed)
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
predictions=constant_op.constant([[0.]]))
est = estimator.Estimator(model_fn=_model_fn)
est.train(dummy_input_fn, steps=1)
next(est.predict(dummy_input_fn))
def _model_fn_for_export_tests(features, labels, mode):
_, _ = features, labels
......@@ -1139,6 +1187,30 @@ class EstimatorExportTest(test.TestCase):
self.assertIsNone(self.labels)
self.assertEqual(model_fn_lib.ModeKeys.PREDICT, self.mode)
def test_graph_initialization_global_step_and_random_seed(self):
expected_random_seed = run_config.RunConfig().tf_random_seed
def _model_fn(features, labels, mode):
_, _, _ = features, labels, mode
self.assertIsNotNone(training.get_global_step())
self.assertEqual(expected_random_seed, ops.get_default_graph().seed)
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
predictions=constant_op.constant([[0.]]),
export_outputs={
'test': export.ClassificationOutput(constant_op.constant([[0.]]))
})
def serving_input_receiver_fn():
return export.ServingInputReceiver(
{'test-features': constant_op.constant([[1], [1]])},
array_ops.placeholder(dtype=dtypes.string))
est = estimator.Estimator(model_fn=_model_fn)
est.train(dummy_input_fn, steps=1)
est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn)
class EstimatorIntegrationTest(test.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册