diff --git a/official/nlp/tasks/masked_lm.py b/official/nlp/tasks/masked_lm.py index fa867bb8682ef0e7961edadc347155255ef04d23..c9bd84a4a8bb5c9a185632070a2a8961ede70049 100644 --- a/official/nlp/tasks/masked_lm.py +++ b/official/nlp/tasks/masked_lm.py @@ -31,15 +31,25 @@ from official.nlp.modeling import models @dataclasses.dataclass class MaskedLMConfig(cfg.TaskConfig): """The model config.""" - model: bert.PretrainerConfig = bert.PretrainerConfig(cls_heads=[ - bert.ClsHeadConfig( - inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') - ]) + model: bert.PretrainerConfig = dataclasses.field( + default_factory=lambda: bert.PretrainerConfig( # pylint: disable=g-long-lambda + cls_heads=[ + bert.ClsHeadConfig( + inner_dim=768, + num_classes=2, + dropout_rate=0.1, + name='next_sentence', + ) + ] + ) + ) # TODO(b/154564893): Mathematically, scale_loss should be True. # However, it works better with scale_loss being False. scale_loss: bool = False - train_data: cfg.DataConfig = cfg.DataConfig() - validation_data: cfg.DataConfig = cfg.DataConfig() + train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig) + validation_data: cfg.DataConfig = dataclasses.field( + default_factory=cfg.DataConfig + ) @task_factory.register_task_cls(MaskedLMConfig) diff --git a/official/projects/mosaic/configs/mosaic_config.py b/official/projects/mosaic/configs/mosaic_config.py index 1ef91edea535caaaa92b7ddb5b5759199c0e954e..23c1e47df1adda22af746e653ce75df5dba1c061 100644 --- a/official/projects/mosaic/configs/mosaic_config.py +++ b/official/projects/mosaic/configs/mosaic_config.py @@ -64,23 +64,37 @@ class MosaicSemanticSegmentationModel(hyperparams.Config): """MOSAIC semantic segmentation model config.""" num_classes: int = 19 input_size: List[int] = dataclasses.field(default_factory=list) - head: MosaicDecoderHead = MosaicDecoderHead() - backbone: backbones.Backbone = backbones.Backbone( - type='mobilenet', mobilenet=backbones.MobileNet()) - neck: MosaicEncoderNeck = MosaicEncoderNeck() + head: MosaicDecoderHead = dataclasses.field(default_factory=MosaicDecoderHead) + backbone: backbones.Backbone = dataclasses.field( + default_factory=lambda: backbones.Backbone( # pylint: disable=g-long-lambda + type='mobilenet', mobilenet=backbones.MobileNet() + ) + ) + neck: MosaicEncoderNeck = dataclasses.field(default_factory=MosaicEncoderNeck) mask_scoring_head: Optional[seg_cfg.MaskScoringHead] = None - norm_activation: common.NormActivation = common.NormActivation( - use_sync_bn=True, norm_momentum=0.99, norm_epsilon=0.001) + norm_activation: common.NormActivation = dataclasses.field( + default_factory=lambda: common.NormActivation( # pylint: disable=g-long-lambda + use_sync_bn=True, norm_momentum=0.99, norm_epsilon=0.001 + ) + ) @dataclasses.dataclass class MosaicSemanticSegmentationTask(seg_cfg.SemanticSegmentationTask): """The config for MOSAIC segmentation task.""" - model: MosaicSemanticSegmentationModel = MosaicSemanticSegmentationModel() - train_data: seg_cfg.DataConfig = seg_cfg.DataConfig(is_training=True) - validation_data: seg_cfg.DataConfig = seg_cfg.DataConfig(is_training=False) - losses: seg_cfg.Losses = seg_cfg.Losses() - evaluation: seg_cfg.Evaluation = seg_cfg.Evaluation() + model: MosaicSemanticSegmentationModel = dataclasses.field( + default_factory=MosaicSemanticSegmentationModel + ) + train_data: seg_cfg.DataConfig = dataclasses.field( + default_factory=lambda: seg_cfg.DataConfig(is_training=True) + ) + validation_data: seg_cfg.DataConfig = dataclasses.field( + default_factory=lambda: seg_cfg.DataConfig(is_training=False) + ) + losses: seg_cfg.Losses = dataclasses.field(default_factory=seg_cfg.Losses) + evaluation: seg_cfg.Evaluation = dataclasses.field( + default_factory=seg_cfg.Evaluation + ) train_input_partition_dims: List[int] = dataclasses.field( default_factory=list) eval_input_partition_dims: List[int] = dataclasses.field( @@ -88,7 +102,9 @@ class MosaicSemanticSegmentationTask(seg_cfg.SemanticSegmentationTask): init_checkpoint: Optional[str] = None init_checkpoint_modules: Union[ str, List[str]] = 'all' # all, backbone, and/or neck. - export_config: seg_cfg.ExportConfig = seg_cfg.ExportConfig() + export_config: seg_cfg.ExportConfig = dataclasses.field( + default_factory=seg_cfg.ExportConfig + ) # Cityscapes Dataset (Download and process the dataset yourself) diff --git a/official/projects/pix2seq/configs/pix2seq.py b/official/projects/pix2seq/configs/pix2seq.py index 529c8b692d386aa34bef0ba71795b64605f751bf..14cd26d51a7f2108dcceafdaece68b8a9948cbd9 100644 --- a/official/projects/pix2seq/configs/pix2seq.py +++ b/official/projects/pix2seq/configs/pix2seq.py @@ -67,7 +67,9 @@ class DataConfig(cfg.DataConfig): global_batch_size: int = 0 is_training: bool = False dtype: str = 'float32' - decoder: common.DataDecoder = common.DataDecoder() + decoder: common.DataDecoder = dataclasses.field( + default_factory=common.DataDecoder + ) shuffle_buffer_size: int = 10000 file_type: str = 'tfrecord' drop_remainder: bool = True @@ -97,10 +99,15 @@ class Pix2Seq(hyperparams.Config): shared_decoder_embedding: bool = True decoder_output_bias: bool = True input_size: List[int] = dataclasses.field(default_factory=list) - backbone: backbones.Backbone = backbones.Backbone( - type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False) + backbone: backbones.Backbone = dataclasses.field( + default_factory=lambda: backbones.Backbone( # pylint: disable=g-long-lambda + type='resnet', + resnet=backbones.ResNet(model_id=50, bn_trainable=False), + ) + ) + norm_activation: common.NormActivation = dataclasses.field( + default_factory=common.NormActivation ) - norm_activation: common.NormActivation = common.NormActivation() backbone_endpoint_name: str = '5' drop_path: float = 0.1 drop_units: float = 0.1 @@ -110,10 +117,12 @@ class Pix2Seq(hyperparams.Config): @dataclasses.dataclass class Pix2SeqTask(cfg.TaskConfig): - model: Pix2Seq = Pix2Seq() - train_data: cfg.DataConfig = cfg.DataConfig() - validation_data: cfg.DataConfig = cfg.DataConfig() - losses: Losses = Losses() + model: Pix2Seq = dataclasses.field(default_factory=Pix2Seq) + train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig) + validation_data: cfg.DataConfig = dataclasses.field( + default_factory=cfg.DataConfig + ) + losses: Losses = dataclasses.field(default_factory=Losses) init_checkpoint: Optional[str] = None init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone annotation_file: Optional[str] = None diff --git a/official/projects/s3d/configs/s3d.py b/official/projects/s3d/configs/s3d.py index 003cfe96d690a1ff19049cb973c1a7f701a5beab..17269167bbc38f584ef0e0f4a1287356bb9e5f35 100644 --- a/official/projects/s3d/configs/s3d.py +++ b/official/projects/s3d/configs/s3d.py @@ -83,7 +83,7 @@ class Backbone3D(backbones_3d.Backbone3D): s3d: s3d backbone config. """ type: str = 's3d' - s3d: S3D = S3D() + s3d: S3D = dataclasses.field(default_factory=S3D) @dataclasses.dataclass @@ -95,4 +95,4 @@ class S3DModel(video_classification.VideoClassificationModel): backbone: backbone config. """ model_type: str = 's3d' - backbone: Backbone3D = Backbone3D() + backbone: Backbone3D = dataclasses.field(default_factory=Backbone3D) diff --git a/official/projects/simclr/configs/multitask_config.py b/official/projects/simclr/configs/multitask_config.py index 81e1cb9dcd6068c5dfd406fd348c5c618e0fe42b..95cfcc8e1d695457a134e1af820be02caa04bec6 100644 --- a/official/projects/simclr/configs/multitask_config.py +++ b/official/projects/simclr/configs/multitask_config.py @@ -31,8 +31,9 @@ class SimCLRMTHeadConfig(hyperparams.Config): """Per-task specific configs.""" task_name: str = 'task_name' # Supervised head is required for finetune, but optional for pretrain. - supervised_head: simclr_configs.SupervisedHead = simclr_configs.SupervisedHead( - num_classes=1001) + supervised_head: simclr_configs.SupervisedHead = dataclasses.field( + default_factory=lambda: simclr_configs.SupervisedHead(num_classes=1001) + ) mode: str = simclr_model.PRETRAIN @@ -40,13 +41,22 @@ class SimCLRMTHeadConfig(hyperparams.Config): class SimCLRMTModelConfig(hyperparams.Config): """Model config for multi-task SimCLR model.""" input_size: List[int] = dataclasses.field(default_factory=list) - backbone: backbones.Backbone = backbones.Backbone( - type='resnet', resnet=backbones.ResNet()) + backbone: backbones.Backbone = dataclasses.field( + default_factory=lambda: backbones.Backbone( # pylint: disable=g-long-lambda + type='resnet', resnet=backbones.ResNet() + ) + ) backbone_trainable: bool = True - projection_head: simclr_configs.ProjectionHead = simclr_configs.ProjectionHead( - proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1) - norm_activation: common.NormActivation = common.NormActivation( - norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False) + projection_head: simclr_configs.ProjectionHead = dataclasses.field( + default_factory=lambda: simclr_configs.ProjectionHead( # pylint: disable=g-long-lambda + proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1 + ) + ) + norm_activation: common.NormActivation = dataclasses.field( + default_factory=lambda: common.NormActivation( # pylint: disable=g-long-lambda + norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False + ) + ) heads: Tuple[SimCLRMTHeadConfig, ...] = () # L2 weight decay is used in the model, not in task. # Note that this can not be used together with lars optimizer. diff --git a/official/projects/teams/teams.py b/official/projects/teams/teams.py index 8e1e659502969e5b95dd2e5abceaec57f171a331..ccae68c07c6c8122334709c835f869aa0fd6220f 100644 --- a/official/projects/teams/teams.py +++ b/official/projects/teams/teams.py @@ -43,8 +43,12 @@ class TeamsPretrainerConfig(base_config.Config): num_shared_generator_hidden_layers: int = 3 # Number of bottom layers shared between different discriminator tasks. num_discriminator_task_agnostic_layers: int = 11 - generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig() - discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig() + generator: encoders.BertEncoderConfig = dataclasses.field( + default_factory=encoders.BertEncoderConfig + ) + discriminator: encoders.BertEncoderConfig = dataclasses.field( + default_factory=encoders.BertEncoderConfig + ) class TeamsEncoderConfig(encoders.BertEncoderConfig): diff --git a/official/projects/teams/teams_task.py b/official/projects/teams/teams_task.py index 0728409ef008193b6068042a8f790de58e2c1999..cc367d801a9bdd6381331e5d314de7a4f8137251 100644 --- a/official/projects/teams/teams_task.py +++ b/official/projects/teams/teams_task.py @@ -30,9 +30,13 @@ from official.projects.teams import teams_pretrainer @dataclasses.dataclass class TeamsPretrainTaskConfig(cfg.TaskConfig): """The model config.""" - model: teams.TeamsPretrainerConfig = teams.TeamsPretrainerConfig() - train_data: cfg.DataConfig = cfg.DataConfig() - validation_data: cfg.DataConfig = cfg.DataConfig() + model: teams.TeamsPretrainerConfig = dataclasses.field( + default_factory=teams.TeamsPretrainerConfig + ) + train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig) + validation_data: cfg.DataConfig = dataclasses.field( + default_factory=cfg.DataConfig + ) def _get_generator_hidden_layers(discriminator_network, num_hidden_layers, diff --git a/official/projects/text_classification_example/classification_example.py b/official/projects/text_classification_example/classification_example.py index 4fdcef174747fadd08c94d6f931a41cbed98b19b..d3c60194f6fac157b4ed1bede90bcfd4d774f3b0 100644 --- a/official/projects/text_classification_example/classification_example.py +++ b/official/projects/text_classification_example/classification_example.py @@ -34,7 +34,9 @@ from official.projects.text_classification_example import classification_data_lo @dataclasses.dataclass class ModelConfig(base_config.Config): """A base span labeler configuration.""" - encoder: encoders.EncoderConfig = encoders.EncoderConfig() + encoder: encoders.EncoderConfig = dataclasses.field( + default_factory=encoders.EncoderConfig + ) head_dropout: float = 0.1 head_initializer_range: float = 0.02 @@ -49,9 +51,11 @@ class ClassificationExampleConfig(cfg.TaskConfig): num_classes = 2 class_names = ['A', 'B'] - train_data: cfg.DataConfig = classification_data_loader.ClassificationExampleDataConfig( + train_data: cfg.DataConfig = dataclasses.field( + default_factory=classification_data_loader.ClassificationExampleDataConfig ) - validation_data: cfg.DataConfig = classification_data_loader.ClassificationExampleDataConfig( + validation_data: cfg.DataConfig = dataclasses.field( + default_factory=classification_data_loader.ClassificationExampleDataConfig ) diff --git a/official/projects/unified_detector/configs/ocr_config.py b/official/projects/unified_detector/configs/ocr_config.py index eb482e9d68494d7e6091012c022da4dc8f73e07f..3b5b40d18df158a0042bb0aa4b81e7d5a8667493 100644 --- a/official/projects/unified_detector/configs/ocr_config.py +++ b/official/projects/unified_detector/configs/ocr_config.py @@ -22,7 +22,7 @@ from official.modeling import optimization @dataclasses.dataclass class OcrTaskConfig(cfg.TaskConfig): - train_data: cfg.DataConfig = cfg.DataConfig() + train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig) model_call_needs_labels: bool = False diff --git a/official/projects/video_ssl/configs/video_ssl.py b/official/projects/video_ssl/configs/video_ssl.py index 48b6a1b3dc1765d2416e1568c0251ea52c991d1f..b674d6049c03a211987be92b2eb1cf72f0d0cb86 100644 --- a/official/projects/video_ssl/configs/video_ssl.py +++ b/official/projects/video_ssl/configs/video_ssl.py @@ -43,8 +43,11 @@ class VideoSSLModel(VideoClassificationModel): hidden_dim: int = 2048 hidden_layer_num: int = 3 projection_dim: int = 128 - hidden_norm_activation: common.NormActivation = common.NormActivation( - use_sync_bn=False, norm_momentum=0.997, norm_epsilon=1.0e-05) + hidden_norm_activation: common.NormActivation = dataclasses.field( + default_factory=lambda: common.NormActivation( + use_sync_bn=False, norm_momentum=0.997, norm_epsilon=1.0e-05 + ) + ) @dataclasses.dataclass @@ -55,21 +58,31 @@ class SSLLosses(Losses): @dataclasses.dataclass class VideoSSLPretrainTask(VideoClassificationTask): - model: VideoSSLModel = VideoSSLModel() - losses: SSLLosses = SSLLosses() - train_data: DataConfig = DataConfig(is_training=True, drop_remainder=True) - validation_data: DataConfig = DataConfig( - is_training=False, drop_remainder=False) - losses: SSLLosses = SSLLosses() + model: VideoSSLModel = dataclasses.field(default_factory=VideoSSLModel) + losses: SSLLosses = dataclasses.field(default_factory=SSLLosses) + train_data: DataConfig = dataclasses.field( + default_factory=lambda: DataConfig(is_training=True, drop_remainder=True) + ) + validation_data: DataConfig = dataclasses.field( + default_factory=lambda: DataConfig( # pylint: disable=g-long-lambda + is_training=False, drop_remainder=False + ) + ) + losses: SSLLosses = dataclasses.field(default_factory=SSLLosses) @dataclasses.dataclass class VideoSSLEvalTask(VideoClassificationTask): - model: VideoSSLModel = VideoSSLModel() - train_data: DataConfig = DataConfig(is_training=True, drop_remainder=True) - validation_data: DataConfig = DataConfig( - is_training=False, drop_remainder=False) - losses: SSLLosses = SSLLosses() + model: VideoSSLModel = dataclasses.field(default_factory=VideoSSLModel) + train_data: DataConfig = dataclasses.field( + default_factory=lambda: DataConfig(is_training=True, drop_remainder=True) + ) + validation_data: DataConfig = dataclasses.field( + default_factory=lambda: DataConfig( # pylint: disable=g-long-lambda + is_training=False, drop_remainder=False + ) + ) + losses: SSLLosses = dataclasses.field(default_factory=SSLLosses) @exp_factory.register_config_factory('video_ssl_pretrain_kinetics400')