未验证 提交 93e25301 编写于 作者: W whs 提交者: GitHub

Fix python api of mean iou op. (#11797)

* Fix mean iou op.

* Fix crop layer test.

* Fix unitest

* fix unitest.
上级 ef68b913
......@@ -5078,12 +5078,12 @@ def mean_iou(input, label, num_classes):
out_correct = helper.create_tmp_variable(dtype='int32')
helper.append_op(
type="mean_iou",
inputs={"predictions": input,
"labels": label},
inputs={"Predictions": input,
"Labels": label},
outputs={
"out_mean_iou": out_mean_iou,
"out_wrong": out_wrong,
"out_correct": out_correct
"OutMeanIou": out_mean_iou,
"OutWrong": out_wrong,
"OutCorrect": out_correct
},
attrs={"num_classes": num_classes})
return out_mean_iou, out_wrong, out_correct
......
......@@ -401,7 +401,7 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_maxout(self):
def test_crop(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 5], dtype="float32")
......@@ -410,6 +410,15 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_mean_iou(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[16], dtype='float32')
y = layers.data(name='label', shape=[1], dtype='int64')
iou = layers.mean_iou(x, y, 2)
self.assertIsNotNone(iou)
print(str(program))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册