提交 9dbe34bf 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

DistributedVariable update methods always pass keyword arguments to

_mirrored_update

In this way it's easier to modify the arugments, which is needed to make the
return type another DistributedVariable.

PiperOrigin-RevId: 306525155
Change-Id: I8ea762fe555827ff8b4061109f7a1db884a9a910
上级 de2ae7c6
......@@ -745,7 +745,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
class MirroredVariable(DistributedVariable, Mirrored):
"""Holds a map from replica to variables whose values are kept in sync."""
def _mirrored_update(self, update_fn, value, **kwargs):
def _mirrored_update(self, update_fn, *args, **kwargs):
"""Apply identical updates using `update_fn` to variables on each replica."""
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
......@@ -760,12 +760,12 @@ class MirroredVariable(DistributedVariable, Mirrored):
# wrapped MirroredVariables through object members, captured arguments
# , etc. This is more likely in an update_non_slot() function
# , which can update several non-slot variables in one call.
return update_fn(self._values[update_replica_id], value, **kwargs)
return update_fn(self._values[update_replica_id], *args, **kwargs)
# We are calling update on the mirrored variable in cross replica
# context, use `strategy.extended.update()` to update the variable.
return self._distribute_strategy.extended.update(
self, update_fn, args=(value,), kwargs=kwargs)
self, update_fn, args=args, kwargs=kwargs)
else:
_assert_replica_context(self._distribute_strategy)
# We are calling an update function on the mirrored variable in replica
......@@ -778,7 +778,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
raise ValueError(
_aggregation_error_msg.format(variable_type="MirroredVariable"))
def merge_fn(strategy, value, **other_kwargs):
def merge_fn(strategy, value, *other_args, **other_kwargs):
"""Aggregate across replicas and update MV with aggregated value."""
# Don't allow MEAN with non float dtype, since it may cause unexpected
# precision loss. Python3 and NumPy automatically upcast integers to
......@@ -797,71 +797,40 @@ class MirroredVariable(DistributedVariable, Mirrored):
v = _apply_aggregation(strategy, value, self._aggregation, self)
return strategy.extended.update(
self, update_fn, args=(v,), kwargs=other_kwargs)
self, update_fn, args=(v,) + other_args, kwargs=other_kwargs)
return ds_context.get_replica_context().merge_call(
merge_fn, args=(value,), kwargs=kwargs)
merge_fn, args=args, kwargs=kwargs)
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
def assign_sub(self, *args, **kwargs):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
return self._mirrored_update(
update_fn=assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
return self._mirrored_update(assign_sub_fn, *args, **kwargs)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
def assign_add(self, *args, **kwargs):
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
return self._mirrored_update(
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
return self._mirrored_update(assign_add_fn, *args, **kwargs)
def assign(self, value, use_locking=False, name=None, read_value=True):
def assign(self, *args, **kwargs):
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
return self._mirrored_update(
update_fn=assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
return self._mirrored_update(assign_fn, *args, **kwargs)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
def scatter_sub(self, *args, **kwargs):
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
return self._mirrored_update(
update_fn=scatter_sub_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return self._mirrored_update(scatter_sub_fn, *args, **kwargs)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
def scatter_add(self, *args, **kwargs):
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
return self._mirrored_update(
update_fn=scatter_add_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return self._mirrored_update(scatter_add_fn, *args, **kwargs)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
def scatter_mul(self, *args, **kwargs):
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
return self._mirrored_update(
update_fn=scatter_mul_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return self._mirrored_update(scatter_mul_fn, *args, **kwargs)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
def scatter_div(self, *args, **kwargs):
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
return self._mirrored_update(
update_fn=scatter_div_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return self._mirrored_update(scatter_div_fn, *args, **kwargs)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
def scatter_min(self, *args, **kwargs):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_min is only supported for mirrored "
......@@ -870,13 +839,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
return self._mirrored_update(
update_fn=scatter_min_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return self._mirrored_update(scatter_min_fn, *args, **kwargs)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
def scatter_max(self, *args, **kwargs):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_max is only supported for mirrored "
......@@ -885,13 +850,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
return self._mirrored_update(
update_fn=scatter_max_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return self._mirrored_update(scatter_max_fn, *args, **kwargs)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
def scatter_update(self, *args, **kwargs):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_update is only supported for mirrored "
......@@ -900,11 +861,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return self._mirrored_update(
update_fn=scatter_update_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return self._mirrored_update(scatter_update_fn, *args, **kwargs)
def _get_cross_replica(self):
# Return identity, to avoid directly exposing the variable to the user and
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册