diff --git a/official/core/train_lib.py b/official/core/train_lib.py index ba66c2a75627213f6c50bcdab555aa36b668d21b..32898a4d35972396e08af7dc773d941b9d0a5e40 100644 --- a/official/core/train_lib.py +++ b/official/core/train_lib.py @@ -124,9 +124,11 @@ def run_experiment( else: raise NotImplementedError('The mode is not implemented: %s' % mode) - if hasattr(trainer.model, 'count_params'): + num_params = train_utils.try_count_params(trainer.model) + if num_params is not None: logging.info('Number of trainable params in model: %f Millions.', - trainer.model.count_params() / 10.**6) + num_params / 10.**6) + if run_post_eval: with distribution_strategy.scope(): return trainer.model, trainer.evaluate( diff --git a/official/core/train_utils.py b/official/core/train_utils.py index 32a3cd7be6d8ac87b30428aa5e5cd15ca1ab802a..c3d2a811cbbd0ad9a4856b596ae1f280857443a4 100644 --- a/official/core/train_utils.py +++ b/official/core/train_utils.py @@ -367,3 +367,24 @@ def remove_ckpts(model_dir): file_to_remove = os.path.join(model_dir, 'checkpoint') if tf.io.gfile.exists(file_to_remove): tf.io.gfile.remove(file_to_remove) + + +def try_count_params(model: tf.keras.Model): + """Count the number of parameters if model is possible. + + Args: + model: Try to count the number of params in this model. + + Returns: + The number of parameters or None. + """ + if hasattr(model, 'count_params'): + try: + return model.count_params() + except ValueError: + logging.info('Number of trainable params unknown, because the build() ' + 'methods in keras layers were not called. This is probably ' + 'because the model was not feed any input, e.g., the max ' + 'train step already reached before this run.') + return None + return None diff --git a/official/modeling/multitask/train_lib.py b/official/modeling/multitask/train_lib.py index b2fa9a0e76fdee2d734e1c050f8201fcab1e7ede..0580964a61e5e7f06e8ab158e1dfdee6af92b399 100644 --- a/official/modeling/multitask/train_lib.py +++ b/official/modeling/multitask/train_lib.py @@ -15,6 +15,7 @@ """Multitask training driver library.""" # pytype: disable=attribute-error import os +from typing import Optional from absl import logging import orbit import tensorflow as tf @@ -139,7 +140,8 @@ def run_experiment_with_multitask_eval( params: configs.MultiEvalExperimentConfig, model_dir: str, run_post_eval: bool = False, - save_summary: bool = True) -> tf.keras.Model: + save_summary: bool = True, + trainer: Optional[core_lib.Trainer] = None) -> tf.keras.Model: """Runs train/eval configured by the experiment params. Args: @@ -153,6 +155,9 @@ def run_experiment_with_multitask_eval( run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. + trainer: the core_lib.Trainer instance. It should be created within the + strategy.scope(). If not provided, an instance will be created by default + if `mode` contains 'train'. Returns: model: `tf.keras.Model` instance. @@ -161,19 +166,19 @@ def run_experiment_with_multitask_eval( is_training = 'train' in mode is_eval = 'eval' in mode with distribution_strategy.scope(): - optimizer = train_task.create_optimizer(params.trainer.optimizer_config, - params.runtime) - model = train_task.build_model() if is_training: - trainer = core_lib.Trainer( + trainer = trainer or core_lib.Trainer( config=params, task=train_task, - model=model, - optimizer=optimizer, + model=train_task.build_model(), + optimizer=train_task.create_optimizer( + params.trainer.optimizer_config, params.runtime), train=True, evaluate=False) else: trainer = None + model = trainer.model if trainer else train_task.build_model() + if is_eval: evaluator = evaluator_lib.MultiTaskEvaluator( task=eval_tasks,