未验证 提交 9b6c7eb9 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel] Support segment for PipelineParallel (#34529)

* add layer segment

* add segement for transformer

* add utest
上级 2714fc7e
......@@ -13,6 +13,7 @@
# limitations under the License.
import math
import paddle
import re
from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str
from functools import partial
......@@ -20,27 +21,6 @@ from functools import partial
__all__ = []
class SegmentLayers(object):
def __init__(self, layers_desc, num_parts, method="uniform"):
self._layers_desc = layers_desc
self.method = method
self.num_parts = num_parts
self.num_items = len(layers_desc)
assert self.num_items >= self.num_parts, "layer number should be greater than number of segments"
def do_segment(self):
if self.method == "uniform":
return self.uniform(self.num_items, self.num_parts)
def uniform(self, num_items, num_parts):
result = [0 for _ in range(num_parts + 1)]
part_size = math.floor(num_items / num_parts)
for i in range(num_parts):
result[i] = int(min(part_size * i, num_items))
result[num_parts] = num_items
return result
class LayerDesc(object):
def __init__(self, layer_func, *inputs, **kwargs):
self.layer_func = layer_func
......@@ -73,6 +53,75 @@ class SharedLayerDesc(LayerDesc):
self.shared_weight_attr = shared_weight_attr
class SegmentLayers(object):
def __init__(self, layers_desc, num_parts, method="uniform"):
self._layers_desc = layers_desc
self.method = method
self.num_parts = num_parts
self.num_items = len(layers_desc)
assert self.num_items >= self.num_parts, "layer number should be greater than number of segments"
def do_segment(self):
if self.method == "uniform":
return self.uniform(self.num_items, self.num_parts)
elif self.method.startswith('layer:'):
# Divide equally according to the specified layer
layername = self.method.split(':')[1]
weights = [0] * len(self._layers_desc)
weight_idxs = self._gen_layer_weight(layername)
for idx in weight_idxs:
weights[idx] = 1
assert sum(
weights
) % self.num_parts == 0, "number of layers ({}) should be divided by part number({})".format(
sum(weights), self.num_parts)
part_size = sum(weights) // self.num_parts
result = [0 for _ in range(self.num_parts + 1)]
memory_counter = 0
result_idx = 1
for idx, weight in enumerate(weights):
memory_counter += weight
if memory_counter == part_size:
result[result_idx] = idx + 1
result_idx += 1
memory_counter = 0
result[self.num_parts] = len(weights)
return result
def _gen_layer_weight(self, layername):
weight_idxs = []
regex = re.compile(layername, re.IGNORECASE)
for idx, layer in enumerate(self._layers_desc):
name = None
if isinstance(layer, Layer):
name = layer.__class__.__name__
elif isinstance(layer, LayerDesc):
name = layer.layer_func.__name__
else:
try:
name = layer.__name__
except AttributeError:
# it is not error
continue
if regex.search(name):
weight_idxs.append(idx)
assert len(
weight_idxs) > 0, "weight_idxs' length should be greater than 0"
return weight_idxs
def uniform(self, num_items, num_parts):
result = [0 for _ in range(num_parts + 1)]
part_size = math.floor(num_items / num_parts)
for i in range(num_parts):
result[i] = int(min(part_size * i, num_items))
result[num_parts] = num_items
return result
class PipelineLayer(Layer):
def __init__(self,
layers,
......@@ -205,6 +254,9 @@ class PipelineLayer(Layer):
self._layers_desc, num_parts=self._num_stages, method=seg_method)
self.segment_parts = seg.do_segment()
logger.info("segment result:" + ", ".join(
str(arg) for arg in self.segment_parts))
self._start_pos = self.segment_parts[self._stage_id]
self._end_pos = self.segment_parts[self._stage_id + 1]
......
......@@ -121,13 +121,16 @@ class ModelPipe(PipelineLayer):
self.descs = []
self.descs.append(LayerDesc(EmbeddingPipe))
for x in range(5):
for x in range(6):
self.descs.append(LayerDesc(TransformerNetPipe))
self.descs.append(lambda x: x[0])
super().__init__(
layers=self.descs, loss_fn=CriterionPipe(), topology=topology)
layers=self.descs,
loss_fn=CriterionPipe(),
topology=topology,
seg_method="layer:TransformerNetPipe")
class TestDistPPTraning(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册