提交 914ab440 编写于 作者: S seiriosPlus

fix UT

上级 b1cc1c3b
...@@ -64,7 +64,7 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): ...@@ -64,7 +64,7 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
sends += 1 sends += 1
if op.type == "sgd": if op.type == "sgd":
sgds += 1 sgds += 1
self.assertEqual(sends, 7) self.assertEqual(sends, 1)
self.assertEqual(sgds, 0) self.assertEqual(sgds, 0)
fleet.init_worker() fleet.init_worker()
...@@ -82,16 +82,11 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): ...@@ -82,16 +82,11 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
paddle.fluid.framework.switch_startup_program(startup_program) paddle.fluid.framework.switch_startup_program(startup_program)
fleet.init(role_maker.PaddleCloudRoleMaker()) fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32') x = paddle.fluid.layers.data(name='x', shape=[1], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = paddle.fluid.layers.square_error_cost(input=x, label=y)
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') avg_cost = paddle.fluid.layers.mean(cost)
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True strategy.a_sync = True
......
...@@ -56,7 +56,7 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
sends += 1 sends += 1
if op.type == "sgd": if op.type == "sgd":
sgds += 1 sgds += 1
self.assertEqual(sends, 6) self.assertEqual(sends, 0)
self.assertEqual(sgds, 0) self.assertEqual(sgds, 0)
fleet.init_worker() fleet.init_worker()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册