提交 7a69f962 编写于 作者: C Chris Jones 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 285974579
上级 a19f2f8b
......@@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import values
from official.transformer.utils import tokenizer
_EXTRA_DECODE_LENGTH = 100
......@@ -144,12 +143,15 @@ def translate_file(model,
text = np.reshape(text, [num_replicas, local_batch_size, -1])
# Add tag to the input of each replica with the reordering logic after
# outputs, to ensure the output order matches the input order.
text = [
[tf.convert_to_tensor(tag), tf.convert_to_tensor(per_replica_text)]
for tag, per_replica_text in enumerate(text)
]
# pylint: disable=protected-access
text = values.PerReplica(distribution_strategy.extended._device_map, text)
text = tf.constant(text)
@tf.function
def text_as_per_replica():
replica_context = tf.distribute.get_replica_context()
replica_id = replica_context.replica_id_in_sync_group
return replica_id, text[replica_id]
text = distribution_strategy.experimental_run_v2(text_as_per_replica)
outputs = distribution_strategy.experimental_local_results(
predict_step(text))
tags, unordered_val_outputs = outputs[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册