提交 9939ad5c 编写于 作者: L Le Hou 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 373623867
上级 4b820994
......@@ -124,9 +124,11 @@ def run_experiment(
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if hasattr(trainer.model, 'count_params'):
num_params = train_utils.try_count_params(trainer.model)
if num_params is not None:
logging.info('Number of trainable params in model: %f Millions.',
trainer.model.count_params() / 10.**6)
num_params / 10.**6)
if run_post_eval:
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
......
......@@ -367,3 +367,24 @@ def remove_ckpts(model_dir):
file_to_remove = os.path.join(model_dir, 'checkpoint')
if tf.io.gfile.exists(file_to_remove):
tf.io.gfile.remove(file_to_remove)
def try_count_params(model: tf.keras.Model):
"""Count the number of parameters if model is possible.
Args:
model: Try to count the number of params in this model.
Returns:
The number of parameters or None.
"""
if hasattr(model, 'count_params'):
try:
return model.count_params()
except ValueError:
logging.info('Number of trainable params unknown, because the build() '
'methods in keras layers were not called. This is probably '
'because the model was not feed any input, e.g., the max '
'train step already reached before this run.')
return None
return None
......@@ -15,6 +15,7 @@
"""Multitask training driver library."""
# pytype: disable=attribute-error
import os
from typing import Optional
from absl import logging
import orbit
import tensorflow as tf
......@@ -139,7 +140,8 @@ def run_experiment_with_multitask_eval(
params: configs.MultiEvalExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) -> tf.keras.Model:
save_summary: bool = True,
trainer: Optional[core_lib.Trainer] = None) -> tf.keras.Model:
"""Runs train/eval configured by the experiment params.
Args:
......@@ -153,6 +155,9 @@ def run_experiment_with_multitask_eval(
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the core_lib.Trainer instance. It should be created within the
strategy.scope(). If not provided, an instance will be created by default
if `mode` contains 'train'.
Returns:
model: `tf.keras.Model` instance.
......@@ -161,19 +166,19 @@ def run_experiment_with_multitask_eval(
is_training = 'train' in mode
is_eval = 'eval' in mode
with distribution_strategy.scope():
optimizer = train_task.create_optimizer(params.trainer.optimizer_config,
params.runtime)
model = train_task.build_model()
if is_training:
trainer = core_lib.Trainer(
trainer = trainer or core_lib.Trainer(
config=params,
task=train_task,
model=model,
optimizer=optimizer,
model=train_task.build_model(),
optimizer=train_task.create_optimizer(
params.trainer.optimizer_config, params.runtime),
train=True,
evaluate=False)
else:
trainer = None
model = trainer.model if trainer else train_task.build_model()
if is_eval:
evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册