提交 041f6976 编写于 作者: P Pengchong Jin 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 341687182
上级 d2501e46
......@@ -57,6 +57,11 @@ class Losses(hyperparams.Config):
l2_weight_decay: float = 0.0
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
top_k: int = 5
@dataclasses.dataclass
class ImageClassificationTask(cfg.TaskConfig):
"""The task config."""
......@@ -64,6 +69,7 @@ class ImageClassificationTask(cfg.TaskConfig):
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
evaluation: Evaluation = Evaluation()
gradient_clip_norm: float = 0.0
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
......
......@@ -123,15 +123,17 @@ class ImageClassificationTask(base_task.Task):
def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation."""
k = self.task_config.evaluation.top_k
if self.task_config.losses.one_hot:
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')]
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=5, name='top_5_accuracy')]
k=k, name='top_{}_accuracy'.format(k))]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册