未验证 提交 ebbd3564 编写于 作者: J JZ-LIANG 提交者: GitHub

remove unitest for auto_searcher (#38370)

上级 4d5a6064
......@@ -212,41 +212,6 @@ class TestMLPSearcher(unittest.TestCase):
self.assertTrue(
check_nonpipeline_enumerater(train_program, process_mesh_topology))
def test_get_dist_programs(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
process_mesh_topology = [4]
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, False)
from test_auto_parallel_cluster import cluster_json
cluster_json_file = ""
cluster_json_object = json.loads(cluster_json)
with open("./auto_parallel_cluster.json", "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
cluster = Cluster()
cluster.build_from_file("./auto_parallel_cluster.json")
os.remove("./auto_parallel_cluster.json")
ops = train_program.global_block().ops
vars = train_program.global_block().vars
new_dist_context = DistributedContext()
set_default_dist_attr(train_program, new_dist_context,
global_process_mesh)
serial_program_info = SerialProgramInfo(train_program, startup_program,
loss, optimizer, cluster)
result = get_all_distributed_main_program(serial_program_info,
new_dist_context)
self.assertEqual(len(result), 4)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册