提交 66569e42 编写于 作者: Z Zhenyu Tan 提交者: TensorFlower Gardener

Let MeanIOU use the user-controlled dtype.

PiperOrigin-RevId: 341160520
Change-Id: Ia66eccb3ce02151db3af68f41f07db527592b0b3
上级 722bf167
......@@ -2796,13 +2796,11 @@ class MeanIoU(Metric):
super(MeanIoU, self).__init__(name=name, dtype=dtype)
self.num_classes = num_classes
# Variable to accumulate the predictions in the confusion matrix. Setting
# the type to be `float64` as required by confusion_matrix_ops.
# Variable to accumulate the predictions in the confusion matrix.
self.total_cm = self.add_weight(
'total_confusion_matrix',
shape=(num_classes, num_classes),
initializer=init_ops.zeros_initializer,
dtype=dtypes.float64)
initializer=init_ops.zeros_initializer)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the confusion matrix statistics.
......@@ -2839,7 +2837,7 @@ class MeanIoU(Metric):
y_pred,
self.num_classes,
weights=sample_weight,
dtype=dtypes.float64)
dtype=self._dtype)
return self.total_cm.assign_add(current_cm)
def result(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册