提交 fe4110f2 编写于 作者: F François Chollet

Reenable Precision and Recall with TF1

上级 a1be7c3e
......@@ -1412,12 +1412,6 @@ class Precision(Metric):
raise RuntimeError(
'`top_k` argument for `Precision` metric is currently supported '
'only with TensorFlow backend.')
import tensorflow as tf
if not tf.__version__.startswith('2.'):
raise RuntimeError(
'`top_k` argument for `Precision` metric is currently '
'supported only with TensorFlow 2.0. Your version '
'of TensorFlow is ' + str(tf.__version__))
self.top_k = top_k
self.class_id = class_id
......@@ -1528,12 +1522,6 @@ class Recall(Metric):
raise RuntimeError(
'`top_k` argument for `Recall` metric is currently supported only '
'with TensorFlow backend.')
import tensorflow as tf
if not tf.__version__.startswith('2.'):
raise RuntimeError(
'`top_k` argument for `Precision` metric is currently '
'supported only with TensorFlow 2.0. Your version '
'of TensorFlow is ' + str(tf.__version__))
self.top_k = top_k
self.class_id = class_id
......@@ -1897,12 +1885,6 @@ class MeanIoU(BaseMeanIoU):
raise RuntimeError(
'`MeanIoU` metric is currently supported only '
'with TensorFlow backend and TF version >= 2.0.0.')
import tensorflow as tf
if not tf.__version__.startswith('2.'):
raise RuntimeError(
'`top_k` argument for `Precision` metric is currently '
'supported only with TensorFlow 2.0. Your version '
'of TensorFlow is ' + str(tf.__version__))
super(MeanIoU, self).__init__(num_classes, name=name, dtype=dtype)
......
......@@ -594,8 +594,6 @@ class TestAUC(object):
metrics.AUC(summation_method='Invalid')
@pytest.mark.skipif(not tf.__version__.startswith('2.'),
reason='Requires TF 2')
class TestPrecisionTest(object):
def test_config(self):
......@@ -745,8 +743,6 @@ class TestPrecisionTest(object):
assert np.isclose(0, K.eval(p_obj.false_positives))
@pytest.mark.skipif(not tf.__version__.startswith('2.'),
reason='Requires TF 2')
class TestRecall(object):
def test_config(self):
......
......@@ -77,6 +77,10 @@ def test_sensitivity_metrics():
@pytest.mark.skipif(K.backend() != 'tensorflow', reason='requires tensorflow')
def test_mean_iou():
import tensorflow as tf
if not tf.__version__.startswith('2.'):
return
model = Sequential([Dense(1, input_shape=(3,))])
model.compile('rmsprop', 'mse', metrics=[metrics.MeanIoU(2)])
x = np.random.random((10, 3))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册