提交 b29fe6b7 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

[Clean up]: remove is_v2() check inside transformer.

PiperOrigin-RevId: 312988874
上级 09d3c74a
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import tensorflow as tf import tensorflow as tf
from official.nlp.transformer import beam_search_v1 as v1 from official.nlp.transformer import beam_search_v1 as v1
from official.nlp.transformer import misc
_StateKeys = v1._StateKeys # pylint: disable=protected-access _StateKeys = v1._StateKeys # pylint: disable=protected-access
...@@ -52,8 +51,8 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch): ...@@ -52,8 +51,8 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
# Account for corner case where there are no finished sequences for a # Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch # particular batch item. In that case, return alive sequences for that batch
# item. # item.
finished_seq = tf.compat.v2.where(seq_cond, finished_seq, alive_seq) finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.compat.v2.where( finished_scores = tf.where(
score_cond, finished_scores, alive_log_probs) score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores return finished_seq, finished_scores
...@@ -102,14 +101,9 @@ def sequence_beam_search(symbols_to_logits_fn, ...@@ -102,14 +101,9 @@ def sequence_beam_search(symbols_to_logits_fn,
batch_size = ( batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0]) tf.shape(initial_ids)[0])
if misc.is_v2(): sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size, beam_size, alpha, max_decode_length, eos_id,
beam_size, alpha, max_decode_length, eos_id, padded_decode, dtype)
padded_decode, dtype)
else:
sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id,
padded_decode, dtype)
return sbs.search(initial_ids, initial_cache) return sbs.search(initial_ids, initial_cache)
......
...@@ -56,7 +56,6 @@ import os ...@@ -56,7 +56,6 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.transformer import misc
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
# Buffer size for reading records from a TFRecord file. Each training file is # Buffer size for reading records from a TFRecord file. Each training file is
...@@ -313,9 +312,5 @@ def eval_input_fn(params, ctx=None): ...@@ -313,9 +312,5 @@ def eval_input_fn(params, ctx=None):
def map_data_for_transformer_fn(x, y): def map_data_for_transformer_fn(x, y):
"""Maps data for training, and handles weried behaviors for different vers.""" """Maps data for training, and handles weried behaviors for different vers."""
# Will transform input x and targets y into tuple(x, y) as new model inputs. # Will transform input x and targets y into tuple(x, y) as new model inputs.
if misc.is_v2(): # For TF v2, the 2nd parameter is omitted to make Keras training work.
# For TF v2, the 2nd parameter is omitted to make Keras training work. return ((x, y),)
return ((x, y),)
else:
# For TF v1, Keras requires a dummy placeholder as the 2nd parameter.
return ((x, y), tf.constant(0.0))
...@@ -22,10 +22,6 @@ from __future__ import print_function ...@@ -22,10 +22,6 @@ from __future__ import print_function
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
# TODO(tianlin) Import internal library. Remove this when some functions for
# different TF versions are fixed.
from tensorflow.python import tf2 as tf2_internal
from official.nlp.transformer import model_params from official.nlp.transformer import model_params
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
...@@ -39,11 +35,6 @@ PARAMS_MAP = { ...@@ -39,11 +35,6 @@ PARAMS_MAP = {
} }
def is_v2():
"""Returns whether it is v2."""
return tf2_internal.enabled()
def get_model_params(param_set, num_gpus): def get_model_params(param_set, num_gpus):
"""Gets predefined model params.""" """Gets predefined model params."""
if num_gpus > 1: if num_gpus > 1:
...@@ -78,17 +69,6 @@ def define_transformer_flags(): ...@@ -78,17 +69,6 @@ def define_transformer_flags():
fp16_implementation=True fp16_implementation=True
) )
# Additional performance flags
# TODO(b/76028325): Remove when generic layout optimizer is ready.
flags.DEFINE_boolean(
name='enable_grappler_layout_optimizer',
default=True,
help='Enable Grappler layout optimizer. Currently Grappler can '
'de-optimize fp16 graphs by forcing NCHW layout for all '
'convolutions and batch normalizations, and this flag allows to '
'disable it.'
)
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_device(tpu=True) flags_core.define_device(tpu=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册