提交 50b8ec05 编写于 作者: C chengduoZH

fix unit test

上级 3416f5e0
...@@ -62,6 +62,7 @@ class TestPool2d_Op(OpTest): ...@@ -62,6 +62,7 @@ class TestPool2d_Op(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.pool_type != "max":
self.check_grad(set(['Input']), 'Output', max_relative_error=0.07) self.check_grad(set(['Input']), 'Output', max_relative_error=0.07)
def initTestCase(self): def initTestCase(self):
...@@ -84,15 +85,16 @@ class TestCase2(TestPool2d_Op): ...@@ -84,15 +85,16 @@ class TestCase2(TestPool2d_Op):
self.paddings = [1, 1] self.paddings = [1, 1]
# class TestCase1(TestPool2d_Op): class TestCase1(TestPool2d_Op):
# def initTestCase(self): def initTestCase(self):
# self.op_type = "pool2d" self.op_type = "pool2d"
# self.pool_type = "max" self.pool_type = "max"
# self.pool2D_forward_naive = max_pool2D_forward_naive self.pool2D_forward_naive = max_pool2D_forward_naive
# self.shape = [2, 3, 5, 5] self.shape = [2, 3, 5, 5]
# self.ksize = [3, 3] self.ksize = [3, 3]
# self.strides = [1, 1] self.strides = [1, 1]
# self.paddings = [1, 1] self.paddings = [1, 1]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -72,6 +72,7 @@ class TestPool3d_Op(OpTest): ...@@ -72,6 +72,7 @@ class TestPool3d_Op(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.pool_type != "max":
self.check_grad(set(['Input']), 'Output', max_relative_error=0.07) self.check_grad(set(['Input']), 'Output', max_relative_error=0.07)
def initTestCase(self): def initTestCase(self):
...@@ -94,15 +95,16 @@ class TestCase1(TestPool3d_Op): ...@@ -94,15 +95,16 @@ class TestCase1(TestPool3d_Op):
self.paddings = [1, 1, 1] self.paddings = [1, 1, 1]
# class TestCase2(TestPool3d_Op): class TestCase2(TestPool3d_Op):
# def initTestCase(self): def initTestCase(self):
# self.op_type = "pool3d" self.op_type = "pool3d"
# self.pool_type = "max" self.pool_type = "max"
# self.pool3D_forward_naive = max_pool3D_forward_naive self.pool3D_forward_naive = max_pool3D_forward_naive
# self.shape = [2, 3, 5, 5, 5] self.shape = [2, 3, 5, 5, 5]
# self.ksize = [3, 3, 3] self.ksize = [3, 3, 3]
# self.strides = [1, 1, 1] self.strides = [1, 1, 1]
# self.paddings = [1, 1, 1] self.paddings = [1, 1, 1]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册