提交 432a448a 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Move files to core/ and common/

PiperOrigin-RevId: 326586473
上级 a003b7c1
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The central place to define flags."""
from absl import flags
def define_flags():
"""Defines flags."""
flags.DEFINE_string(
'experiment', default=None, help='The experiment type registered.')
flags.DEFINE_enum(
'mode',
default=None,
enum_values=['train', 'eval', 'train_and_eval',
'continuous_eval', 'continuous_train_and_eval'],
help='Mode to run: `train`, `eval`, `train_and_eval`, '
'`continuous_eval`, and `continuous_train_and_eval`.')
flags.DEFINE_string(
'model_dir',
default=None,
help='The directory where the model and training/evaluation summaries'
'are stored.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override',
default=None,
help='a YAML/JSON string or a YAML file which specifies additional '
'overrides over the default parameters and those specified in '
'`--config_file`. Note that this is supposed to be used only to override '
'the model parameters, but not the parameters like TPU specific flags. '
'One canonical use case of `--config_file` and `--params_override` is '
'users first define a template config file using `--config_file`, then '
'use `--params_override` to adjust the minimal set of tuning parameters, '
'for example setting up different `train_batch_size`. The final override '
'order of parameters: default_model_params --> params from config_file '
'--> params in params_override. See also the help message of '
'`--config_file`.')
flags.DEFINE_multi_string(
'gin_file', default=None, help='List of paths to the config files.')
flags.DEFINE_multi_string(
'gin_params',
default=None,
help='Newline separated list of Gin parameter bindings.')
flags.DEFINE_string(
'tpu', default=None,
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.')
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.nlp import tasks
from official.utils.testing import mock_task
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM common training driver library."""
import os
from typing import Any, Mapping
# Import libraries
from absl import logging
import orbit
import tensorflow as tf
from official.common import train_utils
from official.core import base_task
from official.modeling.hyperparams import config_definitions
def run_experiment(distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) -> Mapping[str, Any]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
othewise, returns {}.
"""
with distribution_strategy.scope():
trainer = train_utils.create_trainer(
params,
task,
model_dir,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval)
if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
else:
checkpoint_manager = None
controller = orbit.Controller(
distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (
save_summary) else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if (
save_summary) else None,
summary_interval=params.trainer.summary_interval if (
save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if run_post_eval:
with distribution_strategy.scope():
return trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
else:
return {}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for train_ctl_lib."""
import json
import os
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import flags as tfm_flags
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
FLAGS = flags.FLAGS
tfm_flags.define_flags()
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval)
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode != 'eval':
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training utils."""
import json
import os
import pprint
from absl import logging
import tensorflow as tf
from official.core import base_trainer
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions
def create_trainer(params, task, model_dir, train, evaluate):
del model_dir
logging.info('Running default trainer.')
trainer = base_trainer.Trainer(params, task, train=train, evaluate=evaluate)
return trainer
def parse_configuration(flags_obj):
"""Parses ExperimentConfig from flags."""
# 1. Get the default config from the registered experiment.
params = exp_factory.get_exp_config(flags_obj.experiment)
params.override({
'runtime': {
'tpu': flags_obj.tpu,
}
})
# 2. Get the first level of override from `--config_file`.
# `--config_file` is typically used as a template that specifies the common
# override for a particular experiment.
for config_file in flags_obj.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
# 3. Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if flags_obj.params_override:
params = hyperparams.override_params_dict(
params, flags_obj.params_override, is_strict=True)
params.validate()
params.lock()
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s', pp.pformat(params.as_dict()))
return params
def serialize_config(params: config_definitions.ExperimentConfig,
model_dir: str):
"""Serializes and saves the experiment config."""
params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir)
hyperparams.save_params_dict_to_yaml(params, params_save_path)
def read_global_step_from_checkpoint(ckpt_file_path):
"""Read global step from checkpoint, or get global step from its filename."""
global_step = tf.Variable(-1, dtype=tf.int64)
ckpt = tf.train.Checkpoint(global_step=global_step)
try:
ckpt.restore(ckpt_file_path).expect_partial()
global_step_maybe_restored = global_step.numpy()
except tf.errors.InvalidArgumentError:
global_step_maybe_restored = -1
if global_step_maybe_restored == -1:
raise ValueError('global_step not found in checkpoint {}. '
'If you want to run finetune eval jobs, you need to '
'make sure that your pretrain model writes '
'global_step in its checkpoints.'.format(ckpt_file_path))
global_step_restored = global_step.numpy()
logging.info('get global_step %d from checkpoint %s',
global_step_restored, ckpt_file_path)
return global_step_restored
def write_json_summary(log_dir, global_step, eval_metrics):
"""Dump evaluation metrics to json file."""
serializable_dict = {}
for name, value in eval_metrics.items():
if hasattr(value, 'numpy'):
serializable_dict[name] = str(value.numpy())
else:
serializable_dict[name] = str(value)
output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
logging.info('Evaluation results at pretrain step %d: %s',
global_step, serializable_dict)
with tf.io.gfile.GFile(output_json, 'w') as writer:
writer.write(json.dumps(serializable_dict, indent=4) + '\n')
def write_summary(summary_writer, global_step, eval_metrics):
"""Write evaluation metrics to TF summary."""
numeric_dict = {}
for name, value in eval_metrics.items():
if hasattr(value, 'numpy'):
numeric_dict[name] = value.numpy().astype(float)
else:
numeric_dict[name] = value
with summary_writer.as_default():
for name, value in numeric_dict.items():
tf.summary.scalar(name, value, step=global_step)
summary_writer.flush()
def remove_ckpts(model_dir):
"""Remove model checkpoints, so we can restart."""
ckpts = os.path.join(model_dir, 'ckpt-*')
logging.info('removing checkpoint files %s', ckpts)
for file_to_remove in tf.io.gfile.glob(ckpts):
tf.io.gfile.rmtree(file_to_remove)
file_to_remove = os.path.join(model_dir, 'checkpoint')
if tf.io.gfile.exists(file_to_remove):
tf.io.gfile.remove(file_to_remove)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册