From 5a837767512c81f430342c21f65e83a33c17ebd2 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 21 Aug 2018 00:58:58 +0800 Subject: [PATCH] Port test_desc_clone --- .../fluid/tests/unittests/test_desc_clone.py | 14 ++--- .../fluid/tests/unittests/test_prelu_op.py | 52 ++++++------------- 2 files changed, 22 insertions(+), 44 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_desc_clone.py b/python/paddle/fluid/tests/unittests/test_desc_clone.py index aca2911482c..fa6b6795625 100644 --- a/python/paddle/fluid/tests/unittests/test_desc_clone.py +++ b/python/paddle/fluid/tests/unittests/test_desc_clone.py @@ -110,7 +110,7 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers): def operator_equal(a, b): - for k, v in a.__dict__.iteritems(): + for k, v in six.iteritems(a.__dict__): if isinstance(v, fluid.framework.Program) or \ isinstance(v, fluid.framework.Block): continue @@ -120,8 +120,8 @@ def operator_equal(a, b): raise ValueError("In operator_equal not equal:{0}\n".format(k)) elif isinstance(v, collections.OrderedDict): - v0 = sorted(v.iteritems(), key=lambda x: x[0]) - v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0]) + v0 = sorted(six.iteritems(v), key=lambda x: x[0]) + v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0]) if v0 != v1: raise ValueError("In operator_equal not equal:{0}\n".format(k)) @@ -133,7 +133,7 @@ def operator_equal(a, b): def block_equal(a, b): - for k, v in a.__dict__.iteritems(): + for k, v in six.iteritems(a.__dict__): if isinstance(v, core.ProgramDesc) or isinstance( v, fluid.framework.Program) or isinstance(v, core.BlockDesc): continue @@ -145,8 +145,8 @@ def block_equal(a, b): assert (len(a.ops) == len(b.ops)) elif isinstance(v, collections.OrderedDict): - v0 = sorted(v.iteritems(), key=lambda x: x[0]) - v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0]) + v0 = sorted(six.iteritems(v), key=lambda x: x[0]) + v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0]) if v0 != v1: raise ValueError("In block_equal not equal:{0}\n".format(k)) @@ -158,7 +158,7 @@ def block_equal(a, b): def program_equal(a, b): - for k, v in a.__dict__.iteritems(): + for k, v in six.iteritems(a.__dict__): if isinstance(v, core.ProgramDesc): continue diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index dfe278d3009..979be5af3bd 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -21,9 +21,6 @@ from op_test import OpTest class PReluTest(OpTest): def setUp(self): - print('setUp') - import sys - sys.stdout.flush() self.op_type = "prelu" self.initTestCase() x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32") @@ -48,39 +45,19 @@ class PReluTest(OpTest): assert out_np is not self.inputs['X'] self.outputs = {'Out': out_np} - def tearDown(self): - print('tearDown') - import sys - sys.stdout.flush() - del self.outputs - del self.inputs - def initTestCase(self): self.attrs = {'mode': "channel"} - def test_check_4_output(self): - print('test_check_0_output') - import sys - sys.stdout.flush() + def test_check_output(self): self.check_output() - def test_check_0_grad_2_ignore_x(self): - print('test_check_2_grad_2_ignore_x') - import sys - sys.stdout.flush() - self.check_grad(['Alpha'], 'Out', no_grad_set=set('X')) - - # TODO(minqiyang): remove the order of tests - def test_check_1_grad_1(self): - print('test_check_1_grad_1') - import sys - sys.stdout.flush() + def test_check_grad(self): self.check_grad(['X', 'Alpha'], 'Out') - def test_check_3_grad_3_ignore_alpha(self): - print('test_check_3_grad_3_ignore_alpha') - import sys - sys.stdout.flush() + def test_check_grad_ignore_x(self): + self.check_grad(['Alpha'], 'Out', no_grad_set=set('X')) + + def test_check_grad_ignore_alpha(self): self.check_grad(['X'], 'Out', no_grad_set=set('Alpha')) @@ -89,14 +66,15 @@ class TestCase1(PReluTest): self.attrs = {'mode': "all"} -#class TestCase2(PReluTest): -# def initTestCase(self): -# self.attrs = {'mode': "channel"} -# -# -#class TestCase3(PReluTest): -# def initTestCase(self): -# self.attrs = {'mode': "element"} +class TestCase2(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "channel"} + + +class TestCase3(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "element"} + if __name__ == "__main__": unittest.main() -- GitLab