提交 e7a87506 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 481231095
上级 26afeefb
......@@ -244,7 +244,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
dtype=tf.float32)
else:
self._build_coco_metrics()
if self.task_config.use_coco_metrics:
self._build_coco_metrics()
rescale_predictions = (not self.task_config.validation_data.parser
.segmentation_resize_eval_groundtruth)
......@@ -366,24 +367,25 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
training=False)
logs = {self.loss: 0}
coco_model_outputs = {
'detection_masks': outputs['detection_masks'],
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
if self._task_config.use_coco_metrics:
coco_model_outputs = {
'detection_masks': outputs['detection_masks'],
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
logs.update(
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
segmentation_labels = {
'masks': labels['groundtruths']['gt_segmentation_mask'],
'valid_masks': labels['groundtruths']['gt_segmentation_valid_mask'],
'image_info': labels['image_info']
}
logs.update(
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
self.segmentation_perclass_iou_metric.update_state(
segmentation_labels, outputs['segmentation_outputs'])
......@@ -400,15 +402,18 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
self.coco_metric.reset_states()
self.segmentation_perclass_iou_metric.reset_states()
state = [self.coco_metric, self.segmentation_perclass_iou_metric]
state = [self.segmentation_perclass_iou_metric]
if self.task_config.use_coco_metrics:
self.coco_metric.reset_states()
state.append(self.coco_metric)
if self.task_config.model.generate_panoptic_masks:
state += [self.panoptic_quality_metric]
self.panoptic_quality_metric.reset_states()
state.append(self.panoptic_quality_metric)
self.coco_metric.update_state(
step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
if self.task_config.use_coco_metrics:
self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
if self.task_config.model.generate_panoptic_masks:
self.panoptic_quality_metric.update_state(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册