未验证 提交 e494b73b 编写于 作者: C caozhou 提交者: GitHub

fix reshard bug (#41106)

上级 ee8eeb45
......@@ -15,7 +15,6 @@
import copy
import time
import random
import logging
from functools import reduce
from itertools import chain, product
from collections import OrderedDict
......@@ -741,7 +740,7 @@ class MCMC(SearchAlgorithm):
return best_dist_context, min_cost
def search(self):
logging.info("Start MCMC searching.")
print("Start MCMC searching.")
start_time = time.time()
train_program = self.serial_program_info.train_program
cluster = self.serial_program_info.cluster
......@@ -757,8 +756,7 @@ class MCMC(SearchAlgorithm):
searched_pipeline_dist_context = None
pipeline_min_cost = None
for process_mesh_topology in process_mesh_topology_list:
logging.info(
"MCMC search: search process mesh {} with pipeline mode.".
print("MCMC search: search process mesh {} with pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, True)
......@@ -768,7 +766,7 @@ class MCMC(SearchAlgorithm):
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
print(
"MCMC search: the min cost is {} in the process mesh {} with pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
......@@ -784,8 +782,7 @@ class MCMC(SearchAlgorithm):
# if process_mesh_topology shape is 3, include pipeline mode by default
if len(process_mesh_topology) == 3:
continue
logging.info(
"MCMC search: search process mesh {} without pipeline mode.".
print("MCMC search: search process mesh {} without pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, False)
......@@ -795,7 +792,7 @@ class MCMC(SearchAlgorithm):
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
print(
"MCMC search: the min cost is {} in the process mesh {} without pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
......@@ -808,7 +805,7 @@ class MCMC(SearchAlgorithm):
if non_pipeline_min_cost > pipeline_min_cost:
searched_dist_context = searched_pipeline_dist_context
min_cost = pipeline_min_cost
logging.info(
print(
"Better set FLAGS_benchmark=1 to avoid hang problem in the pipeline mode."
)
else:
......@@ -820,7 +817,7 @@ class MCMC(SearchAlgorithm):
for process_mesh in searched_dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
end_time = time.time()
logging.info(
print(
"End MCMC searching: the min cost is {} and the search time is {}s.".
format(min_cost, end_time - start_time))
return searched_dist_context, min_cost
......
......@@ -1239,7 +1239,9 @@ class Resharder:
for item in self.has_allgather[var_name]:
if op_desc.group == item[0]:
tensor_list = [
program.global_block().vars[var_name]
get_var_with_recursion(
var_name, block,
self.auto_parallel_main_prog)
for var_name in item[1]
]
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册