提交 33c61588 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Move core configs to core/ folder. Leave the configs for legacy models.

PiperOrigin-RevId: 336795303
上级 c894eafe
......@@ -25,9 +25,9 @@ import orbit
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.modeling import optimization
from official.modeling import performance
from official.modeling.hyperparams import config_definitions
ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig
......
......@@ -23,8 +23,8 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import base_trainer as trainer_lib
from official.core import config_definitions as cfg
from official.core import train_lib
from official.modeling.hyperparams import config_definitions as cfg
from official.utils.testing import mock_task
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional, Sequence, Union
import dataclasses
from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimization_config
OptimizationConfig = optimization_config.OptimizationConfig
@dataclasses.dataclass
class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a str indicating
a file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
(3) a list of str, each of which is a file path/pattern or multiple file
paths/patterns separated by comma.
It should not be specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
deterministic: A boolean controlling whether determinism should be enforced.
sharding: Whether sharding is used in the input pipeline.
enable_tf_data_service: A boolean indicating whether to enable tf.data
service for the input pipeline.
tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This
argument makes it possible for multiple datasets to share the same job.
The default behavior is that the dataset creates anonymous, exclusively
owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default, the
returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped for
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
"""
input_path: Union[Sequence[str], str] = ""
tfds_name: str = ""
tfds_split: str = ""
global_batch_size: int = 0
is_training: bool = None
drop_remainder: bool = True
shuffle_buffer_size: int = 100
cache: bool = False
cycle_length: Optional[int] = None
block_length: int = 1
deterministic: Optional[bool] = None
sharding: bool = True
enable_tf_data_service: bool = False
tf_data_service_address: Optional[str] = None
tf_data_service_job_name: Optional[str] = None
tfds_data_dir: str = ""
tfds_download: bool = False
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
@dataclasses.dataclass
class RuntimeConfig(base_config.Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
'float16', or 'bfloat16'.
loss_scale: The type of loss scale, or 'float' value. This is used when
setting the mixed precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy: str = "mirrored"
enable_xla: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
# Global model parallelism configurations.
num_cores_per_replica: int = 1
default_shard_dim: int = -1
def model_parallelism(self):
return dict(
num_cores_per_replica=self.num_cores_per_replica,
default_shard_dim=self.default_shard_dim)
@dataclasses.dataclass
class TrainerConfig(base_config.Config):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. This
is only used continuous_train_and_eval and continuous_eval modes. Default
value is 1 hrs.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
best_checkpoint_export_subdir: if set, the trainer will keep track of the
best evaluation metric, and export the corresponding best checkpoint under
`model_dir/best_checkpoint_export_subdir`. Note that this only works if
mode contains eval (such as `train_and_eval`, `continuous_eval`, and
`continuous_train_and_eval`).
best_checkpoint_eval_metric: for exporting the best checkpoint, which
evaluation metric the trainer should monitor. This can be any evaluation
metric appears on tensorboard.
best_checkpoint_metric_comp: for exporting the best checkpoint, how the
trainer should compare the evaluation metrics. This can be either `higher`
(higher the better) or `lower` (lower the better).
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
# Orbit settings.
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
allow_tpu_summary: bool = False
# Trainer intervals.
steps_per_loop: int = 1000
summary_interval: int = 1000
checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5
continuous_eval_timeout: int = 60 * 60
# Train/Eval routines.
train_steps: int = 0
# Sets validation steps to be -1 to evaluate the entire dataset.
validation_steps: int = -1
validation_interval: int = 1000
# Best checkpoint export.
best_checkpoint_export_subdir: str = ""
best_checkpoint_eval_metric: str = ""
best_checkpoint_metric_comp: str = "higher"
@dataclasses.dataclass
class TaskConfig(base_config.Config):
init_checkpoint: str = ""
model: base_config.Config = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
@dataclasses.dataclass
class ExperimentConfig(base_config.Config):
"""Top-level configuration."""
task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig()
......@@ -15,8 +15,8 @@
# ==============================================================================
"""Experiment factory methods."""
from official.core import config_definitions as cfg
from official.core import registry
from official.modeling.hyperparams import config_definitions as cfg
_REGISTERED_CONFIGS = {}
......
......@@ -21,7 +21,7 @@ from typing import Any, Callable, Optional
import tensorflow as tf
import tensorflow_datasets as tfds
from official.modeling.hyperparams import config_definitions as cfg
from official.core import config_definitions as cfg
def _get_random_integer():
......
......@@ -27,7 +27,7 @@ import tensorflow as tf
from official.core import train_utils
from official.core import base_task
from official.modeling.hyperparams import config_definitions
from official.core import config_definitions
class BestCheckpointExporter:
......
......@@ -27,9 +27,9 @@ import tensorflow as tf
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions
def create_trainer(params: config_definitions.ExperimentConfig,
......
......@@ -14,143 +14,16 @@
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional, Sequence, Union
# pylint:disable=wildcard-import
import dataclasses
from official.core.config_definitions import *
from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimization_config
OptimizationConfig = optimization_config.OptimizationConfig
@dataclasses.dataclass
class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a str indicating
a file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
(3) a list of str, each of which is a file path/pattern or multiple file
paths/patterns separated by comma.
It should not be specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
deterministic: A boolean controlling whether determinism should be enforced.
sharding: Whether sharding is used in the input pipeline.
enable_tf_data_service: A boolean indicating whether to enable tf.data
service for the input pipeline.
tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This
argument makes it possible for multiple datasets to share the same job.
The default behavior is that the dataset creates anonymous, exclusively
owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default, the
returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped for
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
"""
input_path: Union[Sequence[str], str] = ""
tfds_name: str = ""
tfds_split: str = ""
global_batch_size: int = 0
is_training: bool = None
drop_remainder: bool = True
shuffle_buffer_size: int = 100
cache: bool = False
cycle_length: Optional[int] = None
block_length: int = 1
deterministic: Optional[bool] = None
sharding: bool = True
enable_tf_data_service: bool = False
tf_data_service_address: Optional[str] = None
tf_data_service_job_name: Optional[str] = None
tfds_data_dir: str = ""
tfds_download: bool = False
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
@dataclasses.dataclass
class RuntimeConfig(base_config.Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
'float16', or 'bfloat16'.
loss_scale: The type of loss scale, or 'float' value. This is used when
setting the mixed precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy: str = "mirrored"
enable_xla: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
# Global model parallelism configurations.
num_cores_per_replica: int = 1
default_shard_dim: int = -1
def model_parallelism(self):
return dict(
num_cores_per_replica=self.num_cores_per_replica,
default_shard_dim=self.default_shard_dim)
# TODO(hongkuny): These configs are used in models that are going to deprecate.
# Once those models are removed, we should delete this file to avoid confusion.
# Users should not use this file anymore.
@dataclasses.dataclass
class TensorboardConfig(base_config.Config):
"""Configuration for Tensorboard.
......@@ -183,79 +56,3 @@ class CallbacksConfig(base_config.Config):
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True
enable_time_history: bool = True
@dataclasses.dataclass
class TrainerConfig(base_config.Config):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. This
is only used continuous_train_and_eval and continuous_eval modes. Default
value is 1 hrs.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
best_checkpoint_export_subdir: if set, the trainer will keep track of the
best evaluation metric, and export the corresponding best checkpoint under
`model_dir/best_checkpoint_export_subdir`. Note that this only works if
mode contains eval (such as `train_and_eval`, `continuous_eval`, and
`continuous_train_and_eval`).
best_checkpoint_eval_metric: for exporting the best checkpoint, which
evaluation metric the trainer should monitor. This can be any evaluation
metric appears on tensorboard.
best_checkpoint_metric_comp: for exporting the best checkpoint, how the
trainer should compare the evaluation metrics. This can be either `higher`
(higher the better) or `lower` (lower the better).
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
# Orbit settings.
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
allow_tpu_summary: bool = False
# Trainer intervals.
steps_per_loop: int = 1000
summary_interval: int = 1000
checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5
continuous_eval_timeout: int = 60 * 60
# Train/Eval routines.
train_steps: int = 0
# Sets validation steps to be -1 to evaluate the entire dataset.
validation_steps: int = -1
validation_interval: int = 1000
# Best checkpoint export.
best_checkpoint_export_subdir: str = ""
best_checkpoint_eval_metric: str = ""
best_checkpoint_metric_comp: str = "higher"
@dataclasses.dataclass
class TaskConfig(base_config.Config):
init_checkpoint: str = ""
model: base_config.Config = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
@dataclasses.dataclass
class ExperimentConfig(base_config.Config):
"""Top-level configuration."""
task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig()
......@@ -18,7 +18,7 @@
import dataclasses
import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg
from official.core import config_definitions as cfg
from official.nlp.data import data_loader_factory
......
......@@ -18,9 +18,8 @@ from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory
......
......@@ -18,9 +18,8 @@ from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory
......
......@@ -18,9 +18,8 @@ from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory
......
......@@ -18,9 +18,8 @@ from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory
......
......@@ -41,9 +41,8 @@ from typing import Optional
import dataclasses
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory
......
......@@ -19,9 +19,9 @@ import dataclasses
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
......
......@@ -19,9 +19,9 @@ import dataclasses
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
......
......@@ -25,9 +25,9 @@ import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.bert import squad_evaluate_v2_0
from official.nlp.bert import tokenization
......
......@@ -26,10 +26,10 @@ import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
......
......@@ -25,9 +25,9 @@ import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
......
......@@ -14,7 +14,6 @@
# limitations under the License.
# ==============================================================================
"""TFM continuous finetuning+eval training driver."""
import gc
import os
import time
......@@ -31,11 +30,11 @@ from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import config_definitions
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.modeling.hyperparams import config_definitions
FLAGS = flags.FLAGS
......
......@@ -20,9 +20,9 @@ import numpy as np
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.core import task_factory
from official.modeling.hyperparams import config_definitions as cfg
class MockModel(tf.keras.Model):
......
......@@ -17,10 +17,10 @@
import os
from typing import List
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
......
......@@ -14,13 +14,12 @@
# limitations under the License.
# ==============================================================================
"""Tests for image_classification."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.vision import beta
from official.vision.beta.configs import image_classification as exp_cfg
......
......@@ -18,11 +18,10 @@
import os
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.configs import decoders
......
......@@ -18,11 +18,10 @@
import os
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.configs import decoders
......
......@@ -16,10 +16,10 @@
"""Video classification configuration definition."""
from typing import Optional, Tuple
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import backbones_3d
from official.vision.beta.configs import common
......
......@@ -19,8 +19,8 @@
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.vision import beta
from official.vision.beta.configs import video_classification as exp_cfg
......
......@@ -14,19 +14,16 @@
# limitations under the License.
# ==============================================================================
"""Definitions for high level configuration groups.."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, List, Mapping, Optional
import dataclasses
from official.core import config_definitions
from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions
from official.modeling.hyperparams import config_definitions as legacy_cfg
CallbacksConfig = config_definitions.CallbacksConfig
TensorboardConfig = config_definitions.TensorboardConfig
CallbacksConfig = legacy_cfg.CallbacksConfig
TensorboardConfig = legacy_cfg.TensorboardConfig
RuntimeConfig = config_definitions.RuntimeConfig
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册