提交 b878ae39 编写于 作者: B Bruce Fontaine 提交者: Geeta Chavan

Allow tf.distribute.TPUStrategy to be used with TPUEmbedding API and ensure...

Allow tf.distribute.TPUStrategy to be used with TPUEmbedding API and ensure that LossScaleOptimizer properly rejects it.

PiperOrigin-RevId: 318186211
Change-Id: Id3b9cb8288e5d28ddbaec97d5b35627ab35bc08d
上级 890eae3e
......@@ -440,7 +440,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
if not strategy_supports_loss_scaling():
strategy = distribution_strategy_context.get_strategy()
if isinstance(strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
tpu_strategy.TPUStrategyV2)):
raise ValueError(
'Loss scaling is not supported with TPUStrategy. Loss scaling is '
'unnecessary with TPUs, since they support bfloat16 instead of '
......
......@@ -265,7 +265,8 @@ class TPUEmbedding(tracking.AutoTrackable):
Adam or Adagrad).
"""
self._strategy = distribution_strategy_context.get_strategy()
self._using_tpu = isinstance(self._strategy, tpu_strategy.TPUStrategy)
self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy,
tpu_strategy.TPUStrategyV2))
self._pipeline_execution_with_tensor_core = (
pipeline_execution_with_tensor_core)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册