From 5e7b169ed1a1a7b90139f386a766da31088de7f3 Mon Sep 17 00:00:00 2001 From: whs Date: Fri, 29 Jun 2018 11:34:26 +0800 Subject: [PATCH] Fix python api of mean iou op. (#11797) (#11827) * Fix mean iou op. * Fix crop layer test. * Fix unitest * fix unitest. --- python/paddle/fluid/layers/nn.py | 10 +++++----- python/paddle/fluid/tests/unittests/test_layers.py | 11 ++++++++++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 2979ff3057..6c48dc3eba 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5015,12 +5015,12 @@ def mean_iou(input, label, num_classes): out_correct = helper.create_tmp_variable(dtype='int32') helper.append_op( type="mean_iou", - inputs={"predictions": input, - "labels": label}, + inputs={"Predictions": input, + "Labels": label}, outputs={ - "out_mean_iou": out_mean_iou, - "out_wrong": out_wrong, - "out_correct": out_correct + "OutMeanIou": out_mean_iou, + "OutWrong": out_wrong, + "OutCorrect": out_correct }, attrs={"num_classes": num_classes}) return out_mean_iou, out_wrong, out_correct diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 82074955fa..9d4b2d4434 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -401,7 +401,7 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(output) print(str(program)) - def test_maxout(self): + def test_crop(self): program = Program() with program_guard(program): x = layers.data(name='x', shape=[3, 5], dtype="float32") @@ -410,6 +410,15 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(output) print(str(program)) + def test_mean_iou(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[16], dtype='float32') + y = layers.data(name='label', shape=[1], dtype='int64') + iou = layers.mean_iou(x, y, 2) + self.assertIsNotNone(iou) + print(str(program)) + if __name__ == '__main__': unittest.main() -- GitLab