diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index f26b797db6425f487d7b189b9ef76893bb6d3a1b..d88792d70aa4e2ad0f382da90cff3c826f57b7f2 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -2796,13 +2796,11 @@ class MeanIoU(Metric): super(MeanIoU, self).__init__(name=name, dtype=dtype) self.num_classes = num_classes - # Variable to accumulate the predictions in the confusion matrix. Setting - # the type to be `float64` as required by confusion_matrix_ops. + # Variable to accumulate the predictions in the confusion matrix. self.total_cm = self.add_weight( 'total_confusion_matrix', shape=(num_classes, num_classes), - initializer=init_ops.zeros_initializer, - dtype=dtypes.float64) + initializer=init_ops.zeros_initializer) def update_state(self, y_true, y_pred, sample_weight=None): """Accumulates the confusion matrix statistics. @@ -2839,7 +2837,7 @@ class MeanIoU(Metric): y_pred, self.num_classes, weights=sample_weight, - dtype=dtypes.float64) + dtype=self._dtype) return self.total_cm.assign_add(current_cm) def result(self):