提交 5a837767 编写于 作者: M minqiyang

Port test_desc_clone

上级 50d66a07
...@@ -110,7 +110,7 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers): ...@@ -110,7 +110,7 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
def operator_equal(a, b): def operator_equal(a, b):
for k, v in a.__dict__.iteritems(): for k, v in six.iteritems(a.__dict__):
if isinstance(v, fluid.framework.Program) or \ if isinstance(v, fluid.framework.Program) or \
isinstance(v, fluid.framework.Block): isinstance(v, fluid.framework.Block):
continue continue
...@@ -120,8 +120,8 @@ def operator_equal(a, b): ...@@ -120,8 +120,8 @@ def operator_equal(a, b):
raise ValueError("In operator_equal not equal:{0}\n".format(k)) raise ValueError("In operator_equal not equal:{0}\n".format(k))
elif isinstance(v, collections.OrderedDict): elif isinstance(v, collections.OrderedDict):
v0 = sorted(v.iteritems(), key=lambda x: x[0]) v0 = sorted(six.iteritems(v), key=lambda x: x[0])
v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0]) v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0])
if v0 != v1: if v0 != v1:
raise ValueError("In operator_equal not equal:{0}\n".format(k)) raise ValueError("In operator_equal not equal:{0}\n".format(k))
...@@ -133,7 +133,7 @@ def operator_equal(a, b): ...@@ -133,7 +133,7 @@ def operator_equal(a, b):
def block_equal(a, b): def block_equal(a, b):
for k, v in a.__dict__.iteritems(): for k, v in six.iteritems(a.__dict__):
if isinstance(v, core.ProgramDesc) or isinstance( if isinstance(v, core.ProgramDesc) or isinstance(
v, fluid.framework.Program) or isinstance(v, core.BlockDesc): v, fluid.framework.Program) or isinstance(v, core.BlockDesc):
continue continue
...@@ -145,8 +145,8 @@ def block_equal(a, b): ...@@ -145,8 +145,8 @@ def block_equal(a, b):
assert (len(a.ops) == len(b.ops)) assert (len(a.ops) == len(b.ops))
elif isinstance(v, collections.OrderedDict): elif isinstance(v, collections.OrderedDict):
v0 = sorted(v.iteritems(), key=lambda x: x[0]) v0 = sorted(six.iteritems(v), key=lambda x: x[0])
v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0]) v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0])
if v0 != v1: if v0 != v1:
raise ValueError("In block_equal not equal:{0}\n".format(k)) raise ValueError("In block_equal not equal:{0}\n".format(k))
...@@ -158,7 +158,7 @@ def block_equal(a, b): ...@@ -158,7 +158,7 @@ def block_equal(a, b):
def program_equal(a, b): def program_equal(a, b):
for k, v in a.__dict__.iteritems(): for k, v in six.iteritems(a.__dict__):
if isinstance(v, core.ProgramDesc): if isinstance(v, core.ProgramDesc):
continue continue
......
...@@ -21,9 +21,6 @@ from op_test import OpTest ...@@ -21,9 +21,6 @@ from op_test import OpTest
class PReluTest(OpTest): class PReluTest(OpTest):
def setUp(self): def setUp(self):
print('setUp')
import sys
sys.stdout.flush()
self.op_type = "prelu" self.op_type = "prelu"
self.initTestCase() self.initTestCase()
x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32") x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32")
...@@ -48,39 +45,19 @@ class PReluTest(OpTest): ...@@ -48,39 +45,19 @@ class PReluTest(OpTest):
assert out_np is not self.inputs['X'] assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np} self.outputs = {'Out': out_np}
def tearDown(self):
print('tearDown')
import sys
sys.stdout.flush()
del self.outputs
del self.inputs
def initTestCase(self): def initTestCase(self):
self.attrs = {'mode': "channel"} self.attrs = {'mode': "channel"}
def test_check_4_output(self): def test_check_output(self):
print('test_check_0_output')
import sys
sys.stdout.flush()
self.check_output() self.check_output()
def test_check_0_grad_2_ignore_x(self): def test_check_grad(self):
print('test_check_2_grad_2_ignore_x')
import sys
sys.stdout.flush()
self.check_grad(['Alpha'], 'Out', no_grad_set=set('X'))
# TODO(minqiyang): remove the order of tests
def test_check_1_grad_1(self):
print('test_check_1_grad_1')
import sys
sys.stdout.flush()
self.check_grad(['X', 'Alpha'], 'Out') self.check_grad(['X', 'Alpha'], 'Out')
def test_check_3_grad_3_ignore_alpha(self): def test_check_grad_ignore_x(self):
print('test_check_3_grad_3_ignore_alpha') self.check_grad(['Alpha'], 'Out', no_grad_set=set('X'))
import sys
sys.stdout.flush() def test_check_grad_ignore_alpha(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Alpha')) self.check_grad(['X'], 'Out', no_grad_set=set('Alpha'))
...@@ -89,14 +66,15 @@ class TestCase1(PReluTest): ...@@ -89,14 +66,15 @@ class TestCase1(PReluTest):
self.attrs = {'mode': "all"} self.attrs = {'mode': "all"}
#class TestCase2(PReluTest): class TestCase2(PReluTest):
# def initTestCase(self): def initTestCase(self):
# self.attrs = {'mode': "channel"} self.attrs = {'mode': "channel"}
#
#
#class TestCase3(PReluTest): class TestCase3(PReluTest):
# def initTestCase(self): def initTestCase(self):
# self.attrs = {'mode': "element"} self.attrs = {'mode': "element"}
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册