From eacac49bcdd281f68d9fac2cba9dee2b245d0d17 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 1 Jun 2018 15:15:26 +0800 Subject: [PATCH] 1. update test_split_var: replace split with slice --- .../{test_split_var.py => test_slice_var.py} | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) rename python/paddle/fluid/tests/unittests/{test_split_var.py => test_slice_var.py} (85%) diff --git a/python/paddle/fluid/tests/unittests/test_split_var.py b/python/paddle/fluid/tests/unittests/test_slice_var.py similarity index 85% rename from python/paddle/fluid/tests/unittests/test_split_var.py rename to python/paddle/fluid/tests/unittests/test_slice_var.py index 157def9b56..82305b23a1 100644 --- a/python/paddle/fluid/tests/unittests/test_split_var.py +++ b/python/paddle/fluid/tests/unittests/test_slice_var.py @@ -14,14 +14,14 @@ import math import unittest -from paddle.fluid.transpiler.distribute_transpiler import split_variable +from paddle.fluid.transpiler.distribute_transpiler import slice_variable import paddle.fluid as fluid import paddle.fluid.core as core import random -class TestSplitVar(unittest.TestCase): - def check_split_output(self, shapes, expected_sizes, min_size): +class TestSliceVar(unittest.TestCase): + def check_slice_output(self, shapes, expected_sizes, min_size): var_list = [] program = fluid.Program() for shape in shapes: @@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase): # dtype=core.VarDesc.VarType.LOD_TENSOR, shape=shape) var_list.append(var) - blocks = split_variable(var_list, 10, min_size) + blocks = slice_variable(var_list, 10, min_size) all_sizes = [] for s in expected_sizes: for s2 in s: @@ -49,7 +49,7 @@ class TestSplitVar(unittest.TestCase): [1150, 1150, 1150, 1150, 1150, 1150, 1100] ] - self.check_split_output(shapes, expected_sizes, 1024) + self.check_slice_output(shapes, expected_sizes, 1024) def test_check_output_8k(self): shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10], @@ -57,7 +57,7 @@ class TestSplitVar(unittest.TestCase): expected_sizes = [[15], [1024], [10976, 10976], [8160], [8000], [35937, 35937, 35937, 35937, 35937, 35937]] - self.check_split_output(shapes, expected_sizes, 8192) + self.check_slice_output(shapes, expected_sizes, 8192) if __name__ == '__main__': -- GitLab