提交 9a7084c1 编写于 作者: Q qiaolongfei

fix test_dist_transpiler

上级 25ab4e0d
...@@ -6,4 +6,4 @@ add_subdirectory(pybind) ...@@ -6,4 +6,4 @@ add_subdirectory(pybind)
add_subdirectory(string) add_subdirectory(string)
add_subdirectory(recordio) add_subdirectory(recordio)
# NOTE: please add subdirectory inference at last. # NOTE: please add subdirectory inference at last.
add_subdirectory(inference) #add_subdirectory(inference)
...@@ -73,9 +73,18 @@ class TranspilerTest(unittest.TestCase): ...@@ -73,9 +73,18 @@ class TranspilerTest(unittest.TestCase):
return self.transpiler return self.transpiler
def transpiler_test_impl(self):
pass
class TestBasicModel(TranspilerTest):
def test_transpiler(self): def test_transpiler(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
self.transpiler_test_impl()
class TestBasicModel(TranspilerTest):
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
pserver2, startup2 = self.get_pserver(self.pserver2_ep) pserver2, startup2 = self.get_pserver(self.pserver2_ep)
...@@ -123,7 +132,7 @@ class TestBasicModel(TranspilerTest): ...@@ -123,7 +132,7 @@ class TestBasicModel(TranspilerTest):
class TestBasicModelWithLargeBlockSize(TranspilerTest): class TestBasicModelWithLargeBlockSize(TranspilerTest):
def test_transpiler(self): def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.min_block_size = 1048576 config.min_block_size = 1048576
...@@ -148,7 +157,7 @@ class TestBasicModelWithLargeBlockSize(TranspilerTest): ...@@ -148,7 +157,7 @@ class TestBasicModelWithLargeBlockSize(TranspilerTest):
["sum", "scale", "sgd"]) ["sum", "scale", "sgd"])
# confirm startup program # confirm startup program
self.assertEqual([op.type for op in startup.global_block().ops], self.assertEqual([op.type for op in startup.global_block().ops],
["fill_constant", "fill_constant", "fill_constant"]) ["fill_constant", "fill_constant"])
# the variable #fc_w will be split into two blocks # the variable #fc_w will be split into two blocks
fc_w_var = startup2.global_block().var("fc_w") fc_w_var = startup2.global_block().var("fc_w")
self.assertEqual(fc_w_var.shape, (1000L, 1000L)) self.assertEqual(fc_w_var.shape, (1000L, 1000L))
...@@ -177,7 +186,7 @@ class TestNoSliceVar(TranspilerTest): ...@@ -177,7 +186,7 @@ class TestNoSliceVar(TranspilerTest):
def setUp(self): def setUp(self):
super(TestNoSliceVar, self).setUp() super(TestNoSliceVar, self).setUp()
def test_transpiler(self): def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.slice_var_up = False config.slice_var_up = False
...@@ -212,7 +221,7 @@ class TestLRDecay(TranspilerTest): ...@@ -212,7 +221,7 @@ class TestLRDecay(TranspilerTest):
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return return
def test_transpiler(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
trainer = self.get_trainer() trainer = self.get_trainer()
...@@ -242,7 +251,7 @@ class TestLRDecayConditional(TranspilerTest): ...@@ -242,7 +251,7 @@ class TestLRDecayConditional(TranspilerTest):
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return return
def test_transpiler(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
trainer = self.get_trainer() trainer = self.get_trainer()
...@@ -291,7 +300,7 @@ class TestL2Decay(TranspilerTest): ...@@ -291,7 +300,7 @@ class TestL2Decay(TranspilerTest):
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return return
def test_transpiler(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
trainer = self.get_trainer() trainer = self.get_trainer()
...@@ -326,7 +335,7 @@ class TestL2DecayWithPiecewise(TranspilerTest): ...@@ -326,7 +335,7 @@ class TestL2DecayWithPiecewise(TranspilerTest):
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return return
def test_transpiler(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
trainer = self.get_trainer() trainer = self.get_trainer()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册