提交 cb043911 编写于 作者: T Thomas O'Malley 提交者: TensorFlower Gardener

Add MetricsContainer.weighted_metrics and MetricsContainer.unweighted_metrics

properties to help distinguish between metrics that should and shouldn't be
passed sample_weight argument.
Note these properties are set to None before Model.fit is called, since metrics
are potentially broadcast to match the structure of data seen in Model.fit.

PiperOrigin-RevId: 339892649
Change-Id: I0abffae08efde2b8adc58014ef205d318d66a9ab
上级 6f3ef003
......@@ -292,11 +292,25 @@ class MetricsContainer(Container):
@property
def metrics(self):
"""Metrics created by this container."""
"""All metrics in this container."""
if not self._built:
return []
return self._metrics_in_order
@property
def unweighted_metrics(self):
"""Metrics in this container that should not be passed `sample_weight`."""
if not self._built:
return None
return nest.flatten(self._metrics)
@property
def weighted_metrics(self):
"""Metrics in this container that should be passed `sample_weight`."""
if not self._built:
return None
return nest.flatten(self._weighted_metrics)
def build(self, y_pred, y_true):
"""One-time setup of metric objects."""
super(MetricsContainer, self).build(y_pred)
......
......@@ -420,6 +420,18 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(acc_metric_2.result().numpy(), 0.)
self.assertEqual(acc_metric_2._fn, metrics_mod.binary_accuracy)
weighted_metrics = metric_container.weighted_metrics
self.assertLen(weighted_metrics, 2)
self.assertEqual(weighted_metrics[0].name, 'output_1_accuracy')
self.assertEqual(weighted_metrics[1].name, 'output_2_accuracy')
unweighted_metrics = metric_container.unweighted_metrics
self.assertLen(unweighted_metrics, 4)
self.assertEqual(unweighted_metrics[0].name, 'output_1_mse')
self.assertEqual(unweighted_metrics[1].name, 'output_1_mae')
self.assertEqual(unweighted_metrics[2].name, 'output_2_mse')
self.assertEqual(unweighted_metrics[3].name, 'output_2_mae')
def test_metric_dict(self):
metric_container = compile_utils.MetricsContainer(
metrics={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册