提交 62df65c7 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add dtype argument to Mean and Accuracy object-oriented metrics.

PiperOrigin-RevId: 172957714
上级 29c7b465
......@@ -198,13 +198,19 @@ class Mean(Metric):
# TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
# Or defaults to type of the input if it is tf.float32, else tf.float64?
def build(self, values, weights=None):
del values, weights # build() does not use call's arguments
def __init__(self, name=None, dtype=dtypes.float64):
super(Mean, self).__init__(name=name)
self.dtype = dtype
def build(self, *args, **kwargs):
# build() does not use call's arguments, by using *args, **kwargs
# we make it easier to inherit from Mean().
del args, kwargs
self.numer = self.add_variable(name="numer", shape=(),
dtype=dtypes.float64,
dtype=self.dtype,
initializer=init_ops.zeros_initializer)
self.denom = self.add_variable(name="denom", shape=(),
dtype=dtypes.float64,
dtype=self.dtype,
initializer=init_ops.zeros_initializer)
def call(self, values, weights=None):
......@@ -219,13 +225,13 @@ class Mean(Metric):
"""
if weights is None:
self.denom.assign_add(
math_ops.cast(array_ops.size(values), dtypes.float64))
math_ops.cast(array_ops.size(values), self.dtype))
values = math_ops.reduce_sum(values)
self.numer.assign_add(math_ops.cast(values, dtypes.float64))
self.numer.assign_add(math_ops.cast(values, self.dtype))
else:
weights = math_ops.cast(weights, dtypes.float64)
weights = math_ops.cast(weights, self.dtype)
self.denom.assign_add(math_ops.reduce_sum(weights))
values = math_ops.cast(values, dtypes.float64) * weights
values = math_ops.cast(values, self.dtype) * weights
self.numer.assign_add(math_ops.reduce_sum(values))
def result(self):
......@@ -235,9 +241,8 @@ class Mean(Metric):
class Accuracy(Mean):
"""Calculates how often `predictions` matches `labels`."""
def build(self, labels, predictions, weights=None):
del labels, predictions, weights
super(Accuracy, self).build(None) # Arguments are unused
def __init__(self, name=None, dtype=dtypes.float64):
super(Accuracy, self).__init__(name=name, dtype=dtype)
def call(self, labels, predictions, weights=None):
"""Accumulate accuracy statistics.
......
......@@ -34,6 +34,8 @@ class MetricsTest(test.TestCase):
m(1000)
m([10000.0, 100000.0])
self.assertEqual(111111.0/6, m.result().numpy())
self.assertEqual(dtypes.float64, m.dtype)
self.assertEqual(dtypes.float64, m.result().dtype)
def testWeightedMean(self):
m = metrics.Mean()
......@@ -41,6 +43,14 @@ class MetricsTest(test.TestCase):
m([500000, 5000, 500]) # weights of 1 each
self.assertNear(535521/4.5, m.result().numpy(), 0.001)
def testMeanDtype(self):
# Can override default dtype of float64.
m = metrics.Mean(dtype=dtypes.float32)
m([0, 2])
self.assertEqual(1, m.result().numpy())
self.assertEqual(dtypes.float32, m.dtype)
self.assertEqual(dtypes.float32, m.result().dtype)
def testAccuracy(self):
m = metrics.Accuracy()
m([0, 1, 2, 3], [0, 0, 0, 0]) # 1 correct
......@@ -49,6 +59,8 @@ class MetricsTest(test.TestCase):
m([6], [6]) # 1 correct
m([7], [2]) # 0 correct
self.assertEqual(3.0/8, m.result().numpy())
self.assertEqual(dtypes.float64, m.dtype)
self.assertEqual(dtypes.float64, m.result().dtype)
def testWeightedAccuracy(self):
m = metrics.Accuracy()
......@@ -60,6 +72,14 @@ class MetricsTest(test.TestCase):
m([7], [2]) # 0 correct, weight 1
self.assertEqual(2.5/5, m.result().numpy())
def testAccuracyDtype(self):
# Can override default dtype of float64.
m = metrics.Accuracy(dtype=dtypes.float32)
m([0, 0], [0, 1])
self.assertEqual(0.5, m.result().numpy())
self.assertEqual(dtypes.float32, m.dtype)
self.assertEqual(dtypes.float32, m.result().dtype)
def testTwoMeans(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册