提交 e412f6aa 编写于 作者: M Mustafa Ispir 提交者: TensorFlower Gardener

Added staggered start of training workers.

Change: 125803613
上级 0c1285bf
......@@ -24,8 +24,11 @@ import time
from tensorflow.contrib.learn.python.learn import monitors
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging
FLAGS = flags.FLAGS
class Experiment(object):
"""Experiment is a class containing all information needed to train a model.
......@@ -70,7 +73,7 @@ class Experiment(object):
self._train_monitors = train_monitors
self._local_eval_frequency = local_eval_frequency
def train(self, delay_secs=0):
def train(self, delay_secs=None):
"""Fit the estimator using the training data.
Train the estimator for `steps` steps, after waiting for `delay_secs`
......@@ -83,6 +86,12 @@ class Experiment(object):
The trained estimator.
"""
if delay_secs is None:
task_id = 0
if "task" in FLAGS:
task_id = FLAGS.task
delay_secs = min(60, task_id*5)
if delay_secs:
logging.info("Waiting %d secs before starting training.", delay_secs)
time.sleep(delay_secs)
......@@ -91,7 +100,7 @@ class Experiment(object):
max_steps=self._train_steps,
monitors=self._train_monitors)
def evaluate(self, delay_secs=0):
def evaluate(self, delay_secs=120):
"""Evaluate on the evaluation data.
Runs evaluation on the evaluation data and returns the result. If `steps`
......@@ -127,13 +136,13 @@ class Experiment(object):
input_fn=self._eval_input_fn, eval_steps=self._eval_steps,
metrics=self._eval_metrics, every_n_steps=self._local_eval_frequency
)]
self.train()
return self.evaluate()
self.train(delay_secs=0)
return self.evaluate(delay_secs=0)
def _continuous_eval(self,
input_fn,
name,
delay_secs=0,
delay_secs=120,
throttle_delay_secs=60):
"""Run continuous eval.
......@@ -160,8 +169,7 @@ class Experiment(object):
metrics=self._eval_metrics,
name=name)
except NotFittedError:
logging.warning("Estimator is not fitted yet, skipping evaluation. "
"Increase 'delay_secs' to avoid this warning.")
logging.warning("Estimator is not fitted yet, skipping evaluation.")
duration = time.time() - start
if duration < throttle_delay_secs:
difference = throttle_delay_secs - duration
......@@ -169,13 +177,14 @@ class Experiment(object):
difference)
time.sleep(difference)
def continuous_eval(self, delay_secs=0, throttle_delay_secs=60):
def continuous_eval(self, delay_secs=120, throttle_delay_secs=60):
self._continuous_eval(self._eval_input_fn,
name="continuous",
delay_secs=delay_secs,
throttle_delay_secs=throttle_delay_secs)
def continuous_eval_on_train_data(self, delay_secs=0, throttle_delay_secs=60):
def continuous_eval_on_train_data(
self, delay_secs=120, throttle_delay_secs=60):
self._continuous_eval(self._train_input_fn,
name="continuous_on_train_data",
delay_secs=delay_secs,
......
......@@ -32,5 +32,4 @@ flags.DEFINE_string('schedule', 'local_run',
'instance returned by the function passed to the '
'run() call')
# TODO(ispir): Remove once we migrated customer pipilines.
flags.DEFINE_string('execution_mode', 'all', 'Deprecated. Use FLAGS.schedule')
flags.DEFINE_integer('task', 0, 'Task id of the replica running the training.')
......@@ -20,6 +20,8 @@ from __future__ import print_function
import time
import tensorflow as tf
# importing to get flags.
from tensorflow.contrib.learn.python.learn import runner_flags # pylint: disable=unused-import
class TestEstimator(object):
......@@ -68,8 +70,19 @@ class ExperimentTest(tf.test.TestCase):
start = time.time()
ex.train(delay_secs=delay)
duration = time.time() - start
tf.logging.info('train duration (expected %f): %f', delay, duration)
self.assertTrue(duration > delay - 0.5 and duration < delay + 0.5)
self.assertAlmostEqual(duration, delay, delta=0.5)
def test_train_default_delay(self):
est = TestEstimator()
ex = tf.contrib.learn.Experiment(est,
train_input_fn='train_input',
eval_input_fn='eval_input')
for task in [0, 1, 3]:
start = time.time()
tf.flags.FLAGS.task = task
ex.train()
duration = time.time() - start
self.assertAlmostEqual(duration, task*5, delta=0.5)
def test_evaluate(self):
est = TestEstimator()
......@@ -92,7 +105,7 @@ class ExperimentTest(tf.test.TestCase):
ex.evaluate(delay_secs=delay)
duration = time.time() - start
tf.logging.info('eval duration (expected %f): %f', delay, duration)
self.assertTrue(duration > delay - 0.5 and duration < delay + 0.5)
self.assertAlmostEqual(duration, delay, delta=0.5)
def test_continuous_eval(self):
est = TestEstimator()
......@@ -118,7 +131,7 @@ class ExperimentTest(tf.test.TestCase):
duration = time.time() - start
expected = 5 * delay
tf.logging.info('eval duration (expected %f): %f', expected, duration)
self.assertTrue(duration > expected - 0.5 and duration < expected + 0.5)
self.assertAlmostEqual(duration, expected, delta=0.5)
def test_run_local(self):
est = TestEstimator()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册