提交 99820e70 编写于 作者: E Evan Rosen 提交者: TensorFlower Gardener

Avoid serializing generated thresholds for AUC metrics.

PiperOrigin-RevId: 381160816
上级 c9b5ba76
...@@ -2116,6 +2116,7 @@ class AUC(Metric): ...@@ -2116,6 +2116,7 @@ class AUC(Metric):
summation_method, list(metrics_utils.AUCSummationMethod))) summation_method, list(metrics_utils.AUCSummationMethod)))
# Update properties. # Update properties.
self._init_from_thresholds = thresholds is not None
if thresholds is not None: if thresholds is not None:
# If specified, use the supplied thresholds. # If specified, use the supplied thresholds.
self.num_thresholds = len(thresholds) + 2 self.num_thresholds = len(thresholds) + 2
...@@ -2444,13 +2445,15 @@ class AUC(Metric): ...@@ -2444,13 +2445,15 @@ class AUC(Metric):
'num_thresholds': self.num_thresholds, 'num_thresholds': self.num_thresholds,
'curve': self.curve.value, 'curve': self.curve.value,
'summation_method': self.summation_method.value, 'summation_method': self.summation_method.value,
# We remove the endpoint thresholds as an inverse of how the thresholds
# were initialized. This ensures that a metric initialized from this
# config has the same thresholds.
'thresholds': self.thresholds[1:-1],
'multi_label': self.multi_label, 'multi_label': self.multi_label,
'label_weights': label_weights 'label_weights': label_weights
} }
# optimization to avoid serializing a large number of generated thresholds
if self._init_from_thresholds:
# We remove the endpoint thresholds as an inverse of how the thresholds
# were initialized. This ensures that a metric initialized from this
# config has the same thresholds.
config['thresholds'] = self.thresholds[1:-1]
base_config = super(AUC, self).get_config() base_config = super(AUC, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
...@@ -1237,6 +1237,7 @@ class AUCTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1237,6 +1237,7 @@ class AUCTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(auc_obj.summation_method, self.assertEqual(auc_obj.summation_method,
metrics_utils.AUCSummationMethod.MAJORING) metrics_utils.AUCSummationMethod.MAJORING)
old_config = auc_obj.get_config() old_config = auc_obj.get_config()
self.assertNotIn('thresholds', old_config)
self.assertDictEqual(old_config, json.loads(json.dumps(old_config))) self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))
# Check save and restore config. # Check save and restore config.
...@@ -1249,6 +1250,7 @@ class AUCTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1249,6 +1250,7 @@ class AUCTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(auc_obj2.summation_method, self.assertEqual(auc_obj2.summation_method,
metrics_utils.AUCSummationMethod.MAJORING) metrics_utils.AUCSummationMethod.MAJORING)
new_config = auc_obj2.get_config() new_config = auc_obj2.get_config()
self.assertNotIn('thresholds', new_config)
self.assertDictEqual(old_config, new_config) self.assertDictEqual(old_config, new_config)
self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds) self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册