未验证 提交 e494b73b 编写于 作者: C caozhou 提交者: GitHub

fix reshard bug (#41106)

上级 ee8eeb45
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import copy import copy
import time import time
import random import random
import logging
from functools import reduce from functools import reduce
from itertools import chain, product from itertools import chain, product
from collections import OrderedDict from collections import OrderedDict
...@@ -741,7 +740,7 @@ class MCMC(SearchAlgorithm): ...@@ -741,7 +740,7 @@ class MCMC(SearchAlgorithm):
return best_dist_context, min_cost return best_dist_context, min_cost
def search(self): def search(self):
logging.info("Start MCMC searching.") print("Start MCMC searching.")
start_time = time.time() start_time = time.time()
train_program = self.serial_program_info.train_program train_program = self.serial_program_info.train_program
cluster = self.serial_program_info.cluster cluster = self.serial_program_info.cluster
...@@ -757,9 +756,8 @@ class MCMC(SearchAlgorithm): ...@@ -757,9 +756,8 @@ class MCMC(SearchAlgorithm):
searched_pipeline_dist_context = None searched_pipeline_dist_context = None
pipeline_min_cost = None pipeline_min_cost = None
for process_mesh_topology in process_mesh_topology_list: for process_mesh_topology in process_mesh_topology_list:
logging.info( print("MCMC search: search process mesh {} with pipeline mode.".
"MCMC search: search process mesh {} with pipeline mode.". format(process_mesh_topology))
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program( valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, True) train_program, process_mesh_topology, True)
init_dist_context = self.init_program( init_dist_context = self.init_program(
...@@ -768,7 +766,7 @@ class MCMC(SearchAlgorithm): ...@@ -768,7 +766,7 @@ class MCMC(SearchAlgorithm):
best_dist_context, cost = self._search_core(valid_dist_attr_dict, best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context, init_dist_context,
pipeline_process_meshes) pipeline_process_meshes)
logging.info( print(
"MCMC search: the min cost is {} in the process mesh {} with pipeline mode.". "MCMC search: the min cost is {} in the process mesh {} with pipeline mode.".
format(cost, process_mesh_topology)) format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext() best_dist_context._dist_op_context = DistributedOperatorContext()
...@@ -784,9 +782,8 @@ class MCMC(SearchAlgorithm): ...@@ -784,9 +782,8 @@ class MCMC(SearchAlgorithm):
# if process_mesh_topology shape is 3, include pipeline mode by default # if process_mesh_topology shape is 3, include pipeline mode by default
if len(process_mesh_topology) == 3: if len(process_mesh_topology) == 3:
continue continue
logging.info( print("MCMC search: search process mesh {} without pipeline mode.".
"MCMC search: search process mesh {} without pipeline mode.". format(process_mesh_topology))
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program( valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, False) train_program, process_mesh_topology, False)
init_dist_context = self.init_program( init_dist_context = self.init_program(
...@@ -795,7 +792,7 @@ class MCMC(SearchAlgorithm): ...@@ -795,7 +792,7 @@ class MCMC(SearchAlgorithm):
best_dist_context, cost = self._search_core(valid_dist_attr_dict, best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context, init_dist_context,
pipeline_process_meshes) pipeline_process_meshes)
logging.info( print(
"MCMC search: the min cost is {} in the process mesh {} without pipeline mode.". "MCMC search: the min cost is {} in the process mesh {} without pipeline mode.".
format(cost, process_mesh_topology)) format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext() best_dist_context._dist_op_context = DistributedOperatorContext()
...@@ -808,7 +805,7 @@ class MCMC(SearchAlgorithm): ...@@ -808,7 +805,7 @@ class MCMC(SearchAlgorithm):
if non_pipeline_min_cost > pipeline_min_cost: if non_pipeline_min_cost > pipeline_min_cost:
searched_dist_context = searched_pipeline_dist_context searched_dist_context = searched_pipeline_dist_context
min_cost = pipeline_min_cost min_cost = pipeline_min_cost
logging.info( print(
"Better set FLAGS_benchmark=1 to avoid hang problem in the pipeline mode." "Better set FLAGS_benchmark=1 to avoid hang problem in the pipeline mode."
) )
else: else:
...@@ -820,7 +817,7 @@ class MCMC(SearchAlgorithm): ...@@ -820,7 +817,7 @@ class MCMC(SearchAlgorithm):
for process_mesh in searched_dist_context._process_meshes: for process_mesh in searched_dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes) pg0.add_ranks(process_mesh.processes)
end_time = time.time() end_time = time.time()
logging.info( print(
"End MCMC searching: the min cost is {} and the search time is {}s.". "End MCMC searching: the min cost is {} and the search time is {}s.".
format(min_cost, end_time - start_time)) format(min_cost, end_time - start_time))
return searched_dist_context, min_cost return searched_dist_context, min_cost
......
...@@ -1239,7 +1239,9 @@ class Resharder: ...@@ -1239,7 +1239,9 @@ class Resharder:
for item in self.has_allgather[var_name]: for item in self.has_allgather[var_name]:
if op_desc.group == item[0]: if op_desc.group == item[0]:
tensor_list = [ tensor_list = [
program.global_block().vars[var_name] get_var_with_recursion(
var_name, block,
self.auto_parallel_main_prog)
for var_name in item[1] for var_name in item[1]
] ]
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册