From 432a448a2a1a49cf3a3ec8bba9f883877a3015f2 Mon Sep 17 00:00:00 2001 From: Hongkun Yu Date: Thu, 13 Aug 2020 21:26:13 -0700 Subject: [PATCH] Move files to core/ and common/ PiperOrigin-RevId: 326586473 --- official/common/__init__.py | 1 + official/common/flags.py | 77 +++++++++++++++ official/common/registry_imports.py | 19 ++++ official/core/train_lib.py | 112 ++++++++++++++++++++++ official/core/train_lib_test.py | 106 +++++++++++++++++++++ official/core/train_utils.py | 142 ++++++++++++++++++++++++++++ 6 files changed, 457 insertions(+) create mode 100644 official/common/__init__.py create mode 100644 official/common/flags.py create mode 100644 official/common/registry_imports.py create mode 100644 official/core/train_lib.py create mode 100644 official/core/train_lib_test.py create mode 100644 official/core/train_utils.py diff --git a/official/common/__init__.py b/official/common/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/official/common/__init__.py @@ -0,0 +1 @@ + diff --git a/official/common/flags.py b/official/common/flags.py new file mode 100644 index 000000000..2e065e070 --- /dev/null +++ b/official/common/flags.py @@ -0,0 +1,77 @@ +# Lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The central place to define flags.""" + +from absl import flags + + +def define_flags(): + """Defines flags.""" + flags.DEFINE_string( + 'experiment', default=None, help='The experiment type registered.') + + flags.DEFINE_enum( + 'mode', + default=None, + enum_values=['train', 'eval', 'train_and_eval', + 'continuous_eval', 'continuous_train_and_eval'], + help='Mode to run: `train`, `eval`, `train_and_eval`, ' + '`continuous_eval`, and `continuous_train_and_eval`.') + + flags.DEFINE_string( + 'model_dir', + default=None, + help='The directory where the model and training/evaluation summaries' + 'are stored.') + + flags.DEFINE_multi_string( + 'config_file', + default=None, + help='YAML/JSON files which specifies overrides. The override order ' + 'follows the order of args. Note that each file ' + 'can be used as an override template to override the default parameters ' + 'specified in Python. If the same parameter is specified in both ' + '`--config_file` and `--params_override`, `config_file` will be used ' + 'first, followed by params_override.') + + flags.DEFINE_string( + 'params_override', + default=None, + help='a YAML/JSON string or a YAML file which specifies additional ' + 'overrides over the default parameters and those specified in ' + '`--config_file`. Note that this is supposed to be used only to override ' + 'the model parameters, but not the parameters like TPU specific flags. ' + 'One canonical use case of `--config_file` and `--params_override` is ' + 'users first define a template config file using `--config_file`, then ' + 'use `--params_override` to adjust the minimal set of tuning parameters, ' + 'for example setting up different `train_batch_size`. The final override ' + 'order of parameters: default_model_params --> params from config_file ' + '--> params in params_override. See also the help message of ' + '`--config_file`.') + + flags.DEFINE_multi_string( + 'gin_file', default=None, help='List of paths to the config files.') + + flags.DEFINE_multi_string( + 'gin_params', + default=None, + help='Newline separated list of Gin parameter bindings.') + + flags.DEFINE_string( + 'tpu', default=None, + help='The Cloud TPU to use for training. This should be either the name ' + 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 ' + 'url.') diff --git a/official/common/registry_imports.py b/official/common/registry_imports.py new file mode 100644 index 000000000..c7e3cc974 --- /dev/null +++ b/official/common/registry_imports.py @@ -0,0 +1,19 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""All necessary imports for registration.""" + +# pylint: disable=unused-import +from official.nlp import tasks +from official.utils.testing import mock_task diff --git a/official/core/train_lib.py b/official/core/train_lib.py new file mode 100644 index 000000000..342bf0f02 --- /dev/null +++ b/official/core/train_lib.py @@ -0,0 +1,112 @@ +# Lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFM common training driver library.""" + +import os +from typing import Any, Mapping + +# Import libraries +from absl import logging +import orbit +import tensorflow as tf + +from official.common import train_utils +from official.core import base_task +from official.modeling.hyperparams import config_definitions + + +def run_experiment(distribution_strategy: tf.distribute.Strategy, + task: base_task.Task, + mode: str, + params: config_definitions.ExperimentConfig, + model_dir: str, + run_post_eval: bool = False, + save_summary: bool = True) -> Mapping[str, Any]: + """Runs train/eval configured by the experiment params. + + Args: + distribution_strategy: A distribution distribution_strategy. + task: A Task instance. + mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' + or 'continuous_eval'. + params: ExperimentConfig instance. + model_dir: A 'str', a path to store model checkpoints and summaries. + run_post_eval: Whether to run post eval once after training, metrics logs + are returned. + save_summary: Whether to save train and validation summary. + + Returns: + eval logs: returns eval metrics logs when run_post_eval is set to True, + othewise, returns {}. + """ + + with distribution_strategy.scope(): + trainer = train_utils.create_trainer( + params, + task, + model_dir, + train='train' in mode, + evaluate=('eval' in mode) or run_post_eval) + + if trainer.checkpoint: + checkpoint_manager = tf.train.CheckpointManager( + trainer.checkpoint, + directory=model_dir, + max_to_keep=params.trainer.max_to_keep, + step_counter=trainer.global_step, + checkpoint_interval=params.trainer.checkpoint_interval, + init_fn=trainer.initialize) + else: + checkpoint_manager = None + + controller = orbit.Controller( + distribution_strategy, + trainer=trainer if 'train' in mode else None, + evaluator=trainer, + global_step=trainer.global_step, + steps_per_loop=params.trainer.steps_per_loop, + checkpoint_manager=checkpoint_manager, + summary_dir=os.path.join(model_dir, 'train') if ( + save_summary) else None, + eval_summary_dir=os.path.join(model_dir, 'validation') if ( + save_summary) else None, + summary_interval=params.trainer.summary_interval if ( + save_summary) else None) + + logging.info('Starts to execute mode: %s', mode) + with distribution_strategy.scope(): + if mode == 'train': + controller.train(steps=params.trainer.train_steps) + elif mode == 'train_and_eval': + controller.train_and_evaluate( + train_steps=params.trainer.train_steps, + eval_steps=params.trainer.validation_steps, + eval_interval=params.trainer.validation_interval) + elif mode == 'eval': + controller.evaluate(steps=params.trainer.validation_steps) + elif mode == 'continuous_eval': + controller.evaluate_continuously( + steps=params.trainer.validation_steps, + timeout=params.trainer.continuous_eval_timeout) + else: + raise NotImplementedError('The mode is not implemented: %s' % mode) + + if run_post_eval: + with distribution_strategy.scope(): + return trainer.evaluate( + tf.convert_to_tensor(params.trainer.validation_steps)) + else: + return {} diff --git a/official/core/train_lib_test.py b/official/core/train_lib_test.py new file mode 100644 index 000000000..6b5248e5f --- /dev/null +++ b/official/core/train_lib_test.py @@ -0,0 +1,106 @@ +# Lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for train_ctl_lib.""" +import json +import os + +from absl import flags +from absl.testing import flagsaver +from absl.testing import parameterized +import tensorflow as tf + +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import strategy_combinations +from official.common import flags as tfm_flags +# pylint: disable=unused-import +from official.common import registry_imports +# pylint: enable=unused-import +from official.core import task_factory +from official.core import train_lib +from official.core import train_utils + +FLAGS = flags.FLAGS + +tfm_flags.define_flags() + + +class TrainTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super(TrainTest, self).setUp() + self._test_config = { + 'trainer': { + 'checkpoint_interval': 10, + 'steps_per_loop': 10, + 'summary_interval': 10, + 'train_steps': 10, + 'validation_steps': 5, + 'validation_interval': 10, + 'optimizer_config': { + 'optimizer': { + 'type': 'sgd', + }, + 'learning_rate': { + 'type': 'constant' + } + } + }, + } + + @combinations.generate( + combinations.combine( + distribution_strategy=[ + strategy_combinations.default_strategy, + strategy_combinations.tpu_strategy, + strategy_combinations.one_device_strategy_gpu, + ], + mode='eager', + flag_mode=['train', 'eval', 'train_and_eval'], + run_post_eval=[True, False])) + def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval): + model_dir = self.get_temp_dir() + flags_dict = dict( + experiment='mock', + mode=flag_mode, + model_dir=model_dir, + params_override=json.dumps(self._test_config)) + with flagsaver.flagsaver(**flags_dict): + params = train_utils.parse_configuration(flags.FLAGS) + train_utils.serialize_config(params, model_dir) + with distribution_strategy.scope(): + task = task_factory.get_task(params.task, logging_dir=model_dir) + + logs = train_lib.run_experiment( + distribution_strategy=distribution_strategy, + task=task, + mode=flag_mode, + params=params, + model_dir=model_dir, + run_post_eval=run_post_eval) + + if run_post_eval: + self.assertNotEmpty(logs) + else: + self.assertEmpty(logs) + self.assertNotEmpty( + tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml'))) + if flag_mode != 'eval': + self.assertNotEmpty( + tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint'))) + + +if __name__ == '__main__': + tf.test.main() diff --git a/official/core/train_utils.py b/official/core/train_utils.py new file mode 100644 index 000000000..b304aacab --- /dev/null +++ b/official/core/train_utils.py @@ -0,0 +1,142 @@ +# Lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Training utils.""" + +import json +import os +import pprint + +from absl import logging +import tensorflow as tf + +from official.core import base_trainer +from official.core import exp_factory +from official.modeling import hyperparams +from official.modeling.hyperparams import config_definitions + + +def create_trainer(params, task, model_dir, train, evaluate): + del model_dir + logging.info('Running default trainer.') + trainer = base_trainer.Trainer(params, task, train=train, evaluate=evaluate) + return trainer + + +def parse_configuration(flags_obj): + """Parses ExperimentConfig from flags.""" + + # 1. Get the default config from the registered experiment. + params = exp_factory.get_exp_config(flags_obj.experiment) + params.override({ + 'runtime': { + 'tpu': flags_obj.tpu, + } + }) + + # 2. Get the first level of override from `--config_file`. + # `--config_file` is typically used as a template that specifies the common + # override for a particular experiment. + for config_file in flags_obj.config_file or []: + params = hyperparams.override_params_dict( + params, config_file, is_strict=True) + + # 3. Get the second level of override from `--params_override`. + # `--params_override` is typically used as a further override over the + # template. For example, one may define a particular template for training + # ResNet50 on ImageNet in a config file and pass it via `--config_file`, + # then define different learning rates and pass it via `--params_override`. + if flags_obj.params_override: + params = hyperparams.override_params_dict( + params, flags_obj.params_override, is_strict=True) + + params.validate() + params.lock() + + pp = pprint.PrettyPrinter() + logging.info('Final experiment parameters: %s', pp.pformat(params.as_dict())) + + return params + + +def serialize_config(params: config_definitions.ExperimentConfig, + model_dir: str): + """Serializes and saves the experiment config.""" + params_save_path = os.path.join(model_dir, 'params.yaml') + logging.info('Saving experiment configuration to %s', params_save_path) + tf.io.gfile.makedirs(model_dir) + hyperparams.save_params_dict_to_yaml(params, params_save_path) + + +def read_global_step_from_checkpoint(ckpt_file_path): + """Read global step from checkpoint, or get global step from its filename.""" + global_step = tf.Variable(-1, dtype=tf.int64) + ckpt = tf.train.Checkpoint(global_step=global_step) + try: + ckpt.restore(ckpt_file_path).expect_partial() + global_step_maybe_restored = global_step.numpy() + except tf.errors.InvalidArgumentError: + global_step_maybe_restored = -1 + + if global_step_maybe_restored == -1: + raise ValueError('global_step not found in checkpoint {}. ' + 'If you want to run finetune eval jobs, you need to ' + 'make sure that your pretrain model writes ' + 'global_step in its checkpoints.'.format(ckpt_file_path)) + global_step_restored = global_step.numpy() + logging.info('get global_step %d from checkpoint %s', + global_step_restored, ckpt_file_path) + return global_step_restored + + +def write_json_summary(log_dir, global_step, eval_metrics): + """Dump evaluation metrics to json file.""" + serializable_dict = {} + for name, value in eval_metrics.items(): + if hasattr(value, 'numpy'): + serializable_dict[name] = str(value.numpy()) + else: + serializable_dict[name] = str(value) + output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step)) + logging.info('Evaluation results at pretrain step %d: %s', + global_step, serializable_dict) + with tf.io.gfile.GFile(output_json, 'w') as writer: + writer.write(json.dumps(serializable_dict, indent=4) + '\n') + + +def write_summary(summary_writer, global_step, eval_metrics): + """Write evaluation metrics to TF summary.""" + numeric_dict = {} + for name, value in eval_metrics.items(): + if hasattr(value, 'numpy'): + numeric_dict[name] = value.numpy().astype(float) + else: + numeric_dict[name] = value + with summary_writer.as_default(): + for name, value in numeric_dict.items(): + tf.summary.scalar(name, value, step=global_step) + summary_writer.flush() + + +def remove_ckpts(model_dir): + """Remove model checkpoints, so we can restart.""" + ckpts = os.path.join(model_dir, 'ckpt-*') + logging.info('removing checkpoint files %s', ckpts) + for file_to_remove in tf.io.gfile.glob(ckpts): + tf.io.gfile.rmtree(file_to_remove) + + file_to_remove = os.path.join(model_dir, 'checkpoint') + if tf.io.gfile.exists(file_to_remove): + tf.io.gfile.remove(file_to_remove) -- GitLab