提交 eacac49b 编写于 作者: M minqiyang

1. update test_split_var: replace split with slice

上级 b33ea7be
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
import math import math
import unittest import unittest
from paddle.fluid.transpiler.distribute_transpiler import split_variable from paddle.fluid.transpiler.distribute_transpiler import slice_variable
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import random import random
class TestSplitVar(unittest.TestCase): class TestSliceVar(unittest.TestCase):
def check_split_output(self, shapes, expected_sizes, min_size): def check_slice_output(self, shapes, expected_sizes, min_size):
var_list = [] var_list = []
program = fluid.Program() program = fluid.Program()
for shape in shapes: for shape in shapes:
...@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR, # dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape) shape=shape)
var_list.append(var) var_list.append(var)
blocks = split_variable(var_list, 10, min_size) blocks = slice_variable(var_list, 10, min_size)
all_sizes = [] all_sizes = []
for s in expected_sizes: for s in expected_sizes:
for s2 in s: for s2 in s:
...@@ -49,7 +49,7 @@ class TestSplitVar(unittest.TestCase): ...@@ -49,7 +49,7 @@ class TestSplitVar(unittest.TestCase):
[1150, 1150, 1150, 1150, 1150, 1150, 1100] [1150, 1150, 1150, 1150, 1150, 1150, 1100]
] ]
self.check_split_output(shapes, expected_sizes, 1024) self.check_slice_output(shapes, expected_sizes, 1024)
def test_check_output_8k(self): def test_check_output_8k(self):
shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10], shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10],
...@@ -57,7 +57,7 @@ class TestSplitVar(unittest.TestCase): ...@@ -57,7 +57,7 @@ class TestSplitVar(unittest.TestCase):
expected_sizes = [[15], [1024], [10976, 10976], [8160], [8000], expected_sizes = [[15], [1024], [10976, 10976], [8160], [8000],
[35937, 35937, 35937, 35937, 35937, 35937]] [35937, 35937, 35937, 35937, 35937, 35937]]
self.check_split_output(shapes, expected_sizes, 8192) self.check_slice_output(shapes, expected_sizes, 8192)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册