From 9b6c7eb9d6f002732112d56b8b449b910071f543 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Tue, 3 Aug 2021 19:40:57 +0800 Subject: [PATCH] [HybridParallel] Support segment for PipelineParallel (#34529) * add layer segment * add segement for transformer * add utest --- .../parallel_layers/pp_layers.py | 94 ++++++++++++++----- .../hybrid_parallel_pp_transformer.py | 7 +- 2 files changed, 78 insertions(+), 23 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index b31b2939695..a3c6a5b5fb6 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -13,6 +13,7 @@ # limitations under the License. import math import paddle +import re from paddle.fluid.dygraph.layers import Layer from ...utils.log_util import logger, layer_to_str from functools import partial @@ -20,27 +21,6 @@ from functools import partial __all__ = [] -class SegmentLayers(object): - def __init__(self, layers_desc, num_parts, method="uniform"): - self._layers_desc = layers_desc - self.method = method - self.num_parts = num_parts - self.num_items = len(layers_desc) - assert self.num_items >= self.num_parts, "layer number should be greater than number of segments" - - def do_segment(self): - if self.method == "uniform": - return self.uniform(self.num_items, self.num_parts) - - def uniform(self, num_items, num_parts): - result = [0 for _ in range(num_parts + 1)] - part_size = math.floor(num_items / num_parts) - for i in range(num_parts): - result[i] = int(min(part_size * i, num_items)) - result[num_parts] = num_items - return result - - class LayerDesc(object): def __init__(self, layer_func, *inputs, **kwargs): self.layer_func = layer_func @@ -73,6 +53,75 @@ class SharedLayerDesc(LayerDesc): self.shared_weight_attr = shared_weight_attr +class SegmentLayers(object): + def __init__(self, layers_desc, num_parts, method="uniform"): + self._layers_desc = layers_desc + self.method = method + self.num_parts = num_parts + self.num_items = len(layers_desc) + assert self.num_items >= self.num_parts, "layer number should be greater than number of segments" + + def do_segment(self): + if self.method == "uniform": + return self.uniform(self.num_items, self.num_parts) + + elif self.method.startswith('layer:'): + # Divide equally according to the specified layer + layername = self.method.split(':')[1] + weights = [0] * len(self._layers_desc) + weight_idxs = self._gen_layer_weight(layername) + for idx in weight_idxs: + weights[idx] = 1 + + assert sum( + weights + ) % self.num_parts == 0, "number of layers ({}) should be divided by part number({})".format( + sum(weights), self.num_parts) + part_size = sum(weights) // self.num_parts + result = [0 for _ in range(self.num_parts + 1)] + + memory_counter = 0 + result_idx = 1 + for idx, weight in enumerate(weights): + memory_counter += weight + if memory_counter == part_size: + result[result_idx] = idx + 1 + result_idx += 1 + memory_counter = 0 + result[self.num_parts] = len(weights) + return result + + def _gen_layer_weight(self, layername): + weight_idxs = [] + regex = re.compile(layername, re.IGNORECASE) + for idx, layer in enumerate(self._layers_desc): + name = None + if isinstance(layer, Layer): + name = layer.__class__.__name__ + elif isinstance(layer, LayerDesc): + name = layer.layer_func.__name__ + else: + try: + name = layer.__name__ + except AttributeError: + # it is not error + continue + if regex.search(name): + weight_idxs.append(idx) + + assert len( + weight_idxs) > 0, "weight_idxs' length should be greater than 0" + return weight_idxs + + def uniform(self, num_items, num_parts): + result = [0 for _ in range(num_parts + 1)] + part_size = math.floor(num_items / num_parts) + for i in range(num_parts): + result[i] = int(min(part_size * i, num_items)) + result[num_parts] = num_items + return result + + class PipelineLayer(Layer): def __init__(self, layers, @@ -205,6 +254,9 @@ class PipelineLayer(Layer): self._layers_desc, num_parts=self._num_stages, method=seg_method) self.segment_parts = seg.do_segment() + logger.info("segment result:" + ", ".join( + str(arg) for arg in self.segment_parts)) + self._start_pos = self.segment_parts[self._stage_id] self._end_pos = self.segment_parts[self._stage_id + 1] diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py index 84971f2bc35..b336330836a 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py @@ -121,13 +121,16 @@ class ModelPipe(PipelineLayer): self.descs = [] self.descs.append(LayerDesc(EmbeddingPipe)) - for x in range(5): + for x in range(6): self.descs.append(LayerDesc(TransformerNetPipe)) self.descs.append(lambda x: x[0]) super().__init__( - layers=self.descs, loss_fn=CriterionPipe(), topology=topology) + layers=self.descs, + loss_fn=CriterionPipe(), + topology=topology, + seg_method="layer:TransformerNetPipe") class TestDistPPTraning(unittest.TestCase): -- GitLab