提交 e412f6aa 编写于 作者: M Mustafa Ispir 提交者: TensorFlower Gardener

Added staggered start of training workers.

Change: 125803613
上级 0c1285bf
...@@ -24,8 +24,11 @@ import time ...@@ -24,8 +24,11 @@ import time
from tensorflow.contrib.learn.python.learn import monitors from tensorflow.contrib.learn.python.learn import monitors
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
FLAGS = flags.FLAGS
class Experiment(object): class Experiment(object):
"""Experiment is a class containing all information needed to train a model. """Experiment is a class containing all information needed to train a model.
...@@ -70,7 +73,7 @@ class Experiment(object): ...@@ -70,7 +73,7 @@ class Experiment(object):
self._train_monitors = train_monitors self._train_monitors = train_monitors
self._local_eval_frequency = local_eval_frequency self._local_eval_frequency = local_eval_frequency
def train(self, delay_secs=0): def train(self, delay_secs=None):
"""Fit the estimator using the training data. """Fit the estimator using the training data.
Train the estimator for `steps` steps, after waiting for `delay_secs` Train the estimator for `steps` steps, after waiting for `delay_secs`
...@@ -83,6 +86,12 @@ class Experiment(object): ...@@ -83,6 +86,12 @@ class Experiment(object):
The trained estimator. The trained estimator.
""" """
if delay_secs is None:
task_id = 0
if "task" in FLAGS:
task_id = FLAGS.task
delay_secs = min(60, task_id*5)
if delay_secs: if delay_secs:
logging.info("Waiting %d secs before starting training.", delay_secs) logging.info("Waiting %d secs before starting training.", delay_secs)
time.sleep(delay_secs) time.sleep(delay_secs)
...@@ -91,7 +100,7 @@ class Experiment(object): ...@@ -91,7 +100,7 @@ class Experiment(object):
max_steps=self._train_steps, max_steps=self._train_steps,
monitors=self._train_monitors) monitors=self._train_monitors)
def evaluate(self, delay_secs=0): def evaluate(self, delay_secs=120):
"""Evaluate on the evaluation data. """Evaluate on the evaluation data.
Runs evaluation on the evaluation data and returns the result. If `steps` Runs evaluation on the evaluation data and returns the result. If `steps`
...@@ -127,13 +136,13 @@ class Experiment(object): ...@@ -127,13 +136,13 @@ class Experiment(object):
input_fn=self._eval_input_fn, eval_steps=self._eval_steps, input_fn=self._eval_input_fn, eval_steps=self._eval_steps,
metrics=self._eval_metrics, every_n_steps=self._local_eval_frequency metrics=self._eval_metrics, every_n_steps=self._local_eval_frequency
)] )]
self.train() self.train(delay_secs=0)
return self.evaluate() return self.evaluate(delay_secs=0)
def _continuous_eval(self, def _continuous_eval(self,
input_fn, input_fn,
name, name,
delay_secs=0, delay_secs=120,
throttle_delay_secs=60): throttle_delay_secs=60):
"""Run continuous eval. """Run continuous eval.
...@@ -160,8 +169,7 @@ class Experiment(object): ...@@ -160,8 +169,7 @@ class Experiment(object):
metrics=self._eval_metrics, metrics=self._eval_metrics,
name=name) name=name)
except NotFittedError: except NotFittedError:
logging.warning("Estimator is not fitted yet, skipping evaluation. " logging.warning("Estimator is not fitted yet, skipping evaluation.")
"Increase 'delay_secs' to avoid this warning.")
duration = time.time() - start duration = time.time() - start
if duration < throttle_delay_secs: if duration < throttle_delay_secs:
difference = throttle_delay_secs - duration difference = throttle_delay_secs - duration
...@@ -169,13 +177,14 @@ class Experiment(object): ...@@ -169,13 +177,14 @@ class Experiment(object):
difference) difference)
time.sleep(difference) time.sleep(difference)
def continuous_eval(self, delay_secs=0, throttle_delay_secs=60): def continuous_eval(self, delay_secs=120, throttle_delay_secs=60):
self._continuous_eval(self._eval_input_fn, self._continuous_eval(self._eval_input_fn,
name="continuous", name="continuous",
delay_secs=delay_secs, delay_secs=delay_secs,
throttle_delay_secs=throttle_delay_secs) throttle_delay_secs=throttle_delay_secs)
def continuous_eval_on_train_data(self, delay_secs=0, throttle_delay_secs=60): def continuous_eval_on_train_data(
self, delay_secs=120, throttle_delay_secs=60):
self._continuous_eval(self._train_input_fn, self._continuous_eval(self._train_input_fn,
name="continuous_on_train_data", name="continuous_on_train_data",
delay_secs=delay_secs, delay_secs=delay_secs,
......
...@@ -32,5 +32,4 @@ flags.DEFINE_string('schedule', 'local_run', ...@@ -32,5 +32,4 @@ flags.DEFINE_string('schedule', 'local_run',
'instance returned by the function passed to the ' 'instance returned by the function passed to the '
'run() call') 'run() call')
# TODO(ispir): Remove once we migrated customer pipilines. flags.DEFINE_integer('task', 0, 'Task id of the replica running the training.')
flags.DEFINE_string('execution_mode', 'all', 'Deprecated. Use FLAGS.schedule')
...@@ -20,6 +20,8 @@ from __future__ import print_function ...@@ -20,6 +20,8 @@ from __future__ import print_function
import time import time
import tensorflow as tf import tensorflow as tf
# importing to get flags.
from tensorflow.contrib.learn.python.learn import runner_flags # pylint: disable=unused-import
class TestEstimator(object): class TestEstimator(object):
...@@ -68,8 +70,19 @@ class ExperimentTest(tf.test.TestCase): ...@@ -68,8 +70,19 @@ class ExperimentTest(tf.test.TestCase):
start = time.time() start = time.time()
ex.train(delay_secs=delay) ex.train(delay_secs=delay)
duration = time.time() - start duration = time.time() - start
tf.logging.info('train duration (expected %f): %f', delay, duration) self.assertAlmostEqual(duration, delay, delta=0.5)
self.assertTrue(duration > delay - 0.5 and duration < delay + 0.5)
def test_train_default_delay(self):
est = TestEstimator()
ex = tf.contrib.learn.Experiment(est,
train_input_fn='train_input',
eval_input_fn='eval_input')
for task in [0, 1, 3]:
start = time.time()
tf.flags.FLAGS.task = task
ex.train()
duration = time.time() - start
self.assertAlmostEqual(duration, task*5, delta=0.5)
def test_evaluate(self): def test_evaluate(self):
est = TestEstimator() est = TestEstimator()
...@@ -92,7 +105,7 @@ class ExperimentTest(tf.test.TestCase): ...@@ -92,7 +105,7 @@ class ExperimentTest(tf.test.TestCase):
ex.evaluate(delay_secs=delay) ex.evaluate(delay_secs=delay)
duration = time.time() - start duration = time.time() - start
tf.logging.info('eval duration (expected %f): %f', delay, duration) tf.logging.info('eval duration (expected %f): %f', delay, duration)
self.assertTrue(duration > delay - 0.5 and duration < delay + 0.5) self.assertAlmostEqual(duration, delay, delta=0.5)
def test_continuous_eval(self): def test_continuous_eval(self):
est = TestEstimator() est = TestEstimator()
...@@ -118,7 +131,7 @@ class ExperimentTest(tf.test.TestCase): ...@@ -118,7 +131,7 @@ class ExperimentTest(tf.test.TestCase):
duration = time.time() - start duration = time.time() - start
expected = 5 * delay expected = 5 * delay
tf.logging.info('eval duration (expected %f): %f', expected, duration) tf.logging.info('eval duration (expected %f): %f', expected, duration)
self.assertTrue(duration > expected - 0.5 and duration < expected + 0.5) self.assertAlmostEqual(duration, expected, delta=0.5)
def test_run_local(self): def test_run_local(self):
est = TestEstimator() est = TestEstimator()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册