提交 7e4e52cd 编写于 作者: Y Yin Cui 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 338348738
上级 f3641f23
......@@ -50,6 +50,17 @@ class DataConfig(cfg.DataConfig):
min_image_size: int = 256
def kinetics400(is_training):
"""Generated Kinectics 400 dataset configs."""
return DataConfig(
name='kinetics400',
num_classes=400,
is_training=is_training,
split='train' if is_training else 'valid',
num_examples=215570 if is_training else 17706,
feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3))
def kinetics600(is_training):
"""Generated Kinectics 600 dataset configs."""
return DataConfig(
......@@ -153,9 +164,35 @@ def video_classification() -> cfg.ExperimentConfig:
])
@exp_factory.register_config_factory('video_classification_kinetics400')
def video_classification_kinetics400() -> cfg.ExperimentConfig:
"""Video classification on Kinectics 400 with resnet."""
train_dataset = kinetics400(is_training=True)
validation_dataset = kinetics400(is_training=False)
task = VideoClassificationTask(
model=VideoClassificationModel(
backbone=backbones_3d.Backbone3D(
type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5)),
losses=Losses(l2_weight_decay=1e-4),
train_data=train_dataset,
validation_data=validation_dataset)
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=task,
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
'task.train_data.num_classes == task.validation_data.num_classes',
])
add_trainer(config, train_batch_size=1024, eval_batch_size=64)
return config
@exp_factory.register_config_factory('video_classification_kinetics600')
def video_classification_kinetics600() -> cfg.ExperimentConfig:
"""Video classification on Videonet with resnet."""
"""Video classification on Kinectics 600 with resnet."""
train_dataset = kinetics600(is_training=True)
validation_dataset = kinetics600(is_training=False)
task = VideoClassificationTask(
......@@ -176,5 +213,4 @@ def video_classification_kinetics600() -> cfg.ExperimentConfig:
'task.train_data.num_classes == task.validation_data.num_classes',
])
add_trainer(config, train_batch_size=1024, eval_batch_size=64)
return config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册