提交 c44de4b0 编写于 作者: M miguelCalado

Changed PR folders

上级 011a5804
......@@ -36,6 +36,7 @@ from official.modeling import hyperparams
from official.modeling import performance
from official.utils import hyperparams_flags
from official.utils.misc import keras_utils
from official.vision.image_classification.vgg16 import vgg_model
def get_models() -> Mapping[str, tf.keras.Model]:
......@@ -43,6 +44,7 @@ def get_models() -> Mapping[str, tf.keras.Model]:
return {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
'vgg': vgg_model.vgg16,
}
......
......@@ -53,6 +53,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
model=[
'efficientnet',
'resnet',
'vgg',
],
dataset=[
'imagenet',
......@@ -149,6 +150,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model=[
'efficientnet',
'resnet',
'vgg',
],
dataset='imagenet',
dtype='float16',
......@@ -193,6 +195,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model=[
'efficientnet',
'resnet',
'vgg',
],
dataset='imagenet',
dtype='bfloat16',
......
......@@ -24,6 +24,7 @@ from official.legacy.image_classification import dataset_factory
from official.legacy.image_classification.configs import base_configs
from official.legacy.image_classification.efficientnet import efficientnet_config
from official.legacy.image_classification.resnet import resnet_config
from official.vision.image_classification.vgg16 import vgg_config
@dataclasses.dataclass
......@@ -92,12 +93,44 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
@dataclasses.dataclass
class VGGImagenetConfig(base_configs.ExperimentConfig):
"""Base configuration to train vgg-16 on ImageNet."""
export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
train_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='train',
one_hot=False,
mean_subtract=True,
standardize=True)
validation_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation',
one_hot=False,
mean_subtract=True,
standardize=True)
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=90,
steps=None,
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = vgg_config.VGGModelConfig()
def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
"""Given model and dataset names, return the ExperimentConfig."""
dataset_model_config_map = {
'imagenet': {
'efficientnet': EfficientNetImageNetConfig(),
'resnet': ResNetImagenetConfig(),
'vgg': VGGImagenetConfig(),
}
}
try:
......
......@@ -49,5 +49,5 @@ class VGGModelConfig(base_configs.ModelConfig):
examples_per_epoch=1281167,
boundaries=[30, 60],
warmup_epochs=0,
scale_by_batch_size=1. / 128.,
scale_by_batch_size=1. / 256.,
multipliers=[0.01 / 256, 0.001 / 256, 0.0001 / 256]))
......@@ -174,17 +174,23 @@ def vgg16(num_classes,
x = layers.Flatten(name='flatten')(x)
x = layers.Dense(4096,
#kernel_initializer=tf.initializers.random_normal(stddev=0.01),
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
#bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1')(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(4096,
#kernel_initializer=tf.initializers.random_normal(stddev=0.01),
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
#bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc2')(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(num_classes,
#kernel_initializer=tf.initializers.random_normal(stddev=0.01),
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
#bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1000')(x)
# A softmax that is followed by the model loss must be done cannot be done
......
......@@ -36,7 +36,6 @@ from official.vision.image_classification.configs import configs
from official.vision.image_classification.efficientnet import efficientnet_model
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_model
from official.vision.image_classification.vgg16 import vgg_model
def get_models() -> Mapping[str, tf.keras.Model]:
......@@ -44,7 +43,6 @@ def get_models() -> Mapping[str, tf.keras.Model]:
return {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
'vgg': vgg_model.vgg16,
}
......
......@@ -53,7 +53,6 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
model=[
'efficientnet',
'resnet',
'vgg',
],
dataset=[
'imagenet',
......@@ -150,7 +149,6 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model=[
'efficientnet',
'resnet',
'vgg',
],
dataset='imagenet',
dtype='float16',
......@@ -195,7 +193,6 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model=[
'efficientnet',
'resnet',
'vgg',
],
dataset='imagenet',
dtype='bfloat16',
......
......@@ -91,36 +91,6 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
@dataclasses.dataclass
class VGGImagenetConfig(base_configs.ExperimentConfig):
"""Base configuration to train vgg-16 on ImageNet."""
export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
train_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='train',
one_hot=False,
mean_subtract=True,
standardize=True)
validation_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation',
one_hot=False,
mean_subtract=True,
standardize=True)
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=90,
steps=None,
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = vgg_config.VGGModelConfig()
def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
"""Given model and dataset names, return the ExperimentConfig."""
......@@ -128,7 +98,6 @@ def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
'imagenet': {
'efficientnet': EfficientNetImageNetConfig(),
'resnet': ResNetImagenetConfig(),
'vgg': VGGImagenetConfig()
}
}
try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册