提交 595dbabe 编写于 作者: Y Yilei Yang 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 547338367
上级 843bdf24
......@@ -31,15 +31,25 @@ from official.nlp.modeling import models
@dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig):
"""The model config."""
model: bert.PretrainerConfig = bert.PretrainerConfig(cls_heads=[
bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
])
model: bert.PretrainerConfig = dataclasses.field(
default_factory=lambda: bert.PretrainerConfig( # pylint: disable=g-long-lambda
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
num_classes=2,
dropout_rate=0.1,
name='next_sentence',
)
]
)
)
# TODO(b/154564893): Mathematically, scale_loss should be True.
# However, it works better with scale_loss being False.
scale_loss: bool = False
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
validation_data: cfg.DataConfig = dataclasses.field(
default_factory=cfg.DataConfig
)
@task_factory.register_task_cls(MaskedLMConfig)
......
......@@ -64,23 +64,37 @@ class MosaicSemanticSegmentationModel(hyperparams.Config):
"""MOSAIC semantic segmentation model config."""
num_classes: int = 19
input_size: List[int] = dataclasses.field(default_factory=list)
head: MosaicDecoderHead = MosaicDecoderHead()
backbone: backbones.Backbone = backbones.Backbone(
type='mobilenet', mobilenet=backbones.MobileNet())
neck: MosaicEncoderNeck = MosaicEncoderNeck()
head: MosaicDecoderHead = dataclasses.field(default_factory=MosaicDecoderHead)
backbone: backbones.Backbone = dataclasses.field(
default_factory=lambda: backbones.Backbone( # pylint: disable=g-long-lambda
type='mobilenet', mobilenet=backbones.MobileNet()
)
)
neck: MosaicEncoderNeck = dataclasses.field(default_factory=MosaicEncoderNeck)
mask_scoring_head: Optional[seg_cfg.MaskScoringHead] = None
norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=True, norm_momentum=0.99, norm_epsilon=0.001)
norm_activation: common.NormActivation = dataclasses.field(
default_factory=lambda: common.NormActivation( # pylint: disable=g-long-lambda
use_sync_bn=True, norm_momentum=0.99, norm_epsilon=0.001
)
)
@dataclasses.dataclass
class MosaicSemanticSegmentationTask(seg_cfg.SemanticSegmentationTask):
"""The config for MOSAIC segmentation task."""
model: MosaicSemanticSegmentationModel = MosaicSemanticSegmentationModel()
train_data: seg_cfg.DataConfig = seg_cfg.DataConfig(is_training=True)
validation_data: seg_cfg.DataConfig = seg_cfg.DataConfig(is_training=False)
losses: seg_cfg.Losses = seg_cfg.Losses()
evaluation: seg_cfg.Evaluation = seg_cfg.Evaluation()
model: MosaicSemanticSegmentationModel = dataclasses.field(
default_factory=MosaicSemanticSegmentationModel
)
train_data: seg_cfg.DataConfig = dataclasses.field(
default_factory=lambda: seg_cfg.DataConfig(is_training=True)
)
validation_data: seg_cfg.DataConfig = dataclasses.field(
default_factory=lambda: seg_cfg.DataConfig(is_training=False)
)
losses: seg_cfg.Losses = dataclasses.field(default_factory=seg_cfg.Losses)
evaluation: seg_cfg.Evaluation = dataclasses.field(
default_factory=seg_cfg.Evaluation
)
train_input_partition_dims: List[int] = dataclasses.field(
default_factory=list)
eval_input_partition_dims: List[int] = dataclasses.field(
......@@ -88,7 +102,9 @@ class MosaicSemanticSegmentationTask(seg_cfg.SemanticSegmentationTask):
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or neck.
export_config: seg_cfg.ExportConfig = seg_cfg.ExportConfig()
export_config: seg_cfg.ExportConfig = dataclasses.field(
default_factory=seg_cfg.ExportConfig
)
# Cityscapes Dataset (Download and process the dataset yourself)
......
......@@ -67,7 +67,9 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'float32'
decoder: common.DataDecoder = common.DataDecoder()
decoder: common.DataDecoder = dataclasses.field(
default_factory=common.DataDecoder
)
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
drop_remainder: bool = True
......@@ -97,10 +99,15 @@ class Pix2Seq(hyperparams.Config):
shared_decoder_embedding: bool = True
decoder_output_bias: bool = True
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False)
backbone: backbones.Backbone = dataclasses.field(
default_factory=lambda: backbones.Backbone( # pylint: disable=g-long-lambda
type='resnet',
resnet=backbones.ResNet(model_id=50, bn_trainable=False),
)
)
norm_activation: common.NormActivation = dataclasses.field(
default_factory=common.NormActivation
)
norm_activation: common.NormActivation = common.NormActivation()
backbone_endpoint_name: str = '5'
drop_path: float = 0.1
drop_units: float = 0.1
......@@ -110,10 +117,12 @@ class Pix2Seq(hyperparams.Config):
@dataclasses.dataclass
class Pix2SeqTask(cfg.TaskConfig):
model: Pix2Seq = Pix2Seq()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
losses: Losses = Losses()
model: Pix2Seq = dataclasses.field(default_factory=Pix2Seq)
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
validation_data: cfg.DataConfig = dataclasses.field(
default_factory=cfg.DataConfig
)
losses: Losses = dataclasses.field(default_factory=Losses)
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone
annotation_file: Optional[str] = None
......
......@@ -83,7 +83,7 @@ class Backbone3D(backbones_3d.Backbone3D):
s3d: s3d backbone config.
"""
type: str = 's3d'
s3d: S3D = S3D()
s3d: S3D = dataclasses.field(default_factory=S3D)
@dataclasses.dataclass
......@@ -95,4 +95,4 @@ class S3DModel(video_classification.VideoClassificationModel):
backbone: backbone config.
"""
model_type: str = 's3d'
backbone: Backbone3D = Backbone3D()
backbone: Backbone3D = dataclasses.field(default_factory=Backbone3D)
......@@ -31,8 +31,9 @@ class SimCLRMTHeadConfig(hyperparams.Config):
"""Per-task specific configs."""
task_name: str = 'task_name'
# Supervised head is required for finetune, but optional for pretrain.
supervised_head: simclr_configs.SupervisedHead = simclr_configs.SupervisedHead(
num_classes=1001)
supervised_head: simclr_configs.SupervisedHead = dataclasses.field(
default_factory=lambda: simclr_configs.SupervisedHead(num_classes=1001)
)
mode: str = simclr_model.PRETRAIN
......@@ -40,13 +41,22 @@ class SimCLRMTHeadConfig(hyperparams.Config):
class SimCLRMTModelConfig(hyperparams.Config):
"""Model config for multi-task SimCLR model."""
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet())
backbone: backbones.Backbone = dataclasses.field(
default_factory=lambda: backbones.Backbone( # pylint: disable=g-long-lambda
type='resnet', resnet=backbones.ResNet()
)
)
backbone_trainable: bool = True
projection_head: simclr_configs.ProjectionHead = simclr_configs.ProjectionHead(
proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1)
norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
projection_head: simclr_configs.ProjectionHead = dataclasses.field(
default_factory=lambda: simclr_configs.ProjectionHead( # pylint: disable=g-long-lambda
proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1
)
)
norm_activation: common.NormActivation = dataclasses.field(
default_factory=lambda: common.NormActivation( # pylint: disable=g-long-lambda
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False
)
)
heads: Tuple[SimCLRMTHeadConfig, ...] = ()
# L2 weight decay is used in the model, not in task.
# Note that this can not be used together with lars optimizer.
......
......@@ -43,8 +43,12 @@ class TeamsPretrainerConfig(base_config.Config):
num_shared_generator_hidden_layers: int = 3
# Number of bottom layers shared between different discriminator tasks.
num_discriminator_task_agnostic_layers: int = 11
generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
generator: encoders.BertEncoderConfig = dataclasses.field(
default_factory=encoders.BertEncoderConfig
)
discriminator: encoders.BertEncoderConfig = dataclasses.field(
default_factory=encoders.BertEncoderConfig
)
class TeamsEncoderConfig(encoders.BertEncoderConfig):
......
......@@ -30,9 +30,13 @@ from official.projects.teams import teams_pretrainer
@dataclasses.dataclass
class TeamsPretrainTaskConfig(cfg.TaskConfig):
"""The model config."""
model: teams.TeamsPretrainerConfig = teams.TeamsPretrainerConfig()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
model: teams.TeamsPretrainerConfig = dataclasses.field(
default_factory=teams.TeamsPretrainerConfig
)
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
validation_data: cfg.DataConfig = dataclasses.field(
default_factory=cfg.DataConfig
)
def _get_generator_hidden_layers(discriminator_network, num_hidden_layers,
......
......@@ -34,7 +34,9 @@ from official.projects.text_classification_example import classification_data_lo
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A base span labeler configuration."""
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
encoder: encoders.EncoderConfig = dataclasses.field(
default_factory=encoders.EncoderConfig
)
head_dropout: float = 0.1
head_initializer_range: float = 0.02
......@@ -49,9 +51,11 @@ class ClassificationExampleConfig(cfg.TaskConfig):
num_classes = 2
class_names = ['A', 'B']
train_data: cfg.DataConfig = classification_data_loader.ClassificationExampleDataConfig(
train_data: cfg.DataConfig = dataclasses.field(
default_factory=classification_data_loader.ClassificationExampleDataConfig
)
validation_data: cfg.DataConfig = classification_data_loader.ClassificationExampleDataConfig(
validation_data: cfg.DataConfig = dataclasses.field(
default_factory=classification_data_loader.ClassificationExampleDataConfig
)
......
......@@ -22,7 +22,7 @@ from official.modeling import optimization
@dataclasses.dataclass
class OcrTaskConfig(cfg.TaskConfig):
train_data: cfg.DataConfig = cfg.DataConfig()
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
model_call_needs_labels: bool = False
......
......@@ -43,8 +43,11 @@ class VideoSSLModel(VideoClassificationModel):
hidden_dim: int = 2048
hidden_layer_num: int = 3
projection_dim: int = 128
hidden_norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=False, norm_momentum=0.997, norm_epsilon=1.0e-05)
hidden_norm_activation: common.NormActivation = dataclasses.field(
default_factory=lambda: common.NormActivation(
use_sync_bn=False, norm_momentum=0.997, norm_epsilon=1.0e-05
)
)
@dataclasses.dataclass
......@@ -55,21 +58,31 @@ class SSLLosses(Losses):
@dataclasses.dataclass
class VideoSSLPretrainTask(VideoClassificationTask):
model: VideoSSLModel = VideoSSLModel()
losses: SSLLosses = SSLLosses()
train_data: DataConfig = DataConfig(is_training=True, drop_remainder=True)
validation_data: DataConfig = DataConfig(
is_training=False, drop_remainder=False)
losses: SSLLosses = SSLLosses()
model: VideoSSLModel = dataclasses.field(default_factory=VideoSSLModel)
losses: SSLLosses = dataclasses.field(default_factory=SSLLosses)
train_data: DataConfig = dataclasses.field(
default_factory=lambda: DataConfig(is_training=True, drop_remainder=True)
)
validation_data: DataConfig = dataclasses.field(
default_factory=lambda: DataConfig( # pylint: disable=g-long-lambda
is_training=False, drop_remainder=False
)
)
losses: SSLLosses = dataclasses.field(default_factory=SSLLosses)
@dataclasses.dataclass
class VideoSSLEvalTask(VideoClassificationTask):
model: VideoSSLModel = VideoSSLModel()
train_data: DataConfig = DataConfig(is_training=True, drop_remainder=True)
validation_data: DataConfig = DataConfig(
is_training=False, drop_remainder=False)
losses: SSLLosses = SSLLosses()
model: VideoSSLModel = dataclasses.field(default_factory=VideoSSLModel)
train_data: DataConfig = dataclasses.field(
default_factory=lambda: DataConfig(is_training=True, drop_remainder=True)
)
validation_data: DataConfig = dataclasses.field(
default_factory=lambda: DataConfig( # pylint: disable=g-long-lambda
is_training=False, drop_remainder=False
)
)
losses: SSLLosses = dataclasses.field(default_factory=SSLLosses)
@exp_factory.register_config_factory('video_ssl_pretrain_kinetics400')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册