From 3a01478381bad3fa8eb3d7d7c325d73a9428fca3 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Mon, 7 Nov 2022 14:08:30 +0800 Subject: [PATCH] update the split logic for uniform (#47670) (#47705) * code format change * update the split logic for uniform (#47670) --- .../parallel_layers/pp_layers.py | 351 +++++++++++------- 1 file changed, 221 insertions(+), 130 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 8cd770698f9..0e14d141238 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -57,7 +57,6 @@ __all__ = [] class LayerDesc(object): - def __init__(self, layer_func, *inputs, **kwargs): self.layer_func = layer_func self.inputs = inputs @@ -65,25 +64,28 @@ class LayerDesc(object): if not issubclass(layer_func, Layer): raise TypeError( - "The input(layer_func) should be a derived class of Layer.") + "The input(layer_func) should be a derived class of Layer." + ) def build_layer(self): return self.layer_func(*self.inputs, **self.kwargs) def __repr__(self): - return layer_to_str(self.layer_func.__name__, *self.inputs, - **self.kwargs) + return layer_to_str( + self.layer_func.__name__, *self.inputs, **self.kwargs + ) class SharedLayerDesc(LayerDesc): - - def __init__(self, - key, - layer_func, - forward_func=None, - shared_weight_attr='weight', - *inputs, - **kwargs): + def __init__( + self, + key, + layer_func, + forward_func=None, + shared_weight_attr='weight', + *inputs, + **kwargs + ): super(SharedLayerDesc, self).__init__(layer_func, *inputs, **kwargs) self.layer_name = key self.forward_func = forward_func @@ -91,12 +93,13 @@ class SharedLayerDesc(LayerDesc): class SegmentLayers(object): - - def __init__(self, - layers_desc, - num_parts, - method="uniform", - num_virtual_pipeline_stage=None): + def __init__( + self, + layers_desc, + num_parts, + method="uniform", + num_virtual_pipeline_stage=None, + ): self._layers_desc = layers_desc self.method = method self.num_parts = num_parts @@ -104,7 +107,9 @@ class SegmentLayers(object): self.num_virtual_pipeline_stage = num_virtual_pipeline_stage if self.num_virtual_pipeline_stage is not None: self.total_parts = num_parts * self.num_virtual_pipeline_stage - assert self.num_items >= self.num_parts, "layer number should be greater than number of segments" + assert ( + self.num_items >= self.num_parts + ), "layer number should be greater than number of segments" def do_segment(self): if self.method == "uniform": @@ -118,12 +123,17 @@ class SegmentLayers(object): for idx in weight_idxs: weights[idx] = 1 - actual_num_parts = self.num_parts if self.num_virtual_pipeline_stage is None else self.total_parts - - assert sum( - weights - ) % actual_num_parts == 0, "number of layers ({}) should be divided by part number({})".format( - sum(weights), actual_num_parts) + actual_num_parts = ( + self.num_parts + if self.num_virtual_pipeline_stage is None + else self.total_parts + ) + + assert ( + sum(weights) % actual_num_parts == 0 + ), "number of layers ({}) should be divided by part number({})".format( + sum(weights), actual_num_parts + ) part_size = sum(weights) // actual_num_parts result = [0 for _ in range(actual_num_parts + 1)] @@ -156,21 +166,23 @@ class SegmentLayers(object): if regex.search(name): weight_idxs.append(idx) - assert len( - weight_idxs) > 0, "weight_idxs' length should be greater than 0" + assert ( + len(weight_idxs) > 0 + ), "weight_idxs' length should be greater than 0" return weight_idxs def uniform(self, num_items, num_parts): result = [0 for _ in range(num_parts + 1)] part_size = math.floor(num_items / num_parts) - for i in range(num_parts): - result[i] = int(min(part_size * i, num_items)) + extra_layers = num_items % num_parts + for i in range(1, num_parts): + offset = 1 if i > (num_parts - extra_layers) else 0 + result[i] = int(min(result[i - 1] + part_size + offset, num_items)) result[num_parts] = num_items return result class PipelineLayerChunk(Layer): - def __init__(self): super(PipelineLayerChunk, self).__init__() self.run_function = [] @@ -192,18 +204,19 @@ class PipelineLayerChunk(Layer): # behavior under recompute circumstance. raise PermissionError( "The forward function of PipelineLayerChunk cannot be called directly. " - "Please call forward function of PipelineLayer.") + "Please call forward function of PipelineLayer." + ) class PipelineLayer(Layer): """PipelineLayer Args: layers(Iterable): A sequence of layers description to define the structure for pipeline. - num_stages(int, optional): pp degree, if not specified, 'topology' parameter must be given. + num_stages(int, optional): pp degree, if not specified, 'topology' parameter must be given. topology(CommunicateTopology, optional): topo of hybrid parallel, if it is None, 'num_stages' parameters must be given. loss_fn(callable, optional): Loss function. seg_method(str, optional): the method of splitting pp layer, default 'uniform', or use specific layer to split, method's name must be start with 'layer:'. - recompute_interval(int, optional): the number of layers to be used recompute, the value of 0 represents no recompute. default 0. + recompute_interval(int, optional): the number of layers to be used recompute, the value of 0 represents no recompute. default 0. recompute_ctx(dict,optional): the context of recompute, when 'recompute_interval' > 0, the context must be given. num_virtual_pipeline_stages(int, optional): the num of virtual pipeline stages for interleave pp. Examples: @@ -213,7 +226,7 @@ class PipelineLayer(Layer): from paddle.fluid.dygraph.layers import Layer import paddle.nn.functional as F from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer - + pipeline_parallel_size = 2 strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { @@ -225,19 +238,19 @@ class PipelineLayer(Layer): "accumulate_steps": 4, "micro_batch_size": 2 } - + fleet.init(is_collective=True, strategy=strategy) - + hcg = fleet.get_hybrid_communicate_group() - + class ReshapeHelp(Layer): def __init__(self, shape): super(ReshapeHelp, self).__init__() self.shape = shape - + def forward(self, x): return x.reshape(shape=self.shape) - + class AlexNetPipeDesc(PipelineLayer): def __init__(self, num_classes=10, **kwargs): self.num_classes = num_classes @@ -269,37 +282,46 @@ class PipelineLayer(Layer): ] super(AlexNetPipeDesc, self).__init__( layers=decs, loss_fn=nn.CrossEntropyLoss(), **kwargs) - + model = AlexNetPipeDesc(num_stages=pipeline_parallel_size, topology=hcg._topo) """ - def __init__(self, - layers, - num_stages=None, - topology=None, - loss_fn=None, - seg_method="uniform", - recompute_interval=0, - recompute_ctx=None, - num_virtual_pipeline_stages=None): + def __init__( + self, + layers, + num_stages=None, + topology=None, + loss_fn=None, + seg_method="uniform", + recompute_interval=0, + recompute_ctx=None, + num_virtual_pipeline_stages=None, + ): super(PipelineLayer, self).__init__() if num_stages is None and topology is None: raise ValueError("should provide num_stages or topology") if num_virtual_pipeline_stages: - assert isinstance(num_virtual_pipeline_stages, int), \ - "virtual_pipeline_stage should be None or an int" + assert isinstance( + num_virtual_pipeline_stages, int + ), "virtual_pipeline_stage should be None or an int" if num_virtual_pipeline_stages > 1: logger.info( "set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler" ) - assert isinstance(seg_method, str), \ - "seg_method should be a str for interleave scheduler" - assert seg_method.startswith('layer:'), \ - "seg_method shoud be start with layer: for interleave scheduler" - - self._num_virtual_pipeline_stages = 1 if num_virtual_pipeline_stages is None else num_virtual_pipeline_stages + assert isinstance( + seg_method, str + ), "seg_method should be a str for interleave scheduler" + assert seg_method.startswith( + 'layer:' + ), "seg_method shoud be start with layer: for interleave scheduler" + + self._num_virtual_pipeline_stages = ( + 1 + if num_virtual_pipeline_stages is None + else num_virtual_pipeline_stages + ) # lazy import import paddle.distributed as dist @@ -313,13 +335,17 @@ class PipelineLayer(Layer): self.recompute_ctx = recompute_ctx if recompute_interval > 0: - assert recompute_ctx is not None, "recompute_ctx must be not None for recompute." + assert ( + recompute_ctx is not None + ), "recompute_ctx must be not None for recompute." offload = recompute_ctx.get('offload', False) partition = recompute_ctx.get('partition', False) logger.info( - "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" - .format(offload, partition)) + "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}".format( + offload, partition + ) + ) world_size = dist.get_world_size() self.global_rank = dist.get_rank() @@ -328,22 +354,28 @@ class PipelineLayer(Layer): self._stage_id = self._topo.get_coord(self.global_rank).pipe self._num_stages = self._topo.get_dim_size("pipe") if num_stages: - assert self._num_stages == num_stages, "num_stages should be equal to be %d" % ( - self._num_stages) + assert ( + self._num_stages == num_stages + ), "num_stages should be equal to be %d" % (self._num_stages) else: # construct default topology if world_size % num_stages != 0: raise ValueError( "should provide correct num_stages({}) " "which can be divided by world_size({})".format( - num_stages, world_size)) + num_stages, world_size + ) + ) dp_num = world_size // num_stages - self._topo = fleet.CommunicateTopology(["data", "pipe", "model"], - [dp_num, num_stages, 1]) + self._topo = fleet.CommunicateTopology( + ["data", "pipe", "model"], [dp_num, num_stages, 1] + ) self._stage_id = self._topo.get_coord(self.global_rank).pipe self._num_stages = self._topo.get_dim_size("pipe") - self._total_stages_with_virtual_stages = self._num_stages * self._num_virtual_pipeline_stages + self._total_stages_with_virtual_stages = ( + self._num_stages * self._num_virtual_pipeline_stages + ) # initialize segment self._layers_desc = list(self.layers) @@ -381,16 +413,22 @@ class PipelineLayer(Layer): start_idx = virtual_pp_rank * self._num_stages for stage in range(self._num_stages): # stage mark the real pp stage - if self.segment_parts[start_idx + - stage] <= layer_idx < self.segment_parts[ - start_idx + stage + 1]: + if ( + self.segment_parts[start_idx + stage] + <= layer_idx + < self.segment_parts[start_idx + stage + 1] + ): return stage def get_num_virtual_stages(self): return self._num_virtual_pipeline_stages def get_model_chunks(self): - return None if self._num_virtual_pipeline_stages == 1 else self._model_chunks + return ( + None + if self._num_virtual_pipeline_stages == 1 + else self._model_chunks + ) def _construct_shared_comm(self): shared_comm = {} @@ -398,17 +436,21 @@ class PipelineLayer(Layer): return layers_desc = self._layers_desc - shared_layer_names = set(s.layer_name for s in layers_desc - if isinstance(s, SharedLayerDesc)) + shared_layer_names = set( + s.layer_name for s in layers_desc if isinstance(s, SharedLayerDesc) + ) for key in shared_layer_names: shared_layers = [] for idx, layer in enumerate(layers_desc): - if isinstance(layer, - SharedLayerDesc) and layer.layer_name == key: + if ( + isinstance(layer, SharedLayerDesc) + and layer.layer_name == key + ): shared_layers.append(idx) shared_stages = set( - self.get_stage_from_index(idx) for idx in shared_layers) + self.get_stage_from_index(idx) for idx in shared_layers + ) self._dp_degree = self._topo.get_dim('data') self._mp_degree = self._topo.get_dim('model') self._sharding_degree = self._topo.get_dim('sharding') @@ -425,7 +467,9 @@ class PipelineLayer(Layer): pipe=s, data=dp, sharding=sharding, - model=mp)) + model=mp, + ) + ) group = paddle.distributed.new_group(ranks=shared_ranks) if self.global_rank in shared_ranks: @@ -434,8 +478,9 @@ class PipelineLayer(Layer): shared_comm[key] = { 'ranks': shared_ranks, 'group': group, - 'weight_attr': - self.shared_weight_attrs[key], + 'weight_attr': self.shared_weight_attrs[ + key + ], 'layer': self.shared_layers[key], } return shared_comm @@ -443,10 +488,11 @@ class PipelineLayer(Layer): def _synchronize_shared_weights(self): for key, comm in self.shared_comm.items(): with paddle.framework.no_grad(): - paddle.distributed.broadcast(getattr(comm['layer'], - comm['weight_attr']), - src=min(comm['ranks']), - group=comm['group']) + paddle.distributed.broadcast( + getattr(comm['layer'], comm['weight_attr']), + src=min(comm['ranks']), + group=comm['group'], + ) for param in comm['layer'].parameters(): if self.global_rank != min(comm['ranks']): @@ -458,8 +504,9 @@ class PipelineLayer(Layer): # need use trace_op to allreduce weight if in_dygraph_mode(): with paddle.framework.no_grad(): - paddle.distributed.all_reduce(param.grad, - group=comm['group']) + paddle.distributed.all_reduce( + param.grad, group=comm['group'] + ) else: with paddle.framework.no_grad(): paddle.fluid.framework._dygraph_tracer().trace_op( @@ -468,8 +515,9 @@ class PipelineLayer(Layer): outputs={'Out': param._grad_ivar()}, attrs={ 'ring_id': comm['group'].id, - 'use_calc_stream': True - }) + 'use_calc_stream': True, + }, + ) def _segment_network_for_interleave(self, seg_method): logger.info("start segment network for interleave scheduler") @@ -477,14 +525,20 @@ class PipelineLayer(Layer): self._layers_desc, num_parts=self._num_stages, method=seg_method, - num_virtual_pipeline_stage=self._num_virtual_pipeline_stages) + num_virtual_pipeline_stage=self._num_virtual_pipeline_stages, + ) self.segment_parts = seg.do_segment() - logger.info("segment result:" + - ", ".join(str(arg) for arg in self.segment_parts)) + logger.info( + "segment result:" + + ", ".join(str(arg) for arg in self.segment_parts) + ) - for i in range(self._stage_id, self._total_stages_with_virtual_stages, - self._num_stages): + for i in range( + self._stage_id, + self._total_stages_with_virtual_stages, + self._num_stages, + ): # If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers. # Layers [0, 1], [4, 5] will be assigned to the first real pp stage. # Layers [2, 3], [6, 7] will be assigned to the second real pp stage. @@ -500,13 +554,15 @@ class PipelineLayer(Layer): def _segment_network(self, seg_method): logger.info("start segment network..") - seg = SegmentLayers(self._layers_desc, - num_parts=self._num_stages, - method=seg_method) + seg = SegmentLayers( + self._layers_desc, num_parts=self._num_stages, method=seg_method + ) self.segment_parts = seg.do_segment() - logger.info("segment result:" + - ", ".join(str(arg) for arg in self.segment_parts)) + logger.info( + "segment result:" + + ", ".join(str(arg) for arg in self.segment_parts) + ) self._start_pos = self.segment_parts[self._stage_id] self._end_pos = self.segment_parts[self._stage_id + 1] @@ -514,22 +570,30 @@ class PipelineLayer(Layer): def _print_segmentation_for_debug(self): # print information for debug - for stage in range(self._num_stages * - self._num_virtual_pipeline_stages): + for stage in range( + self._num_stages * self._num_virtual_pipeline_stages + ): start = self.segment_parts[stage] end = self.segment_parts[stage + 1] - logger.info("stage={}, global_rank={} ,layer_number={}".format( - stage, self.global_rank, end - start)) + logger.info( + "stage={}, global_rank={} ,layer_number={}".format( + stage, self.global_rank, end - start + ) + ) for index, layer in enumerate(self._layers_desc[start:end]): logger.info("{}: {}".format(index + start, str(layer))) if self._num_virtual_pipeline_stages > 1: for stage in range(self._num_stages): - stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format( - stage) - for i in range(stage, self._total_stages_with_virtual_stages, - self._num_stages): + stage_to_virtual_stage_info = ( + "stage {} contains virtual stages: ".format(stage) + ) + for i in range( + stage, + self._total_stages_with_virtual_stages, + self._num_stages, + ): stage_to_virtual_stage_info += " {},".format(i) logger.info(stage_to_virtual_stage_info) @@ -575,9 +639,11 @@ class PipelineLayer(Layer): if layer.layer_name not in self.shared_layers: self.shared_layers[layer.layer_name] = layer.build_layer() self.shared_weight_attrs[ - layer.layer_name] = layer.shared_weight_attr + layer.layer_name + ] = layer.shared_weight_attr for param in self.shared_layers[ - layer.layer_name].parameters(): + layer.layer_name + ].parameters(): setattr(param, "is_firstly_shared", True) if layer.forward_func is None: @@ -585,8 +651,11 @@ class PipelineLayer(Layer): else: run_function.append( - partial(layer.forward_func, - self.shared_layers[layer.layer_name])) + partial( + layer.forward_func, + self.shared_layers[layer.layer_name], + ) + ) elif isinstance(layer, LayerDesc): model = layer.build_layer() @@ -615,11 +684,15 @@ class PipelineLayer(Layer): def forward(self, input, chunk_id=None): if chunk_id is not None: assert isinstance(chunk_id, int), "chunk_id should be an int" - assert self._num_virtual_pipeline_stages > 1, \ - "chunk_id is only valid when using virtual pipeline stage" - assert chunk_id < len(self._model_chunks), \ - "The virtual pipeline only has {} chunks, " \ - "but received chunk_id {}.".format(len(self._model_chunks), chunk_id) + assert ( + self._num_virtual_pipeline_stages > 1 + ), "chunk_id is only valid when using virtual pipeline stage" + assert chunk_id < len(self._model_chunks), ( + "The virtual pipeline only has {} chunks, " + "but received chunk_id {}.".format( + len(self._model_chunks), chunk_id + ) + ) # Get the target model chunk. model_chunk = self._model_chunks[chunk_id] # Update the self.run_function to the target run functions. @@ -637,20 +710,25 @@ class PipelineLayer(Layer): funcs = self.run_function[start_idx:end_idx] if not isinstance(input, tuple): - input = (input, ) + input = (input,) if self._need_recompute(funcs, input): input = recompute_hybrid( self.recompute_ctx, - self.forward_function(start_idx, end_idx), *input) + self.forward_function(start_idx, end_idx), + *input + ) else: input = self.forward_function(start_idx, end_idx)(*input) return input def _need_recompute(self, funcs, inputs): - if not any(input_.stop_gradient == False - for input_ in inputs if isinstance(input_, paddle.Tensor)): + if not any( + input_.stop_gradient == False + for input_ in inputs + if isinstance(input_, paddle.Tensor) + ): return False params = [f.parameters() for f in funcs if isinstance(f, Layer)] @@ -674,11 +752,18 @@ class PipelineLayer(Layer): if self._num_virtual_pipeline_stages > 1: # add virtual pipeline info to the save path assert local_chunk_id is not None - virtual_pipeline_stage_message = "-virtual_pp_stage_{:0>2d}".format( - local_chunk_id) - layer_save_path = os.path.join(ckpt_dir, - 'layer_{:0>2d}'.format(idx)) - layer_save_path = layer_save_path + virtual_pipeline_stage_message + rank_message + '-model_states.pdparams' + virtual_pipeline_stage_message = ( + "-virtual_pp_stage_{:0>2d}".format(local_chunk_id) + ) + layer_save_path = os.path.join( + ckpt_dir, 'layer_{:0>2d}'.format(idx) + ) + layer_save_path = ( + layer_save_path + + virtual_pipeline_stage_message + + rank_message + + '-model_states.pdparams' + ) return layer_save_path def _save_model(run_functions, local_chunk_id=None): @@ -701,7 +786,8 @@ class PipelineLayer(Layer): def set_state_dir(self, path): assert os.path.exists( - path), "{} not found, please check the path".format(path) + path + ), "{} not found, please check the path".format(path) def _load_model(run_functions, local_chunk_id=None): for idx, layer in enumerate(run_functions): @@ -715,21 +801,26 @@ class PipelineLayer(Layer): pos_offset = self._start_poss[local_chunk_id] layer_idx = idx + pos_offset layer_save_path = os.path.join( - path, 'layer_{0:0>2d}'.format(layer_idx)) + path, 'layer_{0:0>2d}'.format(layer_idx) + ) if self._num_virtual_pipeline_stages > 1: # add virtual pipeline info to the path assert local_chunk_id is not None - layer_save_path = layer_save_path + "-virtual_pp_stage_{:0>2d}".format( - local_chunk_id) - model_files = glob.glob(layer_save_path + - "*model_states.pdparams") + layer_save_path = ( + layer_save_path + + "-virtual_pp_stage_{:0>2d}".format(local_chunk_id) + ) + model_files = glob.glob( + layer_save_path + "*model_states.pdparams" + ) model_files.sort() mp_rank = self._topo.get_coord(self.global_rank).model mp_world_size = self._topo.get_dim('model') num_files = len(model_files) - load_param_path = model_files[mp_rank * num_files // - mp_world_size] + load_param_path = model_files[ + mp_rank * num_files // mp_world_size + ] model_state_dict = paddle.load(load_param_path) layer.set_state_dict(model_state_dict) -- GitLab