提交 68104ce3 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 284805626
上级 f7fd59b8
......@@ -44,7 +44,6 @@ from official.utils.logs import logger
from official.utils.misc import keras_utils
from official.utils.misc import distribution_utils
INF = int(1e9)
BLEU_DIR = "bleu"
_SINGLE_SAMPLE = 1
......@@ -158,6 +157,7 @@ class TransformerTask(object):
params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
params["repeat_dataset"] = None
params["dtype"] = flags_core.get_tf_dtype(flags_obj)
params["enable_tensorboard"] = flags_obj.enable_tensorboard
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
params["steps_between_evals"] = flags_obj.steps_between_evals
......@@ -183,8 +183,8 @@ class TransformerTask(object):
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale = flags_core.get_loss_scale(flags_obj,
default_for_fp16="dynamic")
loss_scale = flags_core.get_loss_scale(
flags_obj, default_for_fp16="dynamic")
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_float16", loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
......@@ -206,8 +206,7 @@ class TransformerTask(object):
params = self.params
flags_obj = self.flags_obj
# Sets config options.
keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla)
keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)
_ensure_dir(flags_obj.model_dir)
with distribution_utils.get_strategy_scope(self.distribution_strategy):
......@@ -225,6 +224,14 @@ class TransformerTask(object):
if params["use_ctl"]:
train_loss_metric = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
if params["enable_tensorboard"]:
summary_writer = tf.compat.v2.summary.create_file_writer(
flags_obj.model_dir)
else:
summary_writer = tf.compat.v2.summary.create_noop_writer()
train_metrics = [train_loss_metric]
if params["enable_metrics_in_training"]:
train_metrics = train_metrics + model.metrics
else:
model.compile(opt)
......@@ -303,17 +310,23 @@ class TransformerTask(object):
raise NotImplementedError(
"Custom training loop on GPUs is not implemented.")
# Runs training steps.
train_steps(train_ds_iterator,
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
current_step += train_steps_per_eval
train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s",
current_step, flags_obj.train_steps, train_loss)
with summary_writer.as_default():
train_steps(
train_ds_iterator,
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
current_step += train_steps_per_eval
train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s", current_step,
flags_obj.train_steps, train_loss)
if params["enable_tensorboard"]:
for metric_obj in train_metrics:
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step)
checkpoint_name = checkpoint.save(
os.path.join(
flags_obj.model_dir,
"ctl_step_{}.ckpt".format(current_step)))
os.path.join(flags_obj.model_dir,
"ctl_step_{}.ckpt".format(current_step)))
logging.info("Saved checkpoint to %s", checkpoint_name)
else:
if self.use_tpu:
......@@ -391,8 +404,9 @@ class TransformerTask(object):
callbacks = misc.get_callbacks(params["steps_between_evals"])
callbacks.append(scheduler_callback)
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
save_weights_only=True))
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True))
return callbacks
def _load_weights_if_possible(self, model, init_weight_path=None):
......@@ -426,8 +440,9 @@ class TransformerTask(object):
if params["dtype"] == tf.float16:
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
opt,
loss_scale=flags_core.get_loss_scale(
self.flags_obj, default_for_fp16="dynamic"))
if self.flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册