From ebbd356421598d50db78c32875dc9ab147c41c9e Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 23 Dec 2021 11:26:37 +0800 Subject: [PATCH] remove unitest for auto_searcher (#38370) --- .../unittests/test_auto_parallel_searcher.py | 35 ------------------- 1 file changed, 35 deletions(-) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py old mode 100644 new mode 100755 index 92d11801902..ed64fa0630f --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py @@ -212,41 +212,6 @@ class TestMLPSearcher(unittest.TestCase): self.assertTrue( check_nonpipeline_enumerater(train_program, process_mesh_topology)) - def test_get_dist_programs(self): - train_program = paddle.static.Program() - startup_program = paddle.static.Program() - loss, train_program, startup_program = mlp_forward(train_program, - startup_program) - process_mesh_topology = [4] - optimizer = paddle.optimizer.Adam( - learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) - valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program( - train_program, process_mesh_topology, False) - from test_auto_parallel_cluster import cluster_json - cluster_json_file = "" - cluster_json_object = json.loads(cluster_json) - with open("./auto_parallel_cluster.json", "w") as cluster_json_file: - json.dump(cluster_json_object, cluster_json_file) - cluster = Cluster() - cluster.build_from_file("./auto_parallel_cluster.json") - os.remove("./auto_parallel_cluster.json") - - ops = train_program.global_block().ops - vars = train_program.global_block().vars - new_dist_context = DistributedContext() - set_default_dist_attr(train_program, new_dist_context, - global_process_mesh) - - serial_program_info = SerialProgramInfo(train_program, startup_program, - loss, optimizer, cluster) - result = get_all_distributed_main_program(serial_program_info, - new_dist_context) - self.assertEqual(len(result), 4) - if __name__ == "__main__": unittest.main() -- GitLab