# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import hashlib import itertools import math import time from collections import defaultdict import numpy as np from ...process_mesh import ProcessMesh from ..completion import Completer from ..cost import CostEstimator from ..dist_context import _node_id from ..dist_op import DistributedOperator from ..operators.common import find_compatible_distributed_operator_impls from ..parallelizer_v2 import Parallelizer from .trial import Trial, TrialStatus from .tunable_space import TunableSpace from .tunable_variable import Boolean, IntRange class ParallelTuner: def __init__( self, dist_context, mode="train", max_trials=25, tuner_id=None, seed=None, logger=None, loop_count=10, ): self._loop_count = loop_count self._estimator = None self._dist_context = dist_context assert self._dist_context._is_initialized self._mode = mode self._cluster = self._dist_context.cluster self._num_machines = self._cluster.get_num_machines() self._num_devices_per_machine = ( self._cluster.get_num_devices_per_machine() ) self._space = TunableSpace() self._objective = "time" self._direction = "min" self._max_trials = max_trials self._tuner_id = tuner_id self._seed = seed if seed is not None else 9999 print( "seed", self._seed, "mode", self._mode, "num_machies", self._num_machines, "num_devices_per_machine", self._num_devices_per_machine, flush=True, ) self._seed_state = self._seed self._logger = logger self._max_collisions = 3 self._tried_values = set() self._num_trials = 0 self._rng = np.random.default_rng(self._seed) # Search the op types in the include_op_types, # and will search all op types if it is empty. # Exclude the op types in the exclude_op_types # from the search list. self._exclude_op_types = [] self._include_op_types = [] # The final dist ops will be searched after considering # the include_op_types and exclude_op_types. self._concerned_dist_ops = {} self._op_id_to_dist_attr_candidates = defaultdict(list) self._cached_dims_mapping_candidates = {} self._cached_candidates_info = defaultdict(list) self._special_ops = [ "create_py_reader", "create_double_buffer_reader", "read", "while", "read_from_array", "write_to_array", ] # Each parallel strategy has two elements. The First one is for distributed tensors, # the second element is for distributed tensors, the third element is for process meshes. self._init_parallel_strategy = [None, None, None] self._best_parallel_strategy = [None, None, None] self._completer = Completer(self._dist_context) self._parallelizer = Parallelizer( self._mode, self._completer, self._dist_context ) def _generate_combination( self, elements, target, idx, partial_candidate, candidates, num_candidates=None, ): if target == 0: candidates.append(copy.deepcopy(partial_candidate)) return if ( target < 0 or idx == len(elements) or len(candidates) > num_candidates ): return # Use partial_candidate.append(elements[idx]) self._generate_combination( elements, target - elements[idx], idx, partial_candidate, candidates, num_candidates, ) # Not use partial_candidate.pop() self._generate_combination( elements, target, idx + 1, partial_candidate, candidates, num_candidates, ) def _permute_combination( self, combination, target, check, partial_candidate, candidates, num_candidates=None, skip_prob=None, ): if num_candidates is not None and len(candidates) == num_candidates: return if len(partial_candidate) == len(combination): candidates.append(partial_candidate) return for i in range(len(combination)): if check[i] == 1: continue if self._rng.choice([True, False], p=[skip_prob, 1 - skip_prob]): continue if ( i > 0 and combination[i] == combination[i - 1] and check[i - 1] == 0 ): continue check[i] = 1 self._permute_combination( combination, target, check, partial_candidate + [combination[i]], candidates, num_candidates, skip_prob, ) check[i] = 0 def _partition_number(self, target): log2_target = int(math.log2(target)) elements = [pow(2, i) for i in range(log2_target)] if pow(2, log2_target) == target: elements.append(target) seed_candidates = [] num_seed_candidates = 1000 partial_results = [] self._generate_combination( elements, target, 0, partial_results, seed_candidates, num_seed_candidates, ) candidates = [] for seed_candidate in seed_candidates: cur_candidates = [] num_cur_candidates = 16 seed_candidate.sort() check = [0 for i in range(len(seed_candidate))] if target <= 8: skip_prob = 0.0 else: skip_prob = len(seed_candidate) / target self._permute_combination( seed_candidate, target, check, [], cur_candidates, num_cur_candidates, skip_prob, ) candidates.extend(cur_candidates) return candidates def _partition_devices(self, num_machines, num_devices_per_machine): inter_node_partitions = self._partition_number(num_machines) intra_node_partitions = self._partition_number(num_devices_per_machine) return inter_node_partitions, intra_node_partitions def _generate_process_mesh_list( self, inter_node_partition, intra_node_partition ): process_mesh_list = [] start_row = 0 start_col = 0 for m in inter_node_partition: start_col = 0 for n in intra_node_partition: process_mesh = [] for p in range(m): start = ( start_row + p ) * self._num_devices_per_machine + start_col tmp = [] for q in range(n): tmp.append(start + q) process_mesh.append(tmp) process_mesh_list.append(copy.deepcopy(process_mesh)) start_col += n start_row += m return process_mesh_list def _generate_dims_mapping_candidates_helper( self, dims_mapping, dims_list, start, visited, candidates ): if start == len(dims_mapping) or all(visited): candidates.append(copy.deepcopy(dims_mapping)) return for idx, dim in enumerate(dims_list): if not visited[idx]: dims_mapping[start] = dim visited[idx] = True self._generate_dims_mapping_candidates_helper( dims_mapping, dims_list, start + 1, visited, candidates ) visited[idx] = False dims_mapping[start] = -1 self._generate_dims_mapping_candidates_helper( dims_mapping, dims_list, start + 1, visited, candidates ) def _generate_dims_mapping_candidates( self, dims_mapping_len, process_mesh_len ): assert dims_mapping_len >= 1 and process_mesh_len >= 1 key = (dims_mapping_len, process_mesh_len) if key in self._cached_dims_mapping_candidates: return self._cached_dims_mapping_candidates[key] candidates = [] dims_mapping = [-1 for i in range(dims_mapping_len)] dims_list = list(range(process_mesh_len)) visited = [False for i in range(process_mesh_len)] self._generate_dims_mapping_candidates_helper( dims_mapping, dims_list, 0, visited, candidates ) self._cached_dims_mapping_candidates[key] = candidates return candidates def _generate_dist_attr_candidates(self, op_id, dist_op): # For now, only allow the process meshes have two dimensions process_mesh_len = 2 serial_op = dist_op.serial_op op_dist_attr = dist_op.dist_attr if serial_op.type in self._special_ops: return [copy.deepcopy(op_dist_attr)] key = [] key.append(serial_op.type) for input_name in serial_op.input_names: key.append(input_name) for input_arg_name in serial_op.input(input_name): key.append( len(op_dist_attr.get_input_dims_mapping(input_arg_name)) ) for output_name in serial_op.output_names: key.append(output_name) for output_arg_name in serial_op.output(output_name): key.append( len(op_dist_attr.get_output_dims_mapping(output_arg_name)) ) key = tuple(key) if key in self._cached_candidates_info: cached_dist_attr_candidates = [] cached_input_arg_names = self._cached_candidates_info[key][0] cached_output_arg_names = self._cached_candidates_info[key][1] for cached_dist_attr in self._cached_candidates_info[key][2]: new_op_dist_attr = copy.deepcopy(dist_op.dist_attr) i = 0 for input_name in serial_op.input_names: for input_arg_name in serial_op.input(input_name): cached_dims_mapping = ( cached_dist_attr.get_input_dims_mapping( cached_input_arg_names[i] ) ) new_op_dist_attr.set_input_dims_mapping( input_arg_name, cached_dims_mapping ) i += 1 i = 0 for output_name in serial_op.output_names: for output_arg_name in serial_op.output(output_name): cached_dims_mapping = ( cached_dist_attr.get_output_dims_mapping( cached_output_arg_names[i] ) ) new_op_dist_attr.set_output_dims_mapping( output_arg_name, cached_dims_mapping ) i += 1 cached_dist_attr_candidates.append(new_op_dist_attr) return cached_dist_attr_candidates # cached_candidates_info = [] input_arg_names = [] for input_name in serial_op.input_names: for input_arg_name in serial_op.input(input_name): input_arg_names.append(input_arg_name) self._cached_candidates_info[key].append(input_arg_names) # cached_candidates_info.append(input_arg_names) output_arg_names = [] for output_name in serial_op.output_names: for output_arg_name in serial_op.output(output_name): output_arg_names.append(output_arg_name) self._cached_candidates_info[key].append(output_arg_names) # cached_candidates_info.append(output_arg_names) new_op_dist_attr = copy.deepcopy(dist_op.dist_attr) # Find valid dims_mapping candidates for inputs input_names = [] dims_mapping_generated = [] inputs_dist_attrs = op_dist_attr.inputs_dist_attrs for tensor_name, tensor_dist_attr in inputs_dist_attrs.items(): original_dims_mapping = tensor_dist_attr.dims_mapping dims_mapping_len = len(original_dims_mapping) input_names.append(tensor_name) if dims_mapping_len < 1: dims_mapping_generated.append( [copy.deepcopy(original_dims_mapping)] ) else: dims_mapping_generated.append( self._generate_dims_mapping_candidates( dims_mapping_len, process_mesh_len ) ) input_dims_mapping_candidates = [] for dims_mapping_list in itertools.product(*dims_mapping_generated): dims_mapping_list = list(dims_mapping_list) assert len(dims_mapping_list) == len(input_names) for i, dims_mapping in enumerate(dims_mapping_list): new_op_dist_attr.set_input_dims_mapping( input_names[i], dims_mapping ) new_dist_op = DistributedOperator( dist_op.serial_op, new_op_dist_attr ) dist_op_impls = find_compatible_distributed_operator_impls( new_dist_op, fwd=True ) if dist_op_impls is not None: input_dims_mapping_candidates.append(dims_mapping_list) # Find valid dims_mapping candidates for outputs output_names = [] dims_mapping_generated = [] outputs_dist_attrs = op_dist_attr.outputs_dist_attrs for tensor_name, tensor_dist_attr in outputs_dist_attrs.items(): original_dims_mapping = tensor_dist_attr.dims_mapping dims_mapping_len = len(original_dims_mapping) output_names.append(tensor_name) if dims_mapping_len < 1: dims_mapping_generated.append( [copy.deepcopy(original_dims_mapping)] ) else: dims_mapping_generated.append( self._generate_dims_mapping_candidates( dims_mapping_len, process_mesh_len ) ) output_dims_mapping_candidates = [] for dims_mapping_list in itertools.product(*dims_mapping_generated): dims_mapping_list = list(dims_mapping_list) assert len(dims_mapping_list) == len(output_names) for i, dims_mapping in enumerate(dims_mapping_list): new_op_dist_attr.set_output_dims_mapping( output_names[i], dims_mapping ) new_dist_op = DistributedOperator( dist_op.serial_op, new_op_dist_attr ) dist_op_impls = find_compatible_distributed_operator_impls( new_dist_op, fwd=False ) if dist_op_impls is not None: output_dims_mapping_candidates.append(dims_mapping_list) if not input_dims_mapping_candidates and output_dims_mapping_candidates: inout_dims_mapping_generated = [ [[[-2]]], output_dims_mapping_candidates, ] elif ( input_dims_mapping_candidates and not output_dims_mapping_candidates ): inout_dims_mapping_generated = [ input_dims_mapping_candidates, [[[-2]]], ] elif ( not input_dims_mapping_candidates and not output_dims_mapping_candidates ): inout_dims_mapping_generated = [[[[-2]]], [[[-2]]]] else: inout_dims_mapping_generated = [ input_dims_mapping_candidates, output_dims_mapping_candidates, ] # Find valid dims_mapping generated for both inputs and outputs cached_dist_attr_candidates = [] for inout_dims_mapping_list in itertools.product( *inout_dims_mapping_generated ): assert len(inout_dims_mapping_list) == 2 if input_dims_mapping_candidates: assert len(inout_dims_mapping_list[0]) == len(input_names) if output_dims_mapping_candidates: assert len(inout_dims_mapping_list[1]) == len(output_names) # set the dims_mappings for inputs for i, dims_mapping in enumerate(inout_dims_mapping_list[0]): if dims_mapping != [-2]: new_op_dist_attr.set_input_dims_mapping( input_names[i], dims_mapping ) # set the dims_mappings for outputs for i, dims_mapping in enumerate(inout_dims_mapping_list[1]): if dims_mapping != [-2]: new_op_dist_attr.set_output_dims_mapping( output_names[i], dims_mapping ) new_dist_op = DistributedOperator( dist_op.serial_op, new_op_dist_attr ) dist_op_impls = find_compatible_distributed_operator_impls( new_dist_op, partial=False ) if dist_op_impls is None: continue for dist_op_impl in dist_op_impls: new_op_dist_attr.impl_type = dist_op_impl.type new_op_dist_attr.impl_idx = dist_op_impl.idx cached_dist_attr_candidates.append( copy.deepcopy(new_op_dist_attr) ) self._cached_candidates_info[key].append(cached_dist_attr_candidates) return self._cached_candidates_info[key][2] def construct_space(self): inter_node_partitions, intra_node_partitions = self._partition_devices( self._num_machines, self._num_devices_per_machine ) self._space.choice( "inter_node_partitions", inter_node_partitions, default=inter_node_partitions[0], ) self._space.choice( "intra_node_partitions", intra_node_partitions, default=intra_node_partitions[0], ) dist_ops = self._dist_context._dist_ops_for_program for op_id, dist_op in dist_ops.items(): op_type = dist_op.serial_op.type if self._include_op_types: if op_type in self._include_op_types: self._concerned_dist_ops[op_id] = dist_op else: self._concerned_dist_ops[op_id] = dist_op for op_id, dist_op in self._concerned_dist_ops.items(): op_type = dist_op.serial_op.type if op_type in self._exclude_op_types: del self._concerned_dist_ops[op_id] print( "Number of the concerned dist ops", len(self._concerned_dist_ops), flush=True, ) search_space = 1 for op_id, dist_op in self._concerned_dist_ops.items(): op_dist_attr_candidates = self._generate_dist_attr_candidates( op_id, dist_op ) search_space *= len(op_dist_attr_candidates) self._space.choice( str(op_id), op_dist_attr_candidates, default=op_dist_attr_candidates[0], ) def _compute_values_hash(self, values): keys = sorted(values.keys()) s = "".join(str(k) + "=" + str(values[k]) for k in keys) return hashlib.sha256(s.encode("utf-8")).hexdigest()[:32] def _random_values(self): space = TunableSpace() collisions = 0 while True: for v in self._space.variables.values(): space._register(v) space.values[v.name] = v.random(self._seed_state) self._seed_state += 1 values = space.values values_hash = self._compute_values_hash(values) if values_hash in self._tried_values: collisions += 1 if collisions > self._max_collisions: return None continue self._tried_values.add(values_hash) break return values def _populate_space(self): values = self._random_values() if values is None: return {"status": TrialStatus.STOPPED, "values": None} return {"status": TrialStatus.RUNNING, "values": values} def _create_trial(self): trial_id = f"{{:0{len(str(self._max_trials))}d}}" trial_id = trial_id.format(self._num_trials) if self._max_trials and self._num_trials >= self._max_trials: status = TrialStatus.STOPPED values = None else: results = self._populate_space() status = results["status"] values = results["values"] space = TunableSpace() space.variables = self._space.variables space.values = values trial = Trial(tunable_space=space, trial_id=trial_id, status=status) self._num_trials += 1 return trial def _generate_pipeline_starts(self, process_mesh_list): total_ops = len(self._dist_context._dist_ops_for_program) total_stages = len(process_mesh_list) ops_per_stage = total_ops // total_stages if ops_per_stage == 0: return None # Compute the initial pipeline starts pipeline_starts = [] start = 0 pipeline_starts.append(0) # The pipeline_starts have total_stages+1 items, and # at least have 2 items. for _ in process_mesh_list: start += ops_per_stage pipeline_starts.append(start) pipeline_starts[-1] = total_ops # Adjust the pipeline starts by random selection directions = [] sizes = [] half_ops_per_stage = ops_per_stage // 2 if half_ops_per_stage > 0 and total_stages > 1: new_pipeline_starts = [] # Don't change the first start new_pipeline_starts.append(0) # Consider the starts except the first and the last one for _ in pipeline_starts[1:-1]: directions.append(Boolean("direction")) sizes.append( IntRange( "size", start=0, stop=half_ops_per_stage, endpoint=True ) ) for i, start in enumerate(pipeline_starts[1:-1]): direction = directions[i].random(self._seed) size = sizes[i].random(self._seed) if direction: # Subtract 1 from size to avoid the overlapping of new starts new_start = start - (size - 1) else: new_start = start + size new_pipeline_starts.append(new_start) # Don't change the last start new_pipeline_starts.append(pipeline_starts[-1]) # Validate the new starts print( "Adjusted pipeline starts", new_pipeline_starts, half_ops_per_stage, pipeline_starts, flush=True, ) for i, new_start in enumerate(new_pipeline_starts[1:]): assert new_start > new_pipeline_starts[i] return new_pipeline_starts else: print( "Non-adjusted pipeline starts", pipeline_starts, half_ops_per_stage, flush=True, ) return pipeline_starts def _apply_pipeline_partition(self, process_mesh_list): op_id_to_process_mesh = {} total_ops = len(self._dist_context._dist_ops_for_program) total_stages = len(process_mesh_list) ops_per_stage = total_ops // total_stages if ops_per_stage == 0: return None pipeline_starts = self._generate_pipeline_starts(process_mesh_list) start_idx = 1 sorted_op_ids = sorted(self._dist_context._dist_ops_for_program.keys()) for idx, op_id in enumerate(sorted_op_ids): if idx < pipeline_starts[start_idx]: op_id_to_process_mesh[op_id] = process_mesh_list[start_idx - 1] else: start_idx += 1 op_id_to_process_mesh[op_id] = process_mesh_list[start_idx - 1] return op_id_to_process_mesh def _amend_dist_attr(self): # 1) Reshape the process mesh of [1, x] to [x] or [x, 1] to [x], # and amend the corresponding dims_mapping. # 2) Set the dim_mapping to -1 when the shape cannot be divided # by the corresponding processes. for dist_op in self._dist_context._dist_ops_for_program.values(): dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh if process_mesh is None: continue assert process_mesh.ndim == 2 dim_of_one = None dim_of_other = None if process_mesh.shape[0] == 1: dim_of_one = 0 dim_of_other = 1 elif process_mesh.shape[1] == 1: dim_of_one = 1 dim_of_other = 0 if dim_of_one is not None: dist_attr.process_mesh = ProcessMesh(process_mesh.process_ids) self._dist_context.add_process_mesh(dist_attr.process_mesh) for arg_name in dist_attr.inputs_dist_attrs.keys(): new_dims_mapping = [] dims_mapping = dist_attr.get_input_dims_mapping(arg_name) for dim_mapping in dims_mapping: if dim_mapping == dim_of_one: new_dims_mapping.append(-1) elif dim_mapping == dim_of_other: new_dims_mapping.append(0) else: new_dims_mapping.append(dim_mapping) dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) dims_mapping = dist_attr.get_input_dims_mapping(arg_name) # dynamic_dims = dist_attr.get_input_dynamic_dims(arg_name) process_mesh = dist_attr.process_mesh process_shape = process_mesh.shape tensor = dist_op.get_serial_input(arg_name) if dims_mapping: tensor_shape = tensor.shape else: continue for i, dim_mapping in enumerate(dims_mapping): # if dim_mapping != -1 \ # and (tensor_shape[i] % process_shape[dim_mapping] != 0 \ # or dynamic_dims[i] == 1): if dim_mapping != -1 and ( tensor_shape[i] % process_shape[dim_mapping] != 0 ): dims_mapping[i] = -1 # it is a fix-bug if dim_mapping != -1 and process_shape[dim_mapping] == 1: dims_mapping[i] = -1 for arg_name in dist_attr.outputs_dist_attrs.keys(): new_dims_mapping = [] dims_mapping = dist_attr.get_output_dims_mapping(arg_name) for dim_mapping in dims_mapping: if dim_mapping == dim_of_one: new_dims_mapping.append(-1) elif dim_mapping == dim_of_other: new_dims_mapping.append(0) else: new_dims_mapping.append(dim_mapping) dist_attr.set_output_dims_mapping(arg_name, new_dims_mapping) dims_mapping = dist_attr.get_output_dims_mapping(arg_name) # dynamic_dims = dist_attr.get_output_dynamic_dims(arg_name) process_mesh = dist_attr.process_mesh process_shape = process_mesh.shape tensor = dist_op.get_serial_output(arg_name) if dims_mapping: tensor_shape = tensor.shape else: continue for i, dim_mapping in enumerate(dims_mapping): if dim_mapping != -1 and ( tensor_shape[i] % process_shape[dim_mapping] != 0 ): dims_mapping[i] = -1 # it is a fix-bug if dim_mapping != -1 and process_shape[dim_mapping] == 1: dims_mapping[i] = -1 dist_op_impls = find_compatible_distributed_operator_impls( dist_op, partial=False ) serial_op_type = dist_op.serial_op.type if dist_op_impls is not None and ( serial_op_type != "fused_softmax_mask_upper_triangle" or self._check_fused_softmax_mask_upper_triangle(dist_op) ): dist_op.dist_attr.impl_type = dist_op_impls[0].type dist_op.dist_attr.impl_idx = dist_op_impls[0].idx else: # Use the default dist op impl for arg_name in dist_attr.inputs_dist_attrs.keys(): dims_mapping = dist_attr.get_input_dims_mapping(arg_name) for i, _ in enumerate(dims_mapping): dims_mapping[i] = -1 for arg_name in dist_attr.outputs_dist_attrs.keys(): dims_mapping = dist_attr.get_output_dims_mapping(arg_name) for i, _ in enumerate(dims_mapping): dims_mapping[i] = -1 dist_op.dist_attr.impl_type = "default" dist_op.dist_attr.impl_idx = 0 def _check_fused_softmax_mask_upper_triangle(self, dist_op): """The last_but_one dim should be equal to last dim.""" input_name = dist_op.serial_op.input_arg_names[0] input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_name ) topology = dist_op.dist_attr.process_mesh.shape input_tensor = dist_op.get_serial_input(input_name) last_but_one_dim = ( input_tensor.shape[-2] // topology[input_dims_mapping[-2]] if input_dims_mapping[-2] != -1 else input_tensor.shape[-2] ) last_dim = ( input_tensor.shape[-1] // topology[input_dims_mapping[-1]] if input_dims_mapping[-1] != -1 else input_tensor.shape[-1] ) if last_but_one_dim == last_dim: return True return False def _eval_trial(self, trial): if self._num_trials == 0: num_prev_trials = 0 else: num_prev_trials = self._num_trials - 1 results = None start_time = time.time() inter_node_partition = trial.space.values["inter_node_partitions"] intra_node_partition = trial.space.values["intra_node_partitions"] process_mesh_list = self._generate_process_mesh_list( inter_node_partition, intra_node_partition ) print("\tprocess_mesh list", process_mesh_list, flush=True) op_id_to_process_mesh = self._apply_pipeline_partition( process_mesh_list ) if op_id_to_process_mesh is None: print("Operators are less than pipeline stages", flush=True) return results op_id_to_dist_attr = {} for name, value in trial.space.values.items(): if ( name != "inter_node_partitions" and name != "intra_node_partitions" ): op_id_to_dist_attr[int(name)] = value end_time = time.time() cur_sample_time = end_time - start_time self._sample_time = ( num_prev_trials * self._sample_time + cur_sample_time ) / self._num_trials print( "\tsample_time", num_prev_trials, self._num_trials, self._sample_time, cur_sample_time, flush=True, ) assert len(op_id_to_process_mesh) == len(op_id_to_dist_attr) start_time = time.time() for op_id, process_mesh in op_id_to_process_mesh.items(): dist_op = self._dist_context._dist_ops_for_program[op_id] dist_op.dist_attr = copy.deepcopy(op_id_to_dist_attr[op_id]) assert ( dist_op.dist_attr.impl_type == op_id_to_dist_attr[op_id].impl_type ) assert ( dist_op.dist_attr.impl_idx == op_id_to_dist_attr[op_id].impl_idx ) dist_op.dist_attr.process_mesh = ProcessMesh(process_mesh) self._amend_dist_attr() self._completer._complete_tensor_dist_attr_by_op() self._dist_context.block_state.parse_forward_blocks( self._dist_context.serial_main_program ) end_time = time.time() cur_complete_time = end_time - start_time self._complete_time = ( num_prev_trials * self._complete_time + cur_complete_time ) / self._num_trials print( "\tcomplete_time", num_prev_trials, self._num_trials, self._complete_time, cur_complete_time, flush=True, ) start_time = time.time() estimate_time = self._estimate_trial() end_time = time.time() cur_estimate_time = end_time - start_time self._estimate_time = ( num_prev_trials * self._estimate_time + cur_estimate_time ) / self._num_trials print( "\testimate_time", num_prev_trials, self._num_trials, self._estimate_time, cur_estimate_time, estimate_time, flush=True, ) results = {"estimate_time": estimate_time} return results def _update_trail(self, trial, metrics, step=0): for metric_name, metric_value in metrics.items(): trial.recorder.update(metric_name, metric_value, step=step) return trial.status def _estimate_trial(self): assert self._cluster is not None if self._mode == "eval": self._estimator = CostEstimator( self._dist_context.serial_main_program, self._cluster, loop_count=self._loop_count, ) elif self._mode == "predict": self._estimator = CostEstimator( self._dist_context.serial_main_program, self._cluster, loop_count=self._loop_count, ) elif self._mode == "train": # get serial main program with backward serial_main_program = self._dist_context.serial_main_program serial_startup_program = self._dist_context.serial_startup_program serial_optimizer = self._dist_context.serial_optimizer # Generate backward serial_loss = self._dist_context.serial_fetch_vars["loss"][0] params_grads = self._parallelizer._generate_backward( serial_main_program, serial_startup_program, serial_loss ) # Generate optimizer optimizer_ops = self._parallelizer._generate_optimizer( serial_main_program, serial_startup_program, serial_optimizer, params_grads, ) self._estimator = CostEstimator( serial_main_program, self._cluster, loop_count=self._loop_count ) max_memory = self._estimator._estimate_max_memory_by_dist_op( self._dist_context ) print("\tmax_memory", f"{max_memory:,}", flush=True) # The max memory must be less than 80% 32GB (hard code) if max_memory > 32 * 0.8 * 1024 * 1024 * 1024: return math.inf else: global_cost = self._estimator.estimate(self._dist_context) return global_cost.time def _store_init_parallel_strategy(self): # If there is no annotation information, use the dp as the initial parallel strategy. # TODO: we should need a better way to set up the initial parallel strategy. if ( not self._dist_context.has_annotation or not self._dist_context.process_meshes ): ranks = self._num_machines * self._num_devices_per_machine tensor_node = self._dist_context._serial_ordered_tensor_nodes[0] tensor_node_id = _node_id(tensor_node) tensor = self._dist_context._dist_tensors_for_graph[ tensor_node_id ].serial_tensor tensor_dist_attr = self._dist_context._dist_tensors_for_graph[ tensor_node_id ].dist_attr tensor_dist_attr.process_mesh = ProcessMesh(list(range(ranks))) self._dist_context._process_meshes.append( tensor_dist_attr.process_mesh ) tensor_dist_attr.dims_mapping = [0] + [ -1 for _ in range(len(tensor.shape) - 1) ] tensor_dist_attr.mark_annotated("process_mesh") tensor_dist_attr.mark_annotated("dims_mapping") print("Use dp as the init parallel strategy!", flush=True) # Do the sharding propagation self._completer.complete_forward_annotation() self._dist_context.block_state.parse_forward_blocks( self._dist_context.serial_main_program ) # Backup the initial parallel strategy self._init_parallel_strategy[0] = copy.deepcopy( self._dist_context._dist_tensors_for_program ) self._init_parallel_strategy[1] = copy.deepcopy( self._dist_context._dist_ops_for_program ) self._init_parallel_strategy[2] = copy.deepcopy( self._dist_context.process_meshes ) # Initialize the best parallel strategy to the initial one self._best_parallel_strategy[0] = copy.deepcopy( self._dist_context._dist_tensors_for_program ) self._best_parallel_strategy[1] = copy.deepcopy( self._dist_context._dist_ops_for_program ) self._best_parallel_strategy[2] = copy.deepcopy( self._dist_context._process_meshes ) def _store_best_parallel_strategy(self): # Swap the best and the current parallel strategy tmp = [None, None, None] tmp[0] = self._best_parallel_strategy[0] tmp[1] = self._best_parallel_strategy[1] tmp[2] = self._best_parallel_strategy[2] self._best_parallel_strategy[ 0 ] = self._dist_context._dist_tensors_for_program self._best_parallel_strategy[ 1 ] = self._dist_context._dist_ops_for_program self._best_parallel_strategy[2] = self._dist_context._process_meshes self._dist_context._dist_tensors_for_program = tmp[0] self._dist_context._dist_ops_for_program = tmp[1] self._dist_context._process_meshes = tmp[2] def tune(self): global_start_time = time.time() self._dist_context._backup(serial=True, dist=True) # This store statement must follow the above backup statement self._store_init_parallel_strategy() init_time = self._estimate_trial() # estimate_trial when init # We have to restore the distributed context, because the estimation of one trail need to # generate the backward and update parts. Since we will do the tuning process, # here we only need to reset all distributed information to the default one. self._dist_context._restore( serial=True, serial_mode="to_backup", dist=True, dist_mode="to_default", ) best_time = init_time start_time = time.time() self.construct_space() end_time = time.time() print( "construct_space time", self._num_trials, end_time - start_time, flush=True, ) create_trial_time = 0.0 eval_trial_time = 0.0 self._sample_time = 0.0 self._complete_time = 0.0 self._estimate_time = 0.0 while True: start_time = time.time() trial = self._create_trial() if self._num_trials == 0: num_prev_trials = 0 else: num_prev_trials = self._num_trials - 1 end_time = time.time() cur_create_trial_time = end_time - start_time create_trial_time = ( num_prev_trials * create_trial_time + cur_create_trial_time ) / self._num_trials print( "create_trial time", num_prev_trials, self._num_trials, create_trial_time, cur_create_trial_time, flush=True, ) if trial.status == TrialStatus.STOPPED: break # We need to backup the distributed context, because the evaluation of one trail will # generate the backward and update parts which may change the context. # However, the distributed information of the context aren't backup since a new one is used. self._dist_context._backup(serial=True, dist=False) start_time = time.time() results = self._eval_trial(trial) end_time = time.time() cur_eval_trial_time = end_time - start_time eval_trial_time = ( num_prev_trials * eval_trial_time + cur_eval_trial_time ) / self._num_trials print( "eval_trial time", num_prev_trials, self._num_trials, eval_trial_time, cur_eval_trial_time, "\n", flush=True, ) cur_time = results["estimate_time"] if cur_time < best_time: self._update_trail(trial, results) self._store_best_parallel_strategy() best_time = cur_time # We need to restore the distributed context and reset the distributed information to the default. self._dist_context._restore( serial=True, serial_mode="to_backup", dist=True, dist_mode="to_default", ) # Select the best parallel strategy self._dist_context._dist_tensors_for_program = ( self._best_parallel_strategy[0] ) self._dist_context._dist_ops_for_program = self._best_parallel_strategy[ 1 ] self._dist_context._process_meshes = self._best_parallel_strategy[2]