提交 0f5bdd0e 编写于 作者: C Chen Chen 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 285285388
上级 143d09bc
......@@ -365,7 +365,12 @@ class TransformerTask(object):
def eval(self):
"""Evaluates the model."""
with distribution_utils.get_strategy_scope(self.distribution_strategy):
distribution_strategy = self.distribution_strategy if self.use_tpu else None
# We only want to create the model under DS scope for TPU case.
# When 'distribution_strategy' is None, a no-op DummyContextManager will
# be used.
with distribution_utils.get_strategy_scope(distribution_strategy):
if not self.predict_model:
self.predict_model = transformer.create_model(self.params, False)
self._load_weights_if_possible(
......@@ -375,7 +380,7 @@ class TransformerTask(object):
return evaluate_and_log_bleu(
self.predict_model, self.params, self.flags_obj.bleu_source,
self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
self.distribution_strategy if self.use_tpu else None)
distribution_strategy)
def predict(self):
"""Predicts result from the model."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册