diff --git a/official/nlp/transformer/beam_search.py b/official/nlp/transformer/beam_search.py index fa1ae52f0b1fec6962f9d8b0ee36ee6bf61eb1b2..a4c1127535e6ae805f6619819737c379cadca6f2 100644 --- a/official/nlp/transformer/beam_search.py +++ b/official/nlp/transformer/beam_search.py @@ -17,7 +17,6 @@ import tensorflow as tf from official.nlp.transformer import beam_search_v1 as v1 -from official.nlp.transformer import misc _StateKeys = v1._StateKeys # pylint: disable=protected-access @@ -52,8 +51,8 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch): # Account for corner case where there are no finished sequences for a # particular batch item. In that case, return alive sequences for that batch # item. - finished_seq = tf.compat.v2.where(seq_cond, finished_seq, alive_seq) - finished_scores = tf.compat.v2.where( + finished_seq = tf.where(seq_cond, finished_seq, alive_seq) + finished_scores = tf.where( score_cond, finished_scores, alive_log_probs) return finished_seq, finished_scores @@ -102,14 +101,9 @@ def sequence_beam_search(symbols_to_logits_fn, batch_size = ( initial_ids.shape.as_list()[0] if padded_decode else tf.shape(initial_ids)[0]) - if misc.is_v2(): - sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size, - beam_size, alpha, max_decode_length, eos_id, - padded_decode, dtype) - else: - sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, - beam_size, alpha, max_decode_length, eos_id, - padded_decode, dtype) + sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size, + beam_size, alpha, max_decode_length, eos_id, + padded_decode, dtype) return sbs.search(initial_ids, initial_cache) diff --git a/official/nlp/transformer/data_pipeline.py b/official/nlp/transformer/data_pipeline.py index 84b98fea8a577b8ac5b9784e8fd5c323512441d0..cedd2c309d3194a07841610f8f1039a1a1e7ac51 100644 --- a/official/nlp/transformer/data_pipeline.py +++ b/official/nlp/transformer/data_pipeline.py @@ -56,7 +56,6 @@ import os from absl import logging import tensorflow as tf -from official.nlp.transformer import misc from official.utils.misc import model_helpers # Buffer size for reading records from a TFRecord file. Each training file is @@ -313,9 +312,5 @@ def eval_input_fn(params, ctx=None): def map_data_for_transformer_fn(x, y): """Maps data for training, and handles weried behaviors for different vers.""" # Will transform input x and targets y into tuple(x, y) as new model inputs. - if misc.is_v2(): - # For TF v2, the 2nd parameter is omitted to make Keras training work. - return ((x, y),) - else: - # For TF v1, Keras requires a dummy placeholder as the 2nd parameter. - return ((x, y), tf.constant(0.0)) + # For TF v2, the 2nd parameter is omitted to make Keras training work. + return ((x, y),) diff --git a/official/nlp/transformer/misc.py b/official/nlp/transformer/misc.py index c20f012dce81a9f06b63bbf94fa1de3492a78387..45b47741bc9cc93eabe92a73e2c59dfa9ddf812d 100644 --- a/official/nlp/transformer/misc.py +++ b/official/nlp/transformer/misc.py @@ -22,10 +22,6 @@ from __future__ import print_function from absl import flags import tensorflow as tf -# TODO(tianlin) Import internal library. Remove this when some functions for -# different TF versions are fixed. -from tensorflow.python import tf2 as tf2_internal - from official.nlp.transformer import model_params from official.utils.flags import core as flags_core from official.utils.misc import keras_utils @@ -39,11 +35,6 @@ PARAMS_MAP = { } -def is_v2(): - """Returns whether it is v2.""" - return tf2_internal.enabled() - - def get_model_params(param_set, num_gpus): """Gets predefined model params.""" if num_gpus > 1: @@ -78,17 +69,6 @@ def define_transformer_flags(): fp16_implementation=True ) - # Additional performance flags - # TODO(b/76028325): Remove when generic layout optimizer is ready. - flags.DEFINE_boolean( - name='enable_grappler_layout_optimizer', - default=True, - help='Enable Grappler layout optimizer. Currently Grappler can ' - 'de-optimize fp16 graphs by forcing NCHW layout for all ' - 'convolutions and batch normalizations, and this flag allows to ' - 'disable it.' - ) - flags_core.define_benchmark() flags_core.define_device(tpu=True)