提交 ed6d4d22 编写于 作者: A A. Unique TensorFlower

Merge pull request #11023 from tensorflow:sineeli-patch-13

PiperOrigin-RevId: 542258048
......@@ -20,6 +20,7 @@ from typing import List, Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.projects.yolo.configs import backbones
from official.vision.configs import common
from official.vision.configs import image_classification as imc
......@@ -44,6 +45,9 @@ class Losses(hyperparams.Config):
one_hot: bool = True
label_smoothing: float = 0.0
l2_weight_decay: float = 0.0
loss_weight: float = 1.0
soft_labels: bool = False
use_binary_cross_entropy: bool = False
@dataclasses.dataclass
......@@ -56,6 +60,7 @@ class ImageClassificationTask(cfg.TaskConfig):
losses: Losses = Losses()
gradient_clip_norm: float = 0.0
logging_dir: Optional[str] = None
freeze_backbone: bool = False
@exp_factory.register_config_factory('darknet_classification')
......@@ -63,8 +68,23 @@ def darknet_classification() -> cfg.ExperimentConfig:
"""Image classification general."""
return cfg.ExperimentConfig(
task=ImageClassificationTask(),
trainer=cfg.TrainerConfig(),
trainer=cfg.TrainerConfig(
optimizer_config=optimization.OptimizationConfig({
'optimizer': {'type': 'sgd', 'sgd': {'momentum': 0.9}},
'learning_rate': {
'type': 'polynomial',
'initial_learning_rate': 0.1,
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_learning_rate': 0,
},
},
})
),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
'task.validation_data.is_training != None',
],
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册