提交 ac8ee841 编写于 作者: M Mustafa Ispir 提交者: TensorFlower Gardener

Run SyncReplicasOptimizer with MonitoredSession.

User code will look like as follows:
  opt = tf.SyncReplicasOptimizer(...)
  train_op = opt.minimize(total_loss, global_step=global_step)
  sync_rep_hook = opt.make_session_run_hook(is_chief)
  with training.MonitoredTrainingSession(master=master, is_chief=is_chief, hooks=[sync_rep_hook]) as mon_sess:
    while not mon_sess.should_stop():
      mon_sess.run(training_op)
Change: 144353039
上级 4c620cc8
......@@ -497,7 +497,7 @@ class _MonitoredSession(object):
queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord)
# Inform the hooks that a new session has been created.
for hook in self._hooks:
hook.after_create_session(self.tf_sess)
hook.after_create_session(self.tf_sess, self.coord)
return _CoordinatedSession(
_HookedSession(self.tf_sess, self._hooks), self.coord)
......
......@@ -179,7 +179,7 @@ class FakeHook(session_run_hook.SessionRunHook):
def begin(self):
self.call_counter['begin'] += 1
def after_create_session(self, session): # pylint: disable=unused-argument
def after_create_session(self, session, coord): # pylint: disable=unused-argument
self.call_counter['after_create_session'] += 1
def before_run(self, run_context):
......
......@@ -421,44 +421,6 @@ class SessionManager(object):
pass
# pylint: enable=broad-except
def _ready(self, op, sess, msg):
"""Checks if the model is ready or not, as determined by op.
Args:
op: An op, either _ready_op or _ready_for_local_init_op, which defines the
readiness of the model.
sess: A `Session`.
msg: A message to log to warning if not ready
Returns:
A tuple (is_ready, msg), where is_ready is True if ready and False
otherwise, and msg is `None` if the model is ready, a `String` with the
reason why it is not ready otherwise.
"""
if op is None:
return True, None
else:
try:
ready_value = sess.run(op)
# The model is considered ready if ready_op returns an empty 1-D tensor.
# Also compare to `None` and dtype being int32 for backward
# compatibility.
if (ready_value is None or ready_value.dtype == np.int32 or
ready_value.size == 0):
return True, None
else:
# TODO(sherrym): If a custom ready_op returns other types of tensor,
# or strings other than variable names, this message could be
# confusing.
non_initialized_varnames = ", ".join(
[i.decode("utf-8") for i in ready_value])
return False, "Variables not initialized: " + non_initialized_varnames
except errors.FailedPreconditionError as e:
if "uninitialized" not in str(e):
logging.warning("%s : error [%s]", msg, str(e))
raise e
return False, str(e)
def _model_ready(self, sess):
"""Checks if the model is ready or not.
......@@ -470,7 +432,7 @@ class SessionManager(object):
otherwise, and msg is `None` if the model is ready, a `String` with the
reason why it is not ready otherwise.
"""
return self._ready(self._ready_op, sess, "Model not ready")
return _ready(self._ready_op, sess, "Model not ready")
def _model_ready_for_local_init(self, sess):
"""Checks if the model is ready to run local_init_op.
......@@ -484,7 +446,7 @@ class SessionManager(object):
ready to run local_init_op, a `String` with the reason why it is not ready
otherwise.
"""
return self._ready(self._ready_for_local_init_op, sess,
return _ready(self._ready_for_local_init_op, sess,
"Model not ready for local init")
def _try_run_local_init_op(self, sess):
......@@ -509,6 +471,45 @@ class SessionManager(object):
return True, None
def _ready(op, sess, msg):
"""Checks if the model is ready or not, as determined by op.
Args:
op: An op, either _ready_op or _ready_for_local_init_op, which defines the
readiness of the model.
sess: A `Session`.
msg: A message to log to warning if not ready
Returns:
A tuple (is_ready, msg), where is_ready is True if ready and False
otherwise, and msg is `None` if the model is ready, a `String` with the
reason why it is not ready otherwise.
"""
if op is None:
return True, None
else:
try:
ready_value = sess.run(op)
# The model is considered ready if ready_op returns an empty 1-D tensor.
# Also compare to `None` and dtype being int32 for backward
# compatibility.
if (ready_value is None or ready_value.dtype == np.int32 or
ready_value.size == 0):
return True, None
else:
# TODO(sherrym): If a custom ready_op returns other types of tensor,
# or strings other than variable names, this message could be
# confusing.
non_initialized_varnames = ", ".join(
[i.decode("utf-8") for i in ready_value])
return False, "Variables not initialized: " + non_initialized_varnames
except errors.FailedPreconditionError as e:
if "uninitialized" not in str(e):
logging.warning("%s : error [%s]", msg, str(e))
raise e
return False, str(e)
class _CountDownTimer(object):
def __init__(self, duration_secs):
......
......@@ -98,7 +98,7 @@ class SessionRunHook(object):
"""
pass
def after_create_session(self, session): # pylint: disable=unused-argument
def after_create_session(self, session, coord): # pylint: disable=unused-argument
"""Called when new TensorFlow session is created.
This is called to signal the hooks that a new session has been created. This
......@@ -111,6 +111,7 @@ class SessionRunHook(object):
Args:
session: A TensorFlow Session that has been created.
coord: A Coordinator object which keeps track of all threads.
"""
pass
......
......@@ -28,6 +28,8 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer
from tensorflow.python.training import queue_runner
from tensorflow.python.training import session_manager
from tensorflow.python.training import session_run_hook
# Please note that the gradients from replicas are averaged instead of summed
......@@ -104,43 +106,22 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
# Now you can call `minimize()` or `compute_gradients()` and
# `apply_gradients()` normally
grads = opt.minimize(total_loss, global_step=self.global_step)
training_op = opt.minimize(total_loss, global_step=self.global_step)
# You can now call get_init_tokens_op() and get_chief_queue_runner().
# Note that get_init_tokens_op() must be called before creating session
# because it modifies the graph by adding new nodes.
init_token_op = opt.get_init_tokens_op()
chief_queue_runner = opt.get_chief_queue_runner()
# You can create the hook which handles initialization and queues.
sync_replicas_hook = opt.make_session_run_hook(is_chief)
```
In the training program, every worker will run the train_op as if not
synchronized. But one worker (usually the chief) will need to execute the
chief_queue_runner and get_init_tokens_op from this optimizer.
synchronized.
```python
# When you create the supervisor, you need to add the local_init_op and
# ready_for_local_init_op to make sure the local_step is initialized to the
# global_step. Here is an example:
if is_chief:
local_init_op = opt.chief_init_op
else:
local_init_op = opt.local_step_init_op
ready_for_local_init_op = opt.ready_for_local_init_op
sv = tf.Supervisor(graph=g,
is_chief=is_chief,
# This initialize local step.
local_init_op=local_init_op,
# This makes sure global step is initialized before using.
ready_for_local_init_op=ready_for_local_init_op,
saver=model.saver)
# After the session is created by the Supervisor and before the main while
# loop:
if is_chief and FLAGS.sync_replicas:
sv.start_queue_runners(sess, [chief_queue_runner])
# Insert initial tokens to the queue.
sess.run(init_token_op)
with training.MonitoredTrainingSession(
master=workers[worker_id].target, is_chief=is_chief,
hooks=[sync_replicas_hook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(training_op)
```
@@__init__
......@@ -440,3 +421,51 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
init_tokens = control_flow_ops.no_op(name="no_init_tokens")
return init_tokens
def make_session_run_hook(self, is_chief, num_tokens=-1):
"""Creates a hook to handle SyncReplicasHook ops such as initialization."""
if is_chief:
return _SyncReplicasOptimizerHook(self.chief_init_op,
self.ready_for_local_init_op,
self.get_chief_queue_runner(),
self.get_init_tokens_op(num_tokens))
return _SyncReplicasOptimizerHook(self.local_step_init_op,
self.ready_for_local_init_op, None, None)
class _SyncReplicasOptimizerHook(session_run_hook.SessionRunHook):
"""A SessionRunHook handles ops related to SyncReplicasOptimizer."""
def __init__(self, local_init_op, ready_for_local_init_op, q_runner,
init_tokens_op):
"""Creates hook to handle SyncReplicaOptimizer initialization ops.
Args:
local_init_op: Either `SyncReplicasOptimizer.chief_init_op` or
`SyncReplicasOptimizer.local_step_init_op`.
ready_for_local_init_op: `SyncReplicasOptimizer.ready_for_local_init_op`
q_runner: Either `SyncReplicasOptimizer.get_chief_queue_runner` or `None`
init_tokens_op: `SyncReplicasOptimizer.get_init_tokens_op` or None
"""
self._local_init_op = local_init_op
self._ready_for_local_init_op = ready_for_local_init_op
self._q_runner = q_runner
self._init_tokens_op = init_tokens_op
def after_create_session(self, session, coord):
"""Runs SyncReplicasOptimizer initialization ops."""
local_init_success, msg = session_manager._ready( # pylint: disable=protected-access
self._ready_for_local_init_op, session,
"Model is not ready for SyncReplicasOptimizer local init.")
if not local_init_success:
raise RuntimeError(
"Init operations did not make model ready for SyncReplicasOptimizer "
"local_init. Init op: %s, error: %s" %
(self._local_init_op.name, msg))
session.run(self._local_init_op)
if self._init_tokens_op is not None:
session.run(self._init_tokens_op)
if self._q_runner is not None:
self._q_runner.create_threads(
session, coord=coord, daemon=True, start=True)
......@@ -28,7 +28,6 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import server_lib
from tensorflow.python.training import supervisor as supervisor_lib
from tensorflow.python.training import training
......@@ -92,33 +91,14 @@ def get_workers(num_workers, replicas_to_aggregate, workers):
[var_0, var_1, var_sparse]),
global_step=global_step)
]
sync_replicas_hook = sync_rep_opt.make_session_run_hook(
is_chief, num_tokens=num_workers)
init_op = variables.global_variables_initializer()
# Needed ops from the sync_rep optimizer. This is mainly for the
# local_step initialization.
local_init_op = sync_rep_opt.local_step_init_op
if is_chief:
local_init_op = sync_rep_opt.chief_init_op
ready_for_local_init_op = sync_rep_opt.ready_for_local_init_op
# Chief_queue_runner
chief_queue_runner = sync_rep_opt.get_chief_queue_runner()
sync_init_op = sync_rep_opt.get_init_tokens_op(num_workers)
# Creates session for chief.
supervisor = supervisor_lib.Supervisor(
graph=graph,
# Creates MonitoredSession
session = training.MonitoredTrainingSession(
master=workers[worker_id].target,
is_chief=is_chief,
recovery_wait_secs=1,
init_op=init_op,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op)
session = supervisor.prepare_or_wait_for_session(workers[worker_id].target)
# Chief should execute the sync_init_op and start the chief queue runner.
if is_chief:
session.run(sync_init_op)
supervisor.StartQueueRunners(session, [chief_queue_runner])
hooks=[sync_replicas_hook])
sessions.append(session)
graphs.append(graph)
......@@ -146,9 +126,9 @@ class SyncReplicasOptimizerTest(test.TestCase):
var_0_g_0 = graphs[0].get_tensor_by_name("v0:0")
var_1_g_0 = graphs[0].get_tensor_by_name("v1:0")
local_step_0 = graphs[0].get_tensor_by_name("sync_rep_local_step:0")
self.assertAllEqual(0.0, var_0_g_0.eval(session=sessions[0]))
self.assertAllEqual(1.0, var_1_g_0.eval(session=sessions[0]))
self.assertAllEqual(0, local_step_0.eval(session=sessions[0]))
self.assertAllEqual(0.0, sessions[0].run(var_0_g_0))
self.assertAllEqual(1.0, sessions[0].run(var_1_g_0))
self.assertAllEqual(0, sessions[0].run(local_step_0))
# Will just use session 1 to verify all the variables later.
var_0_g_1 = graphs[1].get_tensor_by_name("v0:0")
......@@ -158,10 +138,9 @@ class SyncReplicasOptimizerTest(test.TestCase):
global_step = graphs[1].get_tensor_by_name("global_step:0")
# The steps should also be initialized.
self.assertAllEqual(0, global_step.eval(session=sessions[1]))
self.assertAllEqual(0, local_step_1.eval(session=sessions[1]))
self.assertAllClose(
[[3.0], [4.0]], var_sparse_g_1.eval(session=sessions[1]))
self.assertAllEqual(0, sessions[1].run(global_step))
self.assertAllEqual(0, sessions[1].run(local_step_1))
self.assertAllClose([[3.0], [4.0]], sessions[1].run(var_sparse_g_1))
# We have initial tokens in the queue so we can call this one by one. After
# the first step, this will no longer work as there will be no more extra
......@@ -171,16 +150,13 @@ class SyncReplicasOptimizerTest(test.TestCase):
# The global step should have been updated and the variables should now have
# the new values after the average of the gradients are applied.
while global_step.eval(session=sessions[1]) != 1:
while sessions[1].run(global_step) != 1:
time.sleep(0.01)
self.assertAllClose(
0 - (0.1 + 0.3) / 2 * 2.0, var_0_g_1.eval(session=sessions[1]))
self.assertAllClose(
1 - (0.9 + 1.1) / 2 * 2.0, var_1_g_1.eval(session=sessions[1]))
self.assertAllClose(
[[3.0], [4.0 - (0.1 + 0.3) / 2 * 2.0]],
var_sparse_g_1.eval(session=sessions[1]))
self.assertAllClose(0 - (0.1 + 0.3) / 2 * 2.0, sessions[1].run(var_0_g_1))
self.assertAllClose(1 - (0.9 + 1.1) / 2 * 2.0, sessions[1].run(var_1_g_1))
self.assertAllClose([[3.0], [4.0 - (0.1 + 0.3) / 2 * 2.0]],
sessions[1].run(var_sparse_g_1))
# The local step for both workers should still be 0 because the initial
# tokens in the token queue are 0s. This means that the following
......@@ -188,20 +164,18 @@ class SyncReplicasOptimizerTest(test.TestCase):
# the current global step. However, this only happens once when the system
# just starts and this is necessary to make the system robust for the case
# when chief gets restarted by errors/preemption/...
self.assertAllEqual(0, local_step_0.eval(session=sessions[0]))
self.assertAllEqual(0, local_step_1.eval(session=sessions[1]))
self.assertAllEqual(0, sessions[0].run(local_step_0))
self.assertAllEqual(0, sessions[1].run(local_step_1))
sessions[0].run(train_ops[0])
sessions[1].run(train_ops[1])
# Although the global step should still be 1 as explained above, the local
# step should now be updated to 1. The variables are still the same.
self.assertAllEqual(1, global_step.eval(session=sessions[1]))
self.assertAllEqual(1, local_step_0.eval(session=sessions[0]))
self.assertAllEqual(1, local_step_1.eval(session=sessions[1]))
self.assertAllClose(
0 - (0.1 + 0.3) / 2 * 2.0, var_0_g_1.eval(session=sessions[1]))
self.assertAllClose(
1 - (0.9 + 1.1) / 2 * 2.0, var_1_g_1.eval(session=sessions[1]))
self.assertAllEqual(1, sessions[1].run(global_step))
self.assertAllEqual(1, sessions[0].run(local_step_0))
self.assertAllEqual(1, sessions[1].run(local_step_1))
self.assertAllClose(0 - (0.1 + 0.3) / 2 * 2.0, sessions[1].run(var_0_g_1))
self.assertAllClose(1 - (0.9 + 1.1) / 2 * 2.0, sessions[1].run(var_1_g_1))
# At this step, the token queue is empty. So the 2 workers need to work
# together to proceed.
......@@ -221,11 +195,11 @@ class SyncReplicasOptimizerTest(test.TestCase):
# The global step should now be 2 and the gradients should have been
# applied twice.
self.assertAllEqual(2, global_step.eval(session=sessions[1]))
self.assertAllClose(
0 - 2 * (0.1 + 0.3) / 2 * 2.0, var_0_g_1.eval(session=sessions[1]))
self.assertAllClose(
1 - 2 * (0.9 + 1.1) / 2 * 2.0, var_1_g_1.eval(session=sessions[1]))
self.assertAllEqual(2, sessions[1].run(global_step))
self.assertAllClose(0 - 2 * (0.1 + 0.3) / 2 * 2.0,
sessions[1].run(var_0_g_1))
self.assertAllClose(1 - 2 * (0.9 + 1.1) / 2 * 2.0,
sessions[1].run(var_1_g_1))
# 3 workers and one of them is backup.
def test3Workers1Backup(self):
......@@ -245,8 +219,8 @@ class SyncReplicasOptimizerTest(test.TestCase):
global_step = graphs[1].get_tensor_by_name("global_step:0")
# The steps should also be initilized.
self.assertAllEqual(0, global_step.eval(session=sessions[1]))
self.assertAllEqual(0, local_step_1.eval(session=sessions[1]))
self.assertAllEqual(0, sessions[1].run(global_step))
self.assertAllEqual(0, sessions[1].run(local_step_1))
# We have initial tokens in the queue so we can call this one by one. After
# the token queue becomes empty, they should be called concurrently.
......@@ -257,14 +231,12 @@ class SyncReplicasOptimizerTest(test.TestCase):
# The global step should have been updated since we only need to collect 2
# gradients. The variables should now have the new values after the average
# of the gradients from worker 0/2 are applied.
while global_step.eval(session=sessions[1]) != 1:
while sessions[1].run(global_step) != 1:
time.sleep(0.01)
self.assertAllEqual(1, global_step.eval(session=sessions[1]))
self.assertAllClose(
0 - (0.1 + 0.5) / 2 * 2.0, var_0_g_1.eval(session=sessions[1]))
self.assertAllClose(
1 - (0.9 + 1.3) / 2 * 2.0, var_1_g_1.eval(session=sessions[1]))
self.assertAllEqual(1, sessions[1].run(global_step))
self.assertAllClose(0 - (0.1 + 0.5) / 2 * 2.0, sessions[1].run(var_0_g_1))
self.assertAllClose(1 - (0.9 + 1.3) / 2 * 2.0, sessions[1].run(var_1_g_1))
# Worker 1 finished later and its gradients will now be dropped as it is
# stale.
......@@ -278,8 +250,8 @@ class SyncReplicasOptimizerTest(test.TestCase):
# Although the global step should still be 1 as explained above, the local
# step should now be updated to 1. Just check worker 1 as an example.
self.assertAllEqual(1, global_step.eval(session=sessions[1]))
self.assertAllEqual(1, local_step_1.eval(session=sessions[1]))
self.assertAllEqual(1, sessions[1].run(global_step))
self.assertAllEqual(1, sessions[1].run(local_step_1))
thread_0 = self.checkedThread(
target=self._run, args=(train_ops[0], sessions[0]))
......@@ -290,7 +262,7 @@ class SyncReplicasOptimizerTest(test.TestCase):
# It will wait as we need 2 workers to finish this step and the global step
# should be still 1.
thread_0.start()
self.assertAllEqual(1, global_step.eval(session=sessions[1]))
self.assertAllEqual(1, sessions[1].run(global_step))
# Starts worker 1.
thread_1.start()
......@@ -298,11 +270,11 @@ class SyncReplicasOptimizerTest(test.TestCase):
# The global step should now be 2 and the gradients should have been
# applied again.
self.assertAllEqual(2, global_step.eval(session=sessions[1]))
self.assertAllClose(
-0.6 - (0.1 + 0.3) / 2 * 2.0, var_0_g_1.eval(session=sessions[1]))
self.assertAllClose(
-1.2 - (0.9 + 1.1) / 2 * 2.0, var_1_g_1.eval(session=sessions[1]))
self.assertAllEqual(2, sessions[1].run(global_step))
self.assertAllClose(-0.6 - (0.1 + 0.3) / 2 * 2.0,
sessions[1].run(var_0_g_1))
self.assertAllClose(-1.2 - (0.9 + 1.1) / 2 * 2.0,
sessions[1].run(var_1_g_1))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册