提交 c14a04ab 编写于 作者: F Fan Yang 提交者: A. Unique TensorFlower

Add precision and recall metrics at predefined thresholds for image classification task.

PiperOrigin-RevId: 448320923
上级 13642b0f
......@@ -65,6 +65,8 @@ class ImageClassificationModel(hyperparams.Config):
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
# Whether to output softmax results instead of logits.
output_softmax: bool = False
@dataclasses.dataclass
......@@ -79,6 +81,7 @@ class Losses(hyperparams.Config):
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
top_k: int = 5
precision_and_recall_thresholds: Optional[List[float]] = None
@dataclasses.dataclass
......
......@@ -184,6 +184,24 @@ class ImageClassificationTask(base_task.Task):
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
if hasattr(
self.task_config.evaluation, 'precision_and_recall_thresholds'
) and self.task_config.evaluation.precision_and_recall_thresholds:
thresholds = self.task_config.evaluation.precision_and_recall_thresholds
# pylint:disable=g-complex-comprehension
metrics += [
tf.keras.metrics.Precision(
thresholds=th,
name='precision_at_threshold_{}'.format(th),
top_k=1) for th in thresholds
]
metrics += [
tf.keras.metrics.Recall(
thresholds=th,
name='recall_at_threshold_{}'.format(th),
top_k=1) for th in thresholds
]
# pylint:enable=g-complex-comprehension
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
......@@ -234,6 +252,7 @@ class ImageClassificationTask(base_task.Task):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
......@@ -264,6 +283,11 @@ class ImageClassificationTask(base_task.Task):
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
# Convert logits to softmax for metric computation if needed.
if hasattr(self.task_config.model,
'output_softmax') and self.task_config.model.output_softmax:
outputs = tf.nn.softmax(outputs, axis=-1)
if metrics:
self.process_metrics(metrics, labels, outputs)
elif model.compiled_metrics:
......@@ -300,6 +324,10 @@ class ImageClassificationTask(base_task.Task):
aux_losses=model.losses)
logs = {self.loss: loss}
# Convert logits to softmax for metric computation if needed.
if hasattr(self.task_config.model,
'output_softmax') and self.task_config.model.output_softmax:
outputs = tf.nn.softmax(outputs, axis=-1)
if metrics:
self.process_metrics(metrics, labels, outputs)
elif model.compiled_metrics:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册