diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index d081c40285841e460e8e2f54e9c760b240c89eaf..01cad8eb916cdfd3180aac56f57a3e94c419a7f1 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -24,8 +24,11 @@ import time from tensorflow.contrib.learn.python.learn import monitors from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError +from tensorflow.python.platform import flags from tensorflow.python.platform import tf_logging as logging +FLAGS = flags.FLAGS + class Experiment(object): """Experiment is a class containing all information needed to train a model. @@ -70,7 +73,7 @@ class Experiment(object): self._train_monitors = train_monitors self._local_eval_frequency = local_eval_frequency - def train(self, delay_secs=0): + def train(self, delay_secs=None): """Fit the estimator using the training data. Train the estimator for `steps` steps, after waiting for `delay_secs` @@ -83,6 +86,12 @@ class Experiment(object): The trained estimator. """ + if delay_secs is None: + task_id = 0 + if "task" in FLAGS: + task_id = FLAGS.task + delay_secs = min(60, task_id*5) + if delay_secs: logging.info("Waiting %d secs before starting training.", delay_secs) time.sleep(delay_secs) @@ -91,7 +100,7 @@ class Experiment(object): max_steps=self._train_steps, monitors=self._train_monitors) - def evaluate(self, delay_secs=0): + def evaluate(self, delay_secs=120): """Evaluate on the evaluation data. Runs evaluation on the evaluation data and returns the result. If `steps` @@ -127,13 +136,13 @@ class Experiment(object): input_fn=self._eval_input_fn, eval_steps=self._eval_steps, metrics=self._eval_metrics, every_n_steps=self._local_eval_frequency )] - self.train() - return self.evaluate() + self.train(delay_secs=0) + return self.evaluate(delay_secs=0) def _continuous_eval(self, input_fn, name, - delay_secs=0, + delay_secs=120, throttle_delay_secs=60): """Run continuous eval. @@ -160,8 +169,7 @@ class Experiment(object): metrics=self._eval_metrics, name=name) except NotFittedError: - logging.warning("Estimator is not fitted yet, skipping evaluation. " - "Increase 'delay_secs' to avoid this warning.") + logging.warning("Estimator is not fitted yet, skipping evaluation.") duration = time.time() - start if duration < throttle_delay_secs: difference = throttle_delay_secs - duration @@ -169,13 +177,14 @@ class Experiment(object): difference) time.sleep(difference) - def continuous_eval(self, delay_secs=0, throttle_delay_secs=60): + def continuous_eval(self, delay_secs=120, throttle_delay_secs=60): self._continuous_eval(self._eval_input_fn, name="continuous", delay_secs=delay_secs, throttle_delay_secs=throttle_delay_secs) - def continuous_eval_on_train_data(self, delay_secs=0, throttle_delay_secs=60): + def continuous_eval_on_train_data( + self, delay_secs=120, throttle_delay_secs=60): self._continuous_eval(self._train_input_fn, name="continuous_on_train_data", delay_secs=delay_secs, diff --git a/tensorflow/contrib/learn/python/learn/runner_flags.py b/tensorflow/contrib/learn/python/learn/runner_flags.py index 95b20f5e418e8187eb8084c69d609aa8b96412c6..26f906e78100b9282fc7e7b149217b0714cee399 100644 --- a/tensorflow/contrib/learn/python/learn/runner_flags.py +++ b/tensorflow/contrib/learn/python/learn/runner_flags.py @@ -32,5 +32,4 @@ flags.DEFINE_string('schedule', 'local_run', 'instance returned by the function passed to the ' 'run() call') -# TODO(ispir): Remove once we migrated customer pipilines. -flags.DEFINE_string('execution_mode', 'all', 'Deprecated. Use FLAGS.schedule') +flags.DEFINE_integer('task', 0, 'Task id of the replica running the training.') diff --git a/tensorflow/contrib/learn/python/learn/tests/experiment_test.py b/tensorflow/contrib/learn/python/learn/tests/experiment_test.py index d359bb80244966f5d37464597a9f5c63fe93048c..b3671a6c4049a63027296be1c476323ff33cec9f 100644 --- a/tensorflow/contrib/learn/python/learn/tests/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/experiment_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import time import tensorflow as tf +# importing to get flags. +from tensorflow.contrib.learn.python.learn import runner_flags # pylint: disable=unused-import class TestEstimator(object): @@ -68,8 +70,19 @@ class ExperimentTest(tf.test.TestCase): start = time.time() ex.train(delay_secs=delay) duration = time.time() - start - tf.logging.info('train duration (expected %f): %f', delay, duration) - self.assertTrue(duration > delay - 0.5 and duration < delay + 0.5) + self.assertAlmostEqual(duration, delay, delta=0.5) + + def test_train_default_delay(self): + est = TestEstimator() + ex = tf.contrib.learn.Experiment(est, + train_input_fn='train_input', + eval_input_fn='eval_input') + for task in [0, 1, 3]: + start = time.time() + tf.flags.FLAGS.task = task + ex.train() + duration = time.time() - start + self.assertAlmostEqual(duration, task*5, delta=0.5) def test_evaluate(self): est = TestEstimator() @@ -92,7 +105,7 @@ class ExperimentTest(tf.test.TestCase): ex.evaluate(delay_secs=delay) duration = time.time() - start tf.logging.info('eval duration (expected %f): %f', delay, duration) - self.assertTrue(duration > delay - 0.5 and duration < delay + 0.5) + self.assertAlmostEqual(duration, delay, delta=0.5) def test_continuous_eval(self): est = TestEstimator() @@ -118,7 +131,7 @@ class ExperimentTest(tf.test.TestCase): duration = time.time() - start expected = 5 * delay tf.logging.info('eval duration (expected %f): %f', expected, duration) - self.assertTrue(duration > expected - 0.5 and duration < expected + 0.5) + self.assertAlmostEqual(duration, expected, delta=0.5) def test_run_local(self): est = TestEstimator()