diff --git a/official/vision/beta/configs/video_classification.py b/official/vision/beta/configs/video_classification.py index e097788106228bb2e4442c25451037c5e7ebb012..496c51ef962b4313394d2c7da188ff33bba6281f 100644 --- a/official/vision/beta/configs/video_classification.py +++ b/official/vision/beta/configs/video_classification.py @@ -50,6 +50,17 @@ class DataConfig(cfg.DataConfig): min_image_size: int = 256 +def kinetics400(is_training): + """Generated Kinectics 400 dataset configs.""" + return DataConfig( + name='kinetics400', + num_classes=400, + is_training=is_training, + split='train' if is_training else 'valid', + num_examples=215570 if is_training else 17706, + feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3)) + + def kinetics600(is_training): """Generated Kinectics 600 dataset configs.""" return DataConfig( @@ -153,9 +164,35 @@ def video_classification() -> cfg.ExperimentConfig: ]) +@exp_factory.register_config_factory('video_classification_kinetics400') +def video_classification_kinetics400() -> cfg.ExperimentConfig: + """Video classification on Kinectics 400 with resnet.""" + train_dataset = kinetics400(is_training=True) + validation_dataset = kinetics400(is_training=False) + task = VideoClassificationTask( + model=VideoClassificationModel( + backbone=backbones_3d.Backbone3D( + type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()), + norm_activation=common.NormActivation( + norm_momentum=0.9, norm_epsilon=1e-5)), + losses=Losses(l2_weight_decay=1e-4), + train_data=train_dataset, + validation_data=validation_dataset) + config = cfg.ExperimentConfig( + runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), + task=task, + restrictions=[ + 'task.train_data.is_training != None', + 'task.validation_data.is_training != None', + 'task.train_data.num_classes == task.validation_data.num_classes', + ]) + add_trainer(config, train_batch_size=1024, eval_batch_size=64) + return config + + @exp_factory.register_config_factory('video_classification_kinetics600') def video_classification_kinetics600() -> cfg.ExperimentConfig: - """Video classification on Videonet with resnet.""" + """Video classification on Kinectics 600 with resnet.""" train_dataset = kinetics600(is_training=True) validation_dataset = kinetics600(is_training=False) task = VideoClassificationTask( @@ -176,5 +213,4 @@ def video_classification_kinetics600() -> cfg.ExperimentConfig: 'task.train_data.num_classes == task.validation_data.num_classes', ]) add_trainer(config, train_batch_size=1024, eval_batch_size=64) - return config