提交 39c526d4 编写于 作者: M minqiyang

Port test_dist_transpiler to it

上级 efe88ab9
...@@ -21,6 +21,7 @@ import paddle.fluid as fluid ...@@ -21,6 +21,7 @@ import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import delete_ops from paddle.fluid.transpiler.distribute_transpiler import delete_ops
import traceback import traceback
import collections import collections
import six
class TranspilerTest(unittest.TestCase): class TranspilerTest(unittest.TestCase):
...@@ -644,18 +645,18 @@ class TestLoadSliceVar(TranspilerTest): ...@@ -644,18 +645,18 @@ class TestLoadSliceVar(TranspilerTest):
self.assertTrue(pserver._slice_vars_and_attrs) self.assertTrue(pserver._slice_vars_and_attrs)
self.assertTrue(pserver2._slice_vars_and_attrs) self.assertTrue(pserver2._slice_vars_and_attrs)
for idx in xrange(len(pserver._slice_vars_and_attrs)): for idx in six.moves.xrange(len(pserver._slice_vars_and_attrs)):
self.assertEqual(pserver._slice_vars_and_attrs[idx][0], self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_attrs[idx][0]) pserver2._slice_vars_and_attrs[idx][0])
total_numel = reduce(lambda x, y: x * y, total_numel = six.moves.reduce(
pserver._slice_vars_and_attrs[idx][0].shape) lambda x, y: x * y, pserver._slice_vars_and_attrs[idx][0].shape)
self.assertEqual( self.assertEqual(
total_numel, total_numel,
reduce(lambda x, y: x * y, six.moves.reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][2].shape) + reduce( pserver._slice_vars_and_attrs[idx][2].shape) +
lambda x, y: x * y, six.moves.reduce(lambda x, y: x * y,
pserver2._slice_vars_and_attrs[idx][2].shape)) pserver2._slice_vars_and_attrs[idx][2].shape))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册