提交 1602ac6a 编写于 作者: I Illia Polosukhin 提交者: TensorFlower Gardener

Fix typo in the run_local call of Experiment. Added test for run_local.

Change: 125458571
上级 002e1854
......@@ -121,9 +121,10 @@ class Experiment(object):
Returns:
The result of the `evaluate` call to the `Estimator`.
"""
self._train_monitors = self._train_monitors or []
if self._local_eval_frequency:
self._train_monitors += [monitors.ValidationMonitor(
input_fn=self._eval_input_fn, steps=self._eval_steps,
input_fn=self._eval_input_fn, eval_steps=self._eval_steps,
metrics=self._eval_metrics, every_n_steps=self._local_eval_frequency
)]
self.train()
......
......@@ -27,6 +27,7 @@ class TestEstimator(object):
def __init__(self):
self.eval_count = 0
self.fit_count = 0
self.monitors = []
def evaluate(self, **kwargs):
tf.logging.info('evaluate called with args: %s' % kwargs)
......@@ -39,6 +40,8 @@ class TestEstimator(object):
def fit(self, **kwargs):
tf.logging.info('fit called with args: %s' % kwargs)
self.fit_count += 1
if 'monitors' in kwargs:
self.monitors = kwargs['monitors']
return [(key, kwargs[key]) for key in sorted(kwargs.keys())]
......@@ -115,6 +118,22 @@ class ExperimentTest(tf.test.TestCase):
tf.logging.info('eval duration (expected %f): %f', expected, duration)
self.assertTrue(duration > expected - 0.5 and duration < expected + 0.5)
def test_run_local(self):
est = TestEstimator()
ex = tf.contrib.learn.Experiment(est,
train_input_fn='train_input',
eval_input_fn='eval_input',
eval_metrics='eval_metrics',
train_steps=100,
eval_steps=100,
local_eval_frequency=10)
ex.local_run()
self.assertEquals(1, est.fit_count)
self.assertEquals(1, est.eval_count)
self.assertEquals(1, len(est.monitors))
self.assertTrue(isinstance(est.monitors[0],
tf.contrib.learn.monitors.ValidationMonitor))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册