diff --git a/python/paddle/distributed/auto_parallel/cost_model.py b/python/paddle/distributed/auto_parallel/cost_model.py index b1ff4fb0ba7c9696ea30db09277b5ad3d2836414..9252f8de905b5f6cedd834b520d8a7b9ad2e125a 100644 --- a/python/paddle/distributed/auto_parallel/cost_model.py +++ b/python/paddle/distributed/auto_parallel/cost_model.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np + import json import queue import copy from enum import Enum + +import numpy as np + import paddle +from paddle.fluid import core +from paddle.distributed.fleet.meta_optimizers.common import OpRole SUCC = 0 # successor PRED = 1 # predecessor @@ -121,8 +126,12 @@ class TensorCostNode(CostNode): batch_size=None, shared_node_id=None): super(TensorCostNode, self).__init__(node, node_type, id) - self.shape = node.shape - self.dtype = node.dtype + if node.name == "create_py_reader_0" or node.name == "double_buffer_0": + self.shape = [2, 2] + self.dtype = paddle.float32 + else: + self.shape = node.shape + self.dtype = node.dtype self.dtype_factor = 1 self.persistable = None self.shared_node_id = shared_node_id @@ -130,9 +139,10 @@ class TensorCostNode(CostNode): self.dtype_factor *= 4 elif node.dtype == paddle.int64: self.dtype_factor *= 8 + elif node.dtype == paddle.uint8: + self.dtype_factor = 1 else: raise NotImplementedError("{} not counted".format(node.dtype)) - self.batch_size = None if batch_size is not None: self.batch_size = batch_size @@ -155,9 +165,9 @@ class CompOpCostNode(CostNode): def init_comp_cost(self, cost_data): # TODO: improve fluid.CostModel for more specific cost_data - op_name = self.node.type - if op_name in cost_data.keys(): - self.cost = cost_data[op_name] + op_id = self.node.desc.id() + if op_id in cost_data.keys(): + self.cost = cost_data[op_id] else: self.cost = 0.0 @@ -215,8 +225,17 @@ class CostModel(object): program.blocks) == 1, "Program more than 1 block not supported." block = program.blocks[0] + var_id = "lod_tensor_blocking_queue_0" + new_var = program.global_block().create_var( + name=var_id, + dtype=paddle.float32, + type=core.VarDesc.VarType.LOD_TENSOR) + nodes[var_id] = TensorCostNode(new_var, CostNodeType.VARIABLE, + "lod_tensor_blocking_queue_0") for var in block.vars.values(): var_id = var.name + # if var.name == "create_py_reader_0" or var.name == "double_buffer_0": + # continue nodes[var_id] = TensorCostNode(var, CostNodeType.VARIABLE, var_id) graph[var_id] = [[], []] @@ -225,7 +244,10 @@ class CostModel(object): if op.type.startswith('c_') or op.type.startswith( 'send') or op.type.startswith('recv'): is_bwd = False - if op.type.startswith('c_'): + if op.type.startswith( + 'c_' + ) and op.type != "c_sync_calc_stream" and not op.type.startswith( + 'c_embedding'): ring_id = op.attr('ring_id') if ring_id not in self.ring2rank: self.ring2rank[ring_id] = set() @@ -238,7 +260,8 @@ class CostModel(object): op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id, is_bwd) else: - is_bwd = '_grad' in op.type + is_bwd = (int(op.attr('op_role')) == int(OpRole.Backward) + ) or "@GRAD" in op.input_arg_names is_optim = 'LearningRate' in op.input_names op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id, is_bwd, is_optim) @@ -258,6 +281,7 @@ class CostModel(object): comm_input_shape = var_node.shape except: continue + for i in range(len(op.output_names)): try: var_id = op.output(op.output_names[i])[0] @@ -361,7 +385,9 @@ class CostModel(object): for sub_idx in range(self.total_rank): for node_id, edges in self.op_graph[sub_idx].items(): node = self.nodes[sub_idx][node_id] - if node_id.startswith('c_'): + if node_id.startswith('c_') and not node.id.startswith( + "c_sync_calc_stream") and not node.id.startswith( + 'c_embedding'): ring_id = node.node.attr('ring_id') node.set_ranks(list(self.ring2rank[ring_id])) node.init_comm_cost(self.cluster) @@ -454,31 +480,52 @@ class CostModel(object): # delete edges and add new edges succ = None - runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[SUCC]) - if len(runtime_graph[pred_id][SUCC]) > 1: - # predecessor has more than 1 successor - # the merged_node is to inherit the rest of its successors - succ = runtime_graph[pred_id][SUCC] - succ.remove(node_id) - runtime_graph[merged_node_id][SUCC] += succ - runtime_graph[merged_node_id][PRED] = runtime_graph[pred_id][ - PRED] - for i in runtime_graph[pred_id][PRED]: - runtime_graph[i][SUCC].remove(pred_id) - runtime_graph[i][SUCC].append(merged_node_id) - - for i in edges[SUCC]: - runtime_graph[i][PRED].remove(node_id) - runtime_graph[i][PRED].append(merged_node_id) + try: + runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[ + SUCC]) + + if len(runtime_graph[pred_id][SUCC]) > 1: + # predecessor has more than 1 successor + # the merged_node is to inherit the rest of its successors + succ = runtime_graph[pred_id][SUCC] + succ.remove(node_id) + runtime_graph[merged_node_id][SUCC] += succ + runtime_graph[merged_node_id][PRED] = runtime_graph[ + pred_id][PRED] + except: + pass + try: + for i in runtime_graph[pred_id][PRED]: + try: + runtime_graph[i][SUCC].remove(pred_id) + except: + continue + runtime_graph[i][SUCC].append(merged_node_id) + except: + pass + + try: + for i in edges[SUCC]: + runtime_graph[i][PRED].remove(node_id) + runtime_graph[i][PRED].append(merged_node_id) + except: + pass if succ is not None: for i in succ: - runtime_graph[i][PRED].remove(pred_id) + try: + runtime_graph[i][PRED].remove(pred_id) + except: + continue runtime_graph[i][PRED].append(merged_node_id) runtime_graph.pop(node_id) - runtime_graph.pop(pred_id) + try: + runtime_graph.pop(pred_id) + except: + continue reduct_cnt += 1 - self.eliminate_multi_edges(runtime_graph) + self.eliminate_multi_edges(runtime_graph) + break return reduct_cnt # the number of nodes that have been reduced def _merge_branch(self, nodes, runtime_graph, is_bwd=False): @@ -496,7 +543,10 @@ class CostModel(object): succ_to_elim = [] for succ_id in succ_nodes_id: for succ_2_id in succ_nodes_id: - tmp = runtime_graph[succ_2_id][SUCC] + try: + tmp = runtime_graph[succ_2_id][SUCC] + except: + continue if succ_id in tmp: succ_to_elim.append(succ_id) break @@ -506,16 +556,22 @@ class CostModel(object): reduct_cnt += 1 to_merge = True - if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]][ - SUCC]) < 1: + try: + if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]] + [SUCC]) < 1: + continue + except: continue end_node_id = runtime_graph[edges[SUCC][0]][SUCC][0] for i in succ_nodes_id: - if len(runtime_graph[i][SUCC]) != 1 or \ - runtime_graph[i][SUCC][0] != end_node_id: - to_merge = False # if branches has different end node, we don't merge them - break - if to_merge: + try: + if len(runtime_graph[i][SUCC]) != 1 or \ + runtime_graph[i][SUCC][0] != end_node_id: + to_merge = False # if branches has different end node, we don't merge them + break + except: + continue + if to_merge and len(succ_nodes_id) > 1: to_merge_node_list = [nodes[i] for i in succ_nodes_id] merged_node_id, merged_node = self._merge_node( to_merge_node_list, merge_type='branch', nodes=nodes) @@ -529,9 +585,13 @@ class CostModel(object): runtime_graph[end_node_id][PRED] = [merged_node_id] runtime_graph[node_id][SUCC] = [merged_node_id] - for i in succ_nodes_id: - runtime_graph.pop(i) - reduct_cnt += len(to_merge_node_list) - 1 + try: + for i in succ_nodes_id: + runtime_graph.pop(i) + reduct_cnt += len(to_merge_node_list) - 1 + break + except: + pass return reduct_cnt def get_runtime_cost(self): @@ -615,7 +675,7 @@ class CostModel(object): return static_mem, cur_mem, top_mem def get_pipeline_time(self): - if self.total_rank <= 1: + if self.pp2rank is None: return self.fwd_time[0] + self.bwd_time[0] + self.optim_time[0] else: return self._simulate_pipeline() diff --git a/python/paddle/distributed/auto_parallel/mapper.py b/python/paddle/distributed/auto_parallel/mapper.py index 543fa2d9681c01b7f58f56aa078ee420865254b0..f5d9c32d33eb328bf30789891d1b5ead3c4c84d2 100644 --- a/python/paddle/distributed/auto_parallel/mapper.py +++ b/python/paddle/distributed/auto_parallel/mapper.py @@ -118,11 +118,11 @@ def get_comm_volume(comm_op, src_rank, tgt_rank): return comm_volume -def analyze_comm_requirements_from_op(op, rank): +def analyze_comm_requirements_from_op(op, rank, g_process_group_map): comm_requirements_to_ranks = {} if is_collective_comm_op(op): process_group_id = op.attr("ring_id") - process_group = get_process_group(process_group_id) + process_group = get_process_group(process_group_id, g_process_group_map) if rank not in process_group.ranks: return comm_requirements_to_ranks for tgt_rank in process_group.ranks: @@ -142,7 +142,9 @@ def analyze_comm_requirements_from_op(op, rank): return comm_requirements_to_ranks -def analyze_requirements_for_program(program, rank): +def analyze_requirements_for_program(src_info, rank): + program = src_info[0] + g_process_group_map = src_info[1] resource_requirements = {} comm_requirements_to_ranks = {} # only support device_type and only support GPU for now @@ -150,7 +152,7 @@ def analyze_requirements_for_program(program, rank): for block in program.blocks: for op in block.ops: cur_comm_requirements_to_ranks = analyze_comm_requirements_from_op( - op, rank) + op, rank, g_process_group_map) for tgt_rank, link_info in cur_comm_requirements_to_ranks.items(): if tgt_rank in comm_requirements_to_ranks: comm_requirements_to_ranks[tgt_rank][ @@ -164,9 +166,9 @@ def analyze_requirements_for_program(program, rank): def build_process_graph(distributed_program): graph = Graph() - for src_rank, src_program in distributed_program.items(): + for src_rank, src_info in distributed_program.items(): resource_requirements, comm_requirements_to_ranks = analyze_requirements_for_program( - src_program, src_rank) + src_info, src_rank) graph.add_node(src_rank, resource_requirements=resource_requirements) for tgt_rank, comm_requirements in comm_requirements_to_ranks.items(): graph.add_edge( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 7bda6a9a28348c0330a60895e46622ca4337dc2f..3a4d8412bf835515821f56c5207c0d22b4b6b399 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -308,6 +308,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): assert len(x_dims_mapping) >= len( y_dims_mapping), "now just support x dims > y dims" + if len(y_dims_mapping) != 2: + return False if len(x_dims_mapping) == len(y_dims_mapping) and len( x_dims_mapping) == 4: if x_dims_mapping[:2] != y_dims_mapping[:2]: @@ -602,6 +604,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): # for gpt2, x dims > y dims, this is a temporary solution assert len(x_dims_mapping) >= len( y_dims_mapping), "now just support x dims > y dims" + if len(y_dims_mapping) != 2: + return False if len(x_dims_mapping) == len(y_dims_mapping) and len( x_dims_mapping) == 4: if x_dims_mapping[:2] != y_dims_mapping[:2]: @@ -889,6 +893,8 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): y_dims_mapping ), "now just support x dims > y dims,but x:{0} and y:{1}".format( x_dims_mapping, y_dims_mapping) + if len(y_dims_mapping) != 2: + return False if len(x_dims_mapping) == len(y_dims_mapping) and len( x_dims_mapping) == 4: if x_dims_mapping[:2] != y_dims_mapping[:2]: @@ -1010,6 +1016,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): return False assert len(x_dims_mapping) >= len( y_dims_mapping), "now just support x dims > y dims" + if len(y_dims_mapping) != 2: + return False if len(x_dims_mapping) == len(y_dims_mapping) and len( x_dims_mapping) == 4: if x_dims_mapping[:2] != y_dims_mapping[:2]: @@ -1297,6 +1305,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) assert len(x_dims_mapping) >= len( y_dims_mapping), "now just support x dims > y dims" + if len(y_dims_mapping) != 2: + return False if len(x_dims_mapping) == len(y_dims_mapping) and len( x_dims_mapping) == 4: if x_dims_mapping[:2] != y_dims_mapping[:2]: @@ -1583,7 +1593,8 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): y_dims_mapping ), "now just support x dims > y dims,but x:{0} and y:{1}".format( x_dims_mapping, y_dims_mapping) - + if len(y_dims_mapping) != 2: + return False if len(x_dims_mapping) == len(y_dims_mapping) and len( x_dims_mapping) == 4: if x_dims_mapping[:2] != y_dims_mapping[:2]: diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index affb27317daafafca19c84fced393d45a873ed75..f6ddf2b9b7350677ffcfd7095f2106fd9bd7441f 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -20,6 +20,9 @@ import copy import pathlib import subprocess import logging +import pickle +import time + import paddle from paddle.distributed.utils import get_logger from paddle.distributed.fleet import cloud_utils @@ -30,13 +33,18 @@ from .dist_context import set_default_distributed_context from .completion import complete_annotation, complete_backward_annotation from .partitioner import Partitioner from .process_group import get_all_process_groups +from .process_group import get_process_group from .process_group import get_world_process_groups +from .process_group import _g_process_group_map, ProcessGroup from .utils import make_data_unshard from .utils import set_grad_var_shape -from .reshard import reshard +from .utils import SerialProgramInfo +from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from .cluster import Cluster from .mapper import mapping -# from .auto_search import auto_search +from .dist_op import DistributedOperator +from .dist_tensor import DistributedTensor +from .planner import Planner _logger = get_logger(logging.INFO) @@ -82,12 +90,20 @@ class AutoParallelizer: if suffix in attr_name: op._remove_attr(attr_name) - def _get_dist_program(self, dist_context, rank): - # Annotation completion - completed_main_program = complete_annotation(self._main_program, - dist_context) + def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): + completed_main_program = None + if dist_context is None: + # Annotation completion + self._dist_context = DistributedContext() + _logger.info("Start annotation dist attr.") + completed_main_program = complete_annotation(self._main_program, + self._dist_context) + else: + completed_main_program = self._main_program + self._dist_context = copy.deepcopy(dist_context) + # Logical partition - partitioner = Partitioner(self._dist_strategy, dist_context, rank) + partitioner = Partitioner(self._dist_strategy, self._dist_context, rank) dist_main_prog, dist_startup_prog = partitioner.transpile_forward( completed_main_program, self._startup_program) dist_params_grads = partitioner.apply_backward( @@ -97,11 +113,21 @@ class AutoParallelizer: copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog, dist_startup_prog) - make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) + set_grad_var_shape(dist_main_prog, self._dist_context) - reshard(dist_main_prog, dist_startup_prog, rank, dist_context) + make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) - return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog + reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) + + g_process_group_map = None + if not relaunch_phase: + g_process_group_map = copy.deepcopy(_g_process_group_map) + HAS_SENT.clear() + HAS_RECV.clear() + HAS_ALLGATHER.clear() + _g_process_group_map.clear() + _g_process_group_map[0] = ProcessGroup(0, []) + return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map def parallelize(self, loss, @@ -121,11 +147,51 @@ class AutoParallelizer: "The cluster must not be none when using auto mapping." dist_programs = {} world_process_group = get_world_process_groups() + dist_context = None + # auto search + if self._dist_strategy.auto_search: + logging.info("Start searching dist attr.") + serial_program_info = SerialProgramInfo( + self._main_program, self._startup_program, self._loss, + self._optimizer, self._cluster) + planner = Planner( + serial_program_info, + algorithm_config={"name": "mcmc", + "max_search_times": 5}) + dist_context, _ = planner.search() + logging.info("End searching dist attr.") + + # serialize the dist context by planner + if dist_context is not None: + logging.info("Start serialize searched dist attr") + cwd = pathlib.Path().resolve() + searched_dist_context_path = os.path.join( + cwd, f"searched_dist_context_{time.time()}.pkl") + saved_dist_context = {} + ops_dist_attr = {} + tensors_dist_attr = {} + for key, dist_op in dist_context._dist_ops_for_program.items(): + ops_dist_attr[key] = dist_op.dist_attr + for key, dist_tensor in dist_context._dist_tensors_for_program.items( + ): + tensors_dist_attr[key] = dist_tensor.dist_attr + saved_dist_context["ops_dist_attr"] = ops_dist_attr + saved_dist_context["tensors_dist_attr"] = tensors_dist_attr + saved_dist_context[ + "process_meshes"] = dist_context._process_meshes + with open(searched_dist_context_path, + "wb") as dist_context_file: + pickle.dump(saved_dist_context, dist_context_file) + os.environ[ + 'PADDLE_SEARCHED_DIST_CONTEXT_PATH'] = searched_dist_context_path + logging.info( + f"End serialize searched dist attr to {searched_dist_context_path}" + ) + for rank in world_process_group.ranks: - dist_context = DistributedContext() - dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program( - dist_context, rank) - dist_programs[rank] = dist_main_prog + dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map = self._get_dist_program( + rank, dist_context) + dist_programs[rank] = [dist_main_prog, g_process_group_map] # Do the mapping between the distributed program graph and the cluster graph rank_mapping_dict = mapping(dist_programs, self._cluster) @@ -162,9 +228,64 @@ class AutoParallelizer: else: # Parallelization after the mapping pass rank = paddle.distributed.get_rank() + dist_context = None + searched_dist_context_path = os.getenv( + "PADDLE_SEARCHED_DIST_CONTEXT_PATH", None) + if searched_dist_context_path is not None: + with open(searched_dist_context_path, + "rb") as dist_context_file: + saved_dist_context = pickle.load(dist_context_file) + dist_context = DistributedContext() + for op in self._main_program.global_block().ops: + dist_attr = saved_dist_context["ops_dist_attr"][ + op.desc.id()] + dist_op = DistributedOperator(op, dist_attr) + dist_context.add_dist_op_for_program(dist_op) + + vars = self._main_program.global_block().vars + for var in vars.values(): + dist_attr = saved_dist_context["tensors_dist_attr"][ + var.desc.id()] + dist_tensor = DistributedTensor(var, dist_attr) + dist_context.add_dist_tensor_for_program(dist_tensor) + + dist_context._process_meshes = saved_dist_context[ + "process_meshes"] + + else: + if self._dist_strategy.auto_search: + serial_program_info = SerialProgramInfo( + self._main_program, + self._startup_program, + self._loss, + self._optimizer, + cluster=self._cluster) + planner = Planner( + serial_program_info, + algorithm_config={ + "name": "mcmc", + "max_search_times": 5 + }) + dist_context, _ = planner.search() + + # rebuild g_process_group + if dist_context is not None: + pg0 = get_process_group(0) + for process_mesh in dist_context._process_meshes: + pg0.add_ranks(process_mesh.processes) + dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, _ = self._get_dist_program( + rank, dist_context, relaunch_phase=True) - dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program( - self._dist_context, rank) + # NOTE: This is a trick to fix hang in pipeline mode when dist context is searched by planner + if self._dist_strategy.auto_search: + is_pipeline = False + for op in dist_main_prog.global_block().ops: + if op.type == "send_v2" or op.type == "recv_v2": + is_pipeline = True + break + if is_pipeline: + with paddle.static.program_guard(dist_main_prog): + paddle.distributed.barrier() # Traverse different rank programs and traverse each op of them, # instantiate communication by process_mapping. diff --git a/python/paddle/distributed/auto_parallel/planner.py b/python/paddle/distributed/auto_parallel/planner.py index 8d1ecf9deaa9570e6777e0c5566969712f1dba2a..7c4ce0b2435069512255813e155011ef3ee55c2c 100644 --- a/python/paddle/distributed/auto_parallel/planner.py +++ b/python/paddle/distributed/auto_parallel/planner.py @@ -32,6 +32,7 @@ from .completion import is_elementwise_like_op from .operators.common import get_distributed_operator_impl_container from .utils import update_op_dims_mapping_by_default_dist_impl from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl +from .utils import get_all_distributed_main_program from .dist_context import DistributedContext, DistributedOperatorContext from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute @@ -370,3 +371,499 @@ class PlanSpace: )] = [op_valid_dist_attrs, pipeline_stage] return valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh + + +class SearchAlgorithm: + def __init__(self, name): + self._name = name + + @property + def name(self): + self.name = name + + def search(self): + raise NotImplementedError("Please Implement this method in subclass.") + + +class MCMC(SearchAlgorithm): + def __init__(self, serial_program_info, max_search_times=5): + super(MCMC, self).__init__("mcmc") + self._serial_program_info = serial_program_info + self._max_search_times = max_search_times + + @property + def serial_program_info(self): + return self._serial_program_info + + @property + def max_search_times(self): + return self._max_search_times + + def make_special_op_unshard(self, op, ops, vars, dist_context, + valid_dist_attr_dict): + if op.type == "softmax_with_cross_entropy": + for var_name in op.input_arg_names: + dims_mapping = dist_context.get_op_dist_attr_for_program( + op).get_input_dims_mapping(var_name) + if dims_mapping != dist_context.get_tensor_dist_attr_for_program( + vars[var_name]).dims_mapping: + has_changed = False + for search_op in ops: + if var_name in search_op.output_arg_names: + op_dist_attr_list = valid_dist_attr_dict[ + search_op.desc.id()][0] + for op_dist_attr in op_dist_attr_list: + if op_dist_attr.get_output_dims_mapping( + var_name) == dims_mapping: + dist_context.set_op_dist_attr_for_program( + search_op, op_dist_attr) + tensor_dist_attr = TensorDistributedAttribute( + ) + tensor_dist_attr.process_mesh = op_dist_attr.process_mesh + tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping( + var_name) + dist_context.set_tensor_dist_attr_for_program( + vars[var_name], tensor_dist_attr) + has_changed = True + break + if has_changed: + break + if not has_changed: + raise ValueError( + "Change softmax_with_cross_entropy dist attr failed") + + def init_program(self, valid_dist_attr_dict, program, + pipeline_process_meshes, global_process_mesh): + ops = program.global_block().ops + vars = program.global_block().vars + new_dist_context = DistributedContext() + + for op in ops: + op_valid_dist_attr_list = valid_dist_attr_dict[op.desc.id()][0] + random_op_dist_attr = np.random.randint( + len(op_valid_dist_attr_list)) + init_op_dist_attr = op_valid_dist_attr_list[random_op_dist_attr] + new_dist_context.set_op_dist_attr_for_program(op, init_op_dist_attr) + for var_name in op.input_arg_names: + if var_name == "lod_tensor_blocking_queue_0": + continue + if new_dist_context.get_tensor_dist_attr_for_program(vars[ + var_name]) is None: + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = init_op_dist_attr.process_mesh + tensor_dist_attr.dims_mapping = init_op_dist_attr.get_input_dims_mapping( + var_name) + new_dist_context.set_tensor_dist_attr_for_program( + vars[var_name], tensor_dist_attr) + + for var_name in op.output_arg_names: + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = init_op_dist_attr.process_mesh + tensor_dist_attr.dims_mapping = init_op_dist_attr.get_output_dims_mapping( + var_name) + new_dist_context.set_tensor_dist_attr_for_program( + vars[var_name], tensor_dist_attr) + + # NOTE: this is a temporary solution to make softmax_with_cross_entropy unshard + self.make_special_op_unshard(op, ops, vars, new_dist_context, + valid_dist_attr_dict) + + # add process meshes to distributed context + if global_process_mesh is not None: + new_dist_context.add_process_mesh(global_process_mesh) + elif pipeline_process_meshes is not None: + for process_mesh in pipeline_process_meshes: + new_dist_context.add_process_mesh(process_mesh) + + return new_dist_context + + def estimate_searched_strategy_cost(self, + dist_context, + pipeline_process_meshes=None): + cost = None + # get all distributed programs + all_dist_main_program = get_all_distributed_main_program( + self.serial_program_info, dist_context) + pipeline_config = [ + process_mesh.processes for process_mesh in pipeline_process_meshes + ] if pipeline_process_meshes is not None else None + microbatch_size = 1 + for program in all_dist_main_program: + searched_batch_size = False + for var in program.list_vars(): + if var.is_data and "@RESHARD" in var.name: + microbatch_size = var.shape[0] + searched_batch_size = True + break + if searched_batch_size: + break + + from .utils import get_standalone_cost_data + standalone_cost_data = get_standalone_cost_data(all_dist_main_program) + + # cost model does not support cluster argument + cost = estimate_cost( + all_dist_main_program, + cluster=None, + pipeline_config=pipeline_config, + standalone_cost_data=standalone_cost_data, + batch_size=microbatch_size) + + return cost + + def set_tensor_dist_attr(self, op, op_dist_attr, vars, dist_context): + # set output tensor distributed attribute + for var_name in op.output_arg_names: + process_mesh = op_dist_attr.process_mesh + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = process_mesh + tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping( + var_name) + dist_context.set_tensor_dist_attr_for_program(vars[var_name], + tensor_dist_attr) + + # set input tensor distributed attribute if input is data or parameter + for var_name in op.input_arg_names: + if vars[var_name].is_parameter or vars[var_name].is_data: + process_mesh = op_dist_attr.process_mesh + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = process_mesh + tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping( + var_name) + dist_context.set_tensor_dist_attr_for_program(vars[var_name], + tensor_dist_attr) + + def change_process_mesh(self, op, changed_process_mesh, vars, dist_context): + dist_context.get_op_dist_attr_for_program( + op).process_mesh = changed_process_mesh + for var_name in op.output_arg_names: + dist_context.get_tensor_dist_attr_for_program(vars[ + var_name]).process_mesh = changed_process_mesh + for var_name in op.input_arg_names: + if vars[var_name].is_parameter or vars[var_name].is_data: + dist_context.get_tensor_dist_attr_for_program(vars[ + var_name]).process_mesh = changed_process_mesh + + def search_once(self, + program, + valid_dist_attr_dict, + dist_context, + pipeline_process_meshes=None): + raw_ops = program.global_block().ops + ops = [] + for op in raw_ops: + if op.type not in PlanSpace.not_enum_ops: + ops.append(op) + assert ops, "The ops of program have no distributed attributes." + vars = program.global_block().vars + new_dist_context = copy.deepcopy(dist_context) + new_dist_context._dist_op_context = DistributedOperatorContext() + new_valid_dist_attr_dict = None + random_selected_op_idx = np.random.randint(len(ops)) + selected_op = ops[random_selected_op_idx] + op_valid_dist_attr_list = valid_dist_attr_dict[selected_op.desc.id()][0] + pipeline_stage = valid_dist_attr_dict[selected_op.desc.id()][1] + random_selected_dist_attr_idx = np.random.randint( + len(op_valid_dist_attr_list)) + selected_op_dist_attr = copy.deepcopy(op_valid_dist_attr_list[ + random_selected_dist_attr_idx]) + + start_idx = ops[0].desc.id() + if pipeline_stage > -1: + # in pipeline mode, the above phase just select a dims mapping + # 0 represents not changed, 1 represents to be the same with before stage, 2 represents to be the same with the latter stage + new_valid_dist_attr_dict = copy.deepcopy(valid_dist_attr_dict) + changed_mode = np.random.randint(3) + if changed_mode == 0: + # not change the process mesh, just change dims mapping + new_dist_context.set_op_dist_attr_for_program( + selected_op, selected_op_dist_attr) + self.set_tensor_dist_attr(selected_op, selected_op_dist_attr, + vars, new_dist_context) + + elif changed_mode == 1: + changed_stage = pipeline_stage - 1 + if changed_stage == -1 or random_selected_op_idx == len(ops) - 1 or \ + (random_selected_op_idx + 1 == len(ops) - 1 and new_valid_dist_attr_dict[ops[random_selected_op_idx + 1].desc.id()][1] == pipeline_stage + 1 ): + new_dist_context.set_op_dist_attr_for_program( + selected_op, selected_op_dist_attr) + self.set_tensor_dist_attr(selected_op, + selected_op_dist_attr, vars, + new_dist_context) + + else: + selected_op_process_mesh = pipeline_process_meshes[ + pipeline_stage] + next_op_id = ops[random_selected_op_idx + 1].desc.id() + if new_valid_dist_attr_dict[next_op_id][ + 1] == pipeline_stage + 1 and random_selected_op_idx + 1 != len( + ops) - 1: + new_valid_dist_attr_dict[next_op_id][1] = pipeline_stage + for op_dist_attr in new_valid_dist_attr_dict[ + next_op_id][0]: + op_dist_attr.process_mesh = selected_op_process_mesh + # set next op dist attr in the discontext and output/input tensor process mesh + self.change_process_mesh( + ops[random_selected_op_idx + 1], + selected_op_process_mesh, vars, new_dist_context) + + # change the selected op stage and output dist attr + new_valid_dist_attr_dict[selected_op.desc.id()][ + 1] = changed_stage + new_process_mesh = pipeline_process_meshes[changed_stage] + selected_op_dist_attr.process_mesh = new_process_mesh + for op_dist_attr in new_valid_dist_attr_dict[ + selected_op.desc.id()][0]: + op_dist_attr.process_mesh = new_process_mesh + new_dist_context.set_op_dist_attr_for_program( + selected_op, selected_op_dist_attr) + + self.set_tensor_dist_attr(selected_op, + selected_op_dist_attr, vars, + new_dist_context) + + # change the pre op stage + for idx in range(random_selected_op_idx - 1, -1, -1): + stage = new_valid_dist_attr_dict[ops[idx].desc.id()][1] + valid_dist_attr_list = new_valid_dist_attr_dict[ops[ + idx].desc.id()][0] + new_process_mesh = pipeline_process_meshes[ + changed_stage] + if stage == changed_stage + 1: + new_valid_dist_attr_dict[ops[idx].desc.id()][ + 1] = changed_stage + for op_dist_attr in valid_dist_attr_list: + op_dist_attr.process_mesh = new_process_mesh + new_dist_context.get_op_dist_attr_for_program(ops[ + idx]).process_mesh = new_process_mesh + # change process mesh of the output and input tensor + self.change_process_mesh(ops[idx], new_process_mesh, + vars, new_dist_context) + else: + break + + else: + changed_stage = pipeline_stage + 1 + if changed_stage == len( + pipeline_process_meshes) or random_selected_op_idx == 0 or \ + (new_valid_dist_attr_dict[ops[random_selected_op_idx - 1].desc.id()][1] == pipeline_stage - 1 and (random_selected_op_idx == 1)): + new_dist_context.set_op_dist_attr_for_program( + selected_op, selected_op_dist_attr) + self.set_tensor_dist_attr(selected_op, + selected_op_dist_attr, vars, + new_dist_context) + + else: + selected_op_process_mesh = pipeline_process_meshes[ + pipeline_stage] + pre_op_id = ops[random_selected_op_idx - 1].desc.id() + if new_valid_dist_attr_dict[pre_op_id][ + 1] == pipeline_stage - 1 and random_selected_op_idx != 1: + new_valid_dist_attr_dict[pre_op_id][1] = pipeline_stage + for op_dist_attr in new_valid_dist_attr_dict[pre_op_id][ + 0]: + op_dist_attr.process_mesh = selected_op_process_mesh + # set pre op dist attr in the discontext and output tensor process mesh + self.change_process_mesh( + ops[random_selected_op_idx - 1], + selected_op_process_mesh, vars, new_dist_context) + + # change the selected op stage and output tensor dist attr + new_valid_dist_attr_dict[selected_op.desc.id()][ + 1] = changed_stage + new_process_mesh = pipeline_process_meshes[changed_stage] + selected_op_dist_attr.process_mesh = new_process_mesh + for op_dist_attr in new_valid_dist_attr_dict[ + selected_op.desc.id()][0]: + op_dist_attr.process_mesh = new_process_mesh + new_dist_context.set_op_dist_attr_for_program( + selected_op, selected_op_dist_attr) + self.set_tensor_dist_attr(selected_op, + selected_op_dist_attr, vars, + new_dist_context) + + # change the next op stage + for idx in range(random_selected_op_idx + 1, len(ops)): + stage = new_valid_dist_attr_dict[ops[idx].desc.id()][1] + valid_dist_attr_list = new_valid_dist_attr_dict[ops[ + idx].desc.id()][0] + new_process_mesh = pipeline_process_meshes[ + changed_stage] + if stage == changed_stage - 1: + new_valid_dist_attr_dict[ops[idx].desc.id()][ + 1] = changed_stage + for op_dist_attr in valid_dist_attr_list: + op_dist_attr.process_mesh = new_process_mesh + + new_dist_context.get_op_dist_attr_for_program(ops[ + idx]).process_mesh = new_process_mesh + # change the output tensor dist attr + self.change_process_mesh(ops[idx], new_process_mesh, + vars, new_dist_context) + else: + break + else: + new_dist_context.set_op_dist_attr_for_program(selected_op, + selected_op_dist_attr) + self.set_tensor_dist_attr(selected_op, selected_op_dist_attr, vars, + new_dist_context) + + for op in ops: + # make softmax_with_cross_entropy unshard + if op.type == "softmax_with_cross_entropy": + self.make_special_op_unshard(op, ops, vars, new_dist_context, + valid_dist_attr_dict) + break + + if new_valid_dist_attr_dict is None: + return valid_dist_attr_dict, new_dist_context + else: + return new_valid_dist_attr_dict, new_dist_context + + def _search_core(self, + valid_dist_attr_dict, + init_dist_context, + pipeline_process_meshes=None): + times = 0 + best_dist_context = init_dist_context + cost = self.estimate_searched_strategy_cost( + init_dist_context, pipeline_process_meshes).runtime + min_cost = cost + while times < self.max_search_times: + times += 1 + new_dist_context = self.search_once( + self.serial_program_info.train_program, valid_dist_attr_dict, + best_dist_context, pipeline_process_meshes)[1] + cur_cost = self.estimate_searched_strategy_cost( + new_dist_context, pipeline_process_meshes).runtime + if (min_cost - cur_cost) > 0: + best_dist_context = copy.deepcopy(new_dist_context) + min_cost = cur_cost + times = 0 + return best_dist_context, min_cost + + def search(self): + logging.info("Start MCMC searching.") + start_time = time.time() + train_program = self.serial_program_info.train_program + cluster = self.serial_program_info.cluster + processes = paddle.distributed.get_world_size( + ) if cluster is None else len(cluster.get_all_devices("GPU")) + assert processes > 0, "Get process failed." + + process_mesh_topology_list = PlanSpace.enum_process_mesh_topology( + processes) + searched_dist_context = None + min_cost = None + + searched_pipeline_dist_context = None + pipeline_min_cost = None + for process_mesh_topology in process_mesh_topology_list: + logging.info( + "MCMC search: search process mesh {} with pipeline mode.". + format(process_mesh_topology)) + valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program( + train_program, process_mesh_topology, True) + init_dist_context = self.init_program( + valid_dist_attr_dict, train_program, pipeline_process_meshes, + global_process_mesh) + best_dist_context, cost = self._search_core(valid_dist_attr_dict, + init_dist_context, + pipeline_process_meshes) + logging.info( + "MCMC search: the min cost is {} in the process mesh {} with pipeline mode.". + format(cost, process_mesh_topology)) + best_dist_context._dist_op_context = DistributedOperatorContext() + pipeline_min_cost = cost if pipeline_min_cost is None else pipeline_min_cost + searched_pipeline_dist_context = best_dist_context if searched_pipeline_dist_context is None else searched_pipeline_dist_context + if pipeline_min_cost > cost: + searched_pipeline_dist_context = best_dist_context + pipeline_min_cost = cost + + searched_non_pipeline_dist_context = None + non_pipeline_min_cost = None + for process_mesh_topology in process_mesh_topology_list: + # if process_mesh_topology shape is 3, include pipeline mode by default + if len(process_mesh_topology) == 3: + continue + logging.info( + "MCMC search: search process mesh {} without pipeline mode.". + format(process_mesh_topology)) + valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program( + train_program, process_mesh_topology, False) + init_dist_context = self.init_program( + valid_dist_attr_dict, train_program, pipeline_process_meshes, + global_process_mesh) + best_dist_context, cost = self._search_core(valid_dist_attr_dict, + init_dist_context, + pipeline_process_meshes) + logging.info( + "MCMC search: the min cost is {} in the process mesh {} without pipeline mode.". + format(cost, process_mesh_topology)) + best_dist_context._dist_op_context = DistributedOperatorContext() + non_pipeline_min_cost = cost if non_pipeline_min_cost is None else non_pipeline_min_cost + searched_non_pipeline_dist_context = best_dist_context if searched_non_pipeline_dist_context is None else searched_non_pipeline_dist_context + if non_pipeline_min_cost > cost: + searched_non_pipeline_dist_context = best_dist_context + non_pipeline_min_cost = cost + + if non_pipeline_min_cost > pipeline_min_cost: + searched_dist_context = searched_pipeline_dist_context + min_cost = pipeline_min_cost + logging.info( + "Better set FLAGS_benchmark=1 to avoid hang problem in the pipeline mode." + ) + else: + searched_dist_context = searched_non_pipeline_dist_context + min_cost = non_pipeline_min_cost + + # rebuild g_process_group + pg0 = get_process_group(0) + for process_mesh in searched_dist_context._process_meshes: + pg0.add_ranks(process_mesh.processes) + end_time = time.time() + logging.info( + "End MCMC searching: the min cost is {} and the search time is {}s.". + format(min_cost, end_time - start_time)) + return searched_dist_context, min_cost + + +class Planner: + def __init__(self, serial_program_info, algorithm_config=None): + self._serial_program_info = serial_program_info + self._algorithm_config = algorithm_config + self._algorithm_searcher = self.create_algorithm_searcher( + algorithm_config) + + @property + def serial_program_info(self): + return self._serial_program_info + + @property + def algorithm_config(self): + return self._algorithm_config + + @property + def algorithm_searcher(self): + return self._algorithm_searcher + + def create_algorithm_searcher(self, algorithm_config): + name = algorithm_config.get("name", None) + assert name is not None, "Invalid algorithm config." + + algorithm_searcher = None + if name == "mcmc": + # NOTE: Only GPU clusters are supported now. + max_search_times = algorithm_config.get("max_search_times", None) + algorithm_searcher = MCMC( + self.serial_program_info, + max_search_times) if max_search_times is not None else MCMC( + self.serial_program_info) + else: + raise NotImplementedError( + "Other search algorithms have not been supported now.") + + return algorithm_searcher + + def search(self): + return self.algorithm_searcher.search() diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index 2e4d370b39435d081d9a4a796f89c0014484e393..fee52e85697dcc69f0f0e4f929c5a63adb503bb9 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -25,9 +25,12 @@ def get_all_process_groups(): return _g_process_group_map.values() -def get_process_group(group_id): +def get_process_group(group_id, g_process_group_map=None): global _g_process_group_map - return _g_process_group_map.get(group_id, None) + return _g_process_group_map.get( + group_id, + None) if g_process_group_map is None else g_process_group_map.get( + group_id, None) def get_world_process_groups(): diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index c9cb4200e36fb67ca7ee654968c4f0f24dff3bb9..3b392d4e088dec3a6887e1b226d475d12c563dd6 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -19,9 +19,11 @@ import threading import numpy as np import warnings import logging +from functools import reduce import paddle.fluid.core as core from paddle.framework.io import _to_LodTensor +from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.fluid.io import is_parameter, is_belong_to_optimizer @@ -1258,3 +1260,95 @@ class SerialProgramInfo: @property def cluster(self): return self._cluster + + +def get_standalone_cost_data(distributed_programs): + def _compute_runtime(op_cost, op, vars): + runtime = 0 + try: + runtime = float(op_cost["op_time"]) + except: + return runtime + op_config = op_cost["config"] + total_static_input_size = 0 + total_actual_input_size = 0 + parsed_info = op_config.split("\n") + variable = "(Variable)" + for info in parsed_info: + variable = "(Variable)" if "(Variable)" in info else "(list" + if variable in info: + arg_name_lower = info[:info.find(variable) - 1] + shape_left_boundary = info.find("[") + shape_right_boundary = info.find("]") + assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed." + shape = info[shape_left_boundary + 1: + shape_right_boundary].split(",") + shape = list(map(lambda x: int(x.strip()), shape)) + dtype_factor = 1 + total_static_input_size += reduce(lambda x, y: x * y, shape) + # print(arg_name_lower) + if op.type == "c_embedding": + arg_name_lower = "w" if arg_name_lower == "weight" else "ids" + for arg_name in op.input_names: + if arg_name.lower() == arg_name_lower: + for var_name in op.input(arg_name): + var = vars[var_name] + total_actual_input_size += reduce( + lambda x, y: x * y, var.shape) + break + assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed." + + actual_runtime = total_actual_input_size / total_static_input_size * runtime + return actual_runtime + + cost_model = paddle.cost_model.CostModel() + cost_model.static_cost_data() + DEFAULT_MULTIPLE = 2 + OP_NAME_MAPPING = { + "c_embedding": "embedding", + "matmul_v2": "matmul", + "transpose2": "transpose", + "reshape2": "reshape", + "unsqueeze2": "unsqueeze", + "reduce_sum": "sum", + "elementwise_div": "divide" + } + + standalone_cost_data = [] + not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"] + for distributed_program in distributed_programs: + cost_data = {} + vars = distributed_program.global_block().vars + for op in distributed_program.global_block().ops: + runtime = 0 + if op.type in not_enum_ops: + cost_data[op.desc.id()] = runtime + continue + dtype = str(vars[op.input_arg_names[0]] + .dtype) if op.input_arg_names else "float32" + if int(op.attr('op_role')) == int(OpRole.Backward): + if "_grad" in op.type: + forward_op_name = op.type[:-5] + if forward_op_name in OP_NAME_MAPPING.keys(): + forward_op_name = OP_NAME_MAPPING[forward_op_name] + op_cost = cost_model.get_static_op_time( + forward_op_name, forward=False, dtype=dtype) + if op_cost: + runtime = _compute_runtime(op_cost, op, vars) + else: + op_cost = cost_model.get_static_op_time( + forward_op_name, dtype=dtype) + if op_cost: + runtime = 2 * _compute_runtime(op_cost, op, vars) + elif int(op.attr('op_role')) == int(OpRole.Forward): + op_name = OP_NAME_MAPPING[ + op.type] if op.type in OP_NAME_MAPPING.keys() else op.type + op_cost = cost_model.get_static_op_time(op_name) + if op_cost: + runtime = _compute_runtime(op_cost, op, vars) + + cost_data[op.desc.id()] = runtime + + standalone_cost_data.append(cost_data) + + return standalone_cost_data diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 0d54a0ea5d3b1620522e097615c3adfd5f94d121..c19ee1e192761a1730445427a52b6f4c1b86eed3 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1436,7 +1436,7 @@ class Fleet(object): context["role_maker"] = self._role_maker # Use the auto-parallel's routines instead - if self._user_defined_strategy.semi_auto: + if self._user_defined_strategy.semi_auto or self._user_defined_strategy.auto_search: from ...auto_parallel.parallelizer import AutoParallelizer auto_parallelizer = AutoParallelizer(self) optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize( @@ -1586,13 +1586,13 @@ class Fleet(object): ] param_grads_fp16 = [ param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and - (param._grad_ivar().dtype == core.VarDesc.VarType.FP16) + if (param._grad_ivar() is not None) and (param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP16) ] param_grads_fp32 = [ param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and - (param._grad_ivar().dtype == core.VarDesc.VarType.FP32) + if (param._grad_ivar() is not None) and (param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP32) ] temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 4244fda0c51d9c582b44b79b659e94a99b179383..9247bf48b35925f46c89daa75e2e9e7981c8474b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -3,4 +3,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS}) set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS}) + set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_with_planner.py b/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_with_planner.py new file mode 100644 index 0000000000000000000000000000000000000000..a93663cb95ed0e90aa0d55cd7ccae5ce711a0f8d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_with_planner.py @@ -0,0 +1,53 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.static as static +from paddle.distributed import fleet + + +def train(): + from auto_parallel_relaunch_model import mlp_pretrain_forward + from auto_parallel_relaunch_model import batch_generator_creator + dist_strategy = fleet.DistributedStrategy() + # init parallel optimizer + dist_strategy.auto_search = True + fleet.init(is_collective=True, strategy=dist_strategy) + train_program = static.Program() + start_program = static.Program() + loss, train_program, start_program, loader = mlp_pretrain_forward( + train_program, start_program) + + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + optimizer = fleet.distributed_optimizer(optimizer) + _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( + loss, start_program) + + places = static.cuda_places() + loader.set_batch_generator(batch_generator_creator(), places=places) + exe = paddle.static.Executor(places[0]) + exe.run(distributed_startup_program) + + for data in loader(): + exe.run(distributed_main_program, feed=data) + + +if __name__ == "__main__": + train() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_relaunch_with_planner.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_relaunch_with_planner.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7ae87e646ada71bcd9ef7c313c7dd16be7df81 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_relaunch_with_planner.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import sys +import json +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestPlannerReLaunch(unittest.TestCase): + def test_relaunch_with_planner(self): + from test_auto_parallel_relaunch import cluster_json + file_dir = os.path.dirname(os.path.abspath(__file__)) + cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") + cluster_json_object = json.loads(cluster_json) + with open(cluster_json_path, "w") as cluster_json_file: + json.dump(cluster_json_object, cluster_json_file) + + launch_model_path = os.path.join( + file_dir, "auto_parallel_relaunch_with_planner.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "launch", "--cluster_topo_path", cluster_json_path, + "--enable_auto_mapping", "True", launch_model_path + ] + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + # Remove unnecessary files + if os.path.exists(cluster_json_path): + os.remove(cluster_json_path) + rank_mapping_json_path = os.path.join(file_dir, + "auto_parallel_rank_mapping.json") + if os.path.exists(rank_mapping_json_path): + os.remove(rank_mapping_json_path) + log_path = os.path.join(file_dir, "log") + if os.path.exists(log_path): + shutil.rmtree(log_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index 4c9c01b99e05052b62e5aeb2a55e6c3a4510c3c6..d58c79dd72cb8929f38c8739ebb93e268ee0fa49 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -187,10 +187,10 @@ def check_empty_program_runtime(cost): def check_empty_program_memory(cost): for mem in cost.peak_mem: - if mem > 0: + if mem > 1: return False for mem in cost.static_mem: - if mem > 0: + if mem > 1: return False return True diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index de37ac56bfbb6314a7c82258c000aaec9e86201e..4fd64dc252bcddd494a9a62e76cefc9c22dec6d9 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -529,7 +529,7 @@ class TestAutoParallelMapper(unittest.TestCase): train_program, startup_program, dist_context, rank_id) # if rank_id == 0: # print_program_with_dist_attr(dist_train_program, dist_context) - dist_programs[rank_id] = dist_train_program + dist_programs[rank_id] = [dist_train_program, None] rank_mapping = mapping(dist_programs, cluster) diff --git a/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py b/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py index 82178e1b62dfbc60387bc84b8b852194f5bdc0df..c9cbcd1ea8efd59a2e9c978001a9086d7de09eb0 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py @@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase): op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) - self.assertTrue(impls[2].is_auto_compatible( + self.assertFalse(impls[2].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) self.assertFalse(impls[2].is_auto_compatible( @@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase): op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) - self.assertTrue(impls[1].is_auto_compatible( + self.assertFalse(impls[1].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) self.assertFalse(impls[1].is_auto_compatible( @@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase): op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1]) - self.assertTrue(impls[0].is_auto_compatible( + self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1]) self.assertFalse(impls[0].is_auto_compatible(