提交 28c4443c 编写于 作者: Q qiaolongfei

add test for remove duplicated init op

上级 8e9f338c
......@@ -6,4 +6,4 @@ add_subdirectory(pybind)
add_subdirectory(string)
add_subdirectory(recordio)
# NOTE: please add subdirectory inference at last.
add_subdirectory(inference)
#add_subdirectory(inference)
......@@ -72,9 +72,10 @@ def _is_inited_by(block, var, init_op_types):
def _is_duplicated_init_op(op1, op2):
if op1.block == op2.block and \
op1.type == op2.type and \
op1.input_arg_names == op2.output_arg_names and \
op1.input_arg_names == op2.input_arg_names and \
op1.output_arg_names == op2.output_arg_names and \
op1.idx != op2.idx and \
op1.all_attrs == op2.all_attrs:
op1.all_attrs() == op2.all_attrs():
return True
return False
......
......@@ -27,6 +27,7 @@ class TestConstantInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -43,6 +44,7 @@ class TestConstantInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -61,6 +63,7 @@ class TestUniformInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -80,6 +83,7 @@ class TestUniformInitializer(unittest.TestCase):
program = framework.Program()
program.random_seed = 123
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -102,6 +106,7 @@ class TestUniformInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -115,6 +120,32 @@ class TestUniformInitializer(unittest.TestCase):
self.assertAlmostEqual(init_op.attr('max'), 3.1, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 123)
def test_uniform_initializer_two_op(self):
"""Test uniform initializer with supplied attributes
"""
program = framework.Program()
block = program.global_block()
for i in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.UniformInitializer(-4.2, float(i), 123))
self.assertEqual(len(block.ops), 2)
init_op0 = block.ops[0]
self.assertEqual(init_op0.type, 'uniform_random')
self.assertAlmostEqual(init_op0.attr('min'), -4.2, delta=DELTA)
self.assertAlmostEqual(init_op0.attr('max'), 1.0, delta=DELTA)
self.assertEqual(init_op0.attr('seed'), 123)
self.assertEqual(len(block.ops), 2)
init_op1 = block.ops[1]
self.assertEqual(init_op1.type, 'uniform_random')
self.assertAlmostEqual(init_op1.attr('min'), -4.2, delta=DELTA)
self.assertAlmostEqual(init_op1.attr('max'), 0.0, delta=DELTA)
self.assertEqual(init_op1.attr('seed'), 123)
class TestNormalInitializer(unittest.TestCase):
def test_normal_initializer_default_value(self):
......@@ -122,6 +153,7 @@ class TestNormalInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -140,6 +172,7 @@ class TestNormalInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -161,6 +194,7 @@ class TestXavierInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
param = block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -181,6 +215,7 @@ class TestXavierInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
param = block.create_parameter(
dtype="float32",
shape=[5, 10, 15, 20],
......@@ -203,6 +238,7 @@ class TestXavierInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
param = block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -223,6 +259,7 @@ class TestXavierInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
param = block.create_parameter(
dtype="float32",
shape=[5, 10, 15, 20],
......@@ -244,6 +281,7 @@ class TestXavierInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -267,6 +305,7 @@ class TestMSRAInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
param = block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -287,6 +326,7 @@ class TestMSRAInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
param = block.create_parameter(
dtype="float32",
shape=[5, 10, 15, 20],
......@@ -308,6 +348,7 @@ class TestMSRAInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
param = block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -328,6 +369,7 @@ class TestMSRAInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
param = block.create_parameter(
dtype="float32",
shape=[5, 10, 15, 20],
......@@ -348,6 +390,7 @@ class TestMSRAInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[5, 10],
......@@ -370,6 +413,7 @@ class TestMSRAInitializer(unittest.TestCase):
"""
program = framework.Program()
block = program.global_block()
for _ in range(2):
block.create_parameter(
dtype="float32",
shape=[8, 1, 3, 3],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册