From 6fc44800ed800e78822c6af5750e202c041d9173 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 30 Sep 2017 15:20:44 +0800 Subject: [PATCH] fix unit test --- .../v2/framework/tests/test_pool_max_op.py | 72 ++++++++++++++++++- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py index ffc345198da..17028c3bf65 100644 --- a/python/paddle/v2/framework/tests/test_pool_max_op.py +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -98,6 +98,28 @@ class TestMaxPoolWithIndex_Op(OpTest): # def test_check_grad(self): # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) + def initTestCase(self): + self.global_pool = True + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase1(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase2(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = False self.op_type = "maxPool3dWithIndex" @@ -108,7 +130,18 @@ class TestMaxPoolWithIndex_Op(OpTest): self.paddings = [1, 1, 1] -class TestCase1(TestMaxPoolWithIndex_Op): +class TestCase3(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + + +class TestCase4(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = True self.op_type = "maxPool3dWithIndex" @@ -116,10 +149,21 @@ class TestCase1(TestMaxPoolWithIndex_Op): self.shape = [2, 3, 5, 5, 5] self.ksize = [3, 3, 3] self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase5(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] self.paddings = [0, 0, 0] -class TestCase2(TestMaxPoolWithIndex_Op): +class TestCase6(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = False self.op_type = "maxPool2dWithIndex" @@ -130,7 +174,18 @@ class TestCase2(TestMaxPoolWithIndex_Op): self.paddings = [1, 1] -class TestCase3(TestMaxPoolWithIndex_Op): +class TestCase7(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "maxPool2dWithIndex" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +class TestCase8(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = True self.op_type = "maxPool2dWithIndex" @@ -138,6 +193,17 @@ class TestCase3(TestMaxPoolWithIndex_Op): self.shape = [2, 3, 5, 5] self.ksize = [3, 3] self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase9(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "maxPool2dWithIndex" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [2, 2] self.paddings = [0, 0] -- GitLab