未验证 提交 04c15e79 编写于 作者: Y Yuang Liu 提交者: GitHub

[dygraph hybrid pp for interleave] Virtual pp stage layer split (#45402)

上级 2a992178
...@@ -91,11 +91,18 @@ class SharedLayerDesc(LayerDesc): ...@@ -91,11 +91,18 @@ class SharedLayerDesc(LayerDesc):
class SegmentLayers(object): class SegmentLayers(object):
def __init__(self, layers_desc, num_parts, method="uniform"): def __init__(self,
layers_desc,
num_parts,
method="uniform",
num_virtual_pipeline_stage=None):
self._layers_desc = layers_desc self._layers_desc = layers_desc
self.method = method self.method = method
self.num_parts = num_parts self.num_parts = num_parts
self.num_items = len(layers_desc) self.num_items = len(layers_desc)
self.num_virtual_pipeline_stage = num_virtual_pipeline_stage
if self.num_virtual_pipeline_stage is not None:
self.total_parts = num_parts * self.num_virtual_pipeline_stage
assert self.num_items >= self.num_parts, "layer number should be greater than number of segments" assert self.num_items >= self.num_parts, "layer number should be greater than number of segments"
def do_segment(self): def do_segment(self):
...@@ -110,12 +117,14 @@ class SegmentLayers(object): ...@@ -110,12 +117,14 @@ class SegmentLayers(object):
for idx in weight_idxs: for idx in weight_idxs:
weights[idx] = 1 weights[idx] = 1
actual_num_parts = self.num_parts if self.num_virtual_pipeline_stage is None else self.total_parts
assert sum( assert sum(
weights weights
) % self.num_parts == 0, "number of layers ({}) should be divided by part number({})".format( ) % actual_num_parts == 0, "number of layers ({}) should be divided by part number({})".format(
sum(weights), self.num_parts) sum(weights), actual_num_parts)
part_size = sum(weights) // self.num_parts part_size = sum(weights) // actual_num_parts
result = [0 for _ in range(self.num_parts + 1)] result = [0 for _ in range(actual_num_parts + 1)]
memory_counter = 0 memory_counter = 0
result_idx = 1 result_idx = 1
...@@ -125,7 +134,7 @@ class SegmentLayers(object): ...@@ -125,7 +134,7 @@ class SegmentLayers(object):
result[result_idx] = idx + 1 result[result_idx] = idx + 1
result_idx += 1 result_idx += 1
memory_counter = 0 memory_counter = 0
result[self.num_parts] = len(weights) result[actual_num_parts] = len(weights)
return result return result
def _gen_layer_weight(self, layername): def _gen_layer_weight(self, layername):
...@@ -159,6 +168,23 @@ class SegmentLayers(object): ...@@ -159,6 +168,23 @@ class SegmentLayers(object):
return result return result
class PipelineLayerChunk(Layer):
def __init__(self):
super(PipelineLayerChunk, self).__init__()
self.functions = []
def append(self, sublayer):
# This method is used to unify codes in _build_layer_impl.
# For 1f1b scheduler, it will call append method of a List.
# For interleave scheduler, it will call append method of this class.
if isinstance(sublayer, Layer):
self.add_sublayer(str(len(self.functions)), sublayer)
self.functions.append(sublayer)
# TODO (Yuang Liu) forward function implement
class PipelineLayer(Layer): class PipelineLayer(Layer):
def __init__(self, def __init__(self,
...@@ -169,11 +195,26 @@ class PipelineLayer(Layer): ...@@ -169,11 +195,26 @@ class PipelineLayer(Layer):
seg_method="uniform", seg_method="uniform",
recompute_interval=0, recompute_interval=0,
recompute_offload=False, recompute_offload=False,
recompute_partition=False): recompute_partition=False,
num_virtual_pipeline_stages=None):
super(PipelineLayer, self).__init__() super(PipelineLayer, self).__init__()
if num_stages is None and topology is None: if num_stages is None and topology is None:
raise ValueError("should provide num_stages or topology") raise ValueError("should provide num_stages or topology")
if num_virtual_pipeline_stages:
assert isinstance(num_virtual_pipeline_stages, int), \
"virtual_pipeline_stage should be None or an int"
if num_virtual_pipeline_stages > 1:
logger.info(
"set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler"
)
assert isinstance(seg_method, str), \
"seg_method should be a str for interleave scheduler"
assert seg_method.startswith('layer:'), \
"seg_method shoud be start with layer: for interleave scheduler"
self._num_virtual_pipeline_stages = 1 if num_virtual_pipeline_stages is None else num_virtual_pipeline_stages
# lazy import # lazy import
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
...@@ -214,28 +255,51 @@ class PipelineLayer(Layer): ...@@ -214,28 +255,51 @@ class PipelineLayer(Layer):
self._stage_id = self._topo.get_coord(self.global_rank).pipe self._stage_id = self._topo.get_coord(self.global_rank).pipe
self._num_stages = self._topo.get_dim_size("pipe") self._num_stages = self._topo.get_dim_size("pipe")
self._total_stages_with_virtual_stages = self._num_stages * self._num_virtual_pipeline_stages
# initialize segment # initialize segment
self._layers_desc = list(self.layers) self._layers_desc = list(self.layers)
self._num_layers = len(self._layers_desc) self._num_layers = len(self._layers_desc)
self._start_pos = 0
self._end_pos = self._num_layers - 1
self._segment_network(seg_method)
self.shared_layers = paddle.nn.LayerDict() self.shared_layers = paddle.nn.LayerDict()
self.shared_weight_attrs = {} self.shared_weight_attrs = {}
# construct layer if self._num_virtual_pipeline_stages > 1:
self.run_function = [] # interleaving pipeline segmentation
self._build_layer() self._start_poss = []
self._end_poss = []
self._segment_network_for_interleave(seg_method)
# The _model_chunks is a list of PipelineLayerChunk,
# while PipelineLayerChunk is a list of Layers relating with one model chunk.
# Therefore, the _model_chunks is something like 'list of a list of layers'.
self._model_chunks = []
self._build_layer_with_interleave()
else:
# 1f1b pipeline segmentation
self._start_pos = 0
self._end_pos = self._num_layers - 1
self._segment_network(seg_method)
# construct layer
self.run_function = []
self._build_layer()
self.shared_comm = self._construct_shared_comm() self.shared_comm = self._construct_shared_comm()
self._synchronize_shared_weights() self._synchronize_shared_weights()
def get_stage_from_index(self, layer_idx): def get_stage_from_index(self, layer_idx):
assert 0 <= layer_idx < self._num_layers, "layer_idx is out of bound" assert 0 <= layer_idx < self._num_layers, "layer_idx is out of bound"
for stage in range(self._topo.get_dim('pipe')): for virtual_pp_rank in range(self._num_virtual_pipeline_stages):
if self.segment_parts[stage] <= layer_idx < self.segment_parts[stage # Mapping the virtual pipeline stage to the real pipeline stage.
+ 1]: # start_idx marks the start of a new virtual pp stage.
return stage start_idx = virtual_pp_rank * self._num_virtual_pipeline_stages
for stage in range(self._num_stages):
# stage mark the real pp stage
if self.segment_parts[start_idx +
stage] <= layer_idx < self.segment_parts[
start_idx + stage + 1]:
return stage
def get_model_chunks(self):
return None if self._num_virtual_pipeline_stages == 1 else self._model_chunks
def _construct_shared_comm(self): def _construct_shared_comm(self):
shared_comm = {} shared_comm = {}
...@@ -316,6 +380,33 @@ class PipelineLayer(Layer): ...@@ -316,6 +380,33 @@ class PipelineLayer(Layer):
'use_calc_stream': True 'use_calc_stream': True
}) })
def _segment_network_for_interleave(self, seg_method):
logger.info("start segment network for interleave scheduler")
seg = SegmentLayers(
self._layers_desc,
num_parts=self._num_stages,
method=seg_method,
num_virtual_pipeline_stage=self._num_virtual_pipeline_stages)
self.segment_parts = seg.do_segment()
logger.info("segment result:" +
", ".join(str(arg) for arg in self.segment_parts))
for i in range(self._stage_id, self._total_stages_with_virtual_stages,
self._num_virtual_pipeline_stages):
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
# Layers [0, 1] and [2, 3] are the first virtual pp stage in each real pp stage.
# Layers [4, 5] and [6, 7] are the second virtual pp stage in each real pp stage.
assert self.segment_parts[i] <= self.segment_parts[i + 1]
self._start_poss.append(self.segment_parts[i])
self._end_poss.append(self.segment_parts[i + 1])
assert len(self._start_poss) == len(self._end_poss)
self._print_segmentation_for_debug()
def _segment_network(self, seg_method): def _segment_network(self, seg_method):
logger.info("start segment network..") logger.info("start segment network..")
seg = SegmentLayers(self._layers_desc, seg = SegmentLayers(self._layers_desc,
...@@ -328,9 +419,12 @@ class PipelineLayer(Layer): ...@@ -328,9 +419,12 @@ class PipelineLayer(Layer):
self._start_pos = self.segment_parts[self._stage_id] self._start_pos = self.segment_parts[self._stage_id]
self._end_pos = self.segment_parts[self._stage_id + 1] self._end_pos = self.segment_parts[self._stage_id + 1]
self._print_segmentation_for_debug()
def _print_segmentation_for_debug(self):
# print information for debug # print information for debug
for stage in range(self._num_stages): for stage in range(self._num_stages *
self._num_virtual_pipeline_stages):
start = self.segment_parts[stage] start = self.segment_parts[stage]
end = self.segment_parts[stage + 1] end = self.segment_parts[stage + 1]
logger.info("stage={}, global_rank={} ,layer_number={}".format( logger.info("stage={}, global_rank={} ,layer_number={}".format(
...@@ -339,20 +433,53 @@ class PipelineLayer(Layer): ...@@ -339,20 +433,53 @@ class PipelineLayer(Layer):
for index, layer in enumerate(self._layers_desc[start:end]): for index, layer in enumerate(self._layers_desc[start:end]):
logger.info("{}: {}".format(index + start, str(layer))) logger.info("{}: {}".format(index + start, str(layer)))
if self._num_virtual_pipeline_stages > 1:
for stage in range(self._num_stages):
stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format(
stage)
for i in range(stage, self._total_stages_with_virtual_stages,
self._num_virtual_pipeline_stages):
stage_to_virtual_stage_info += " {},".format(i)
logger.info(stage_to_virtual_stage_info)
if self._loss_fn: if self._loss_fn:
try: try:
logger.info("loss: {}".format(self._loss_fn.__name__)) logger.info("loss: {}".format(self._loss_fn.__name__))
except AttributeError: except AttributeError:
logger.info("loss: {}".format(self._loss_fn.__class__.__name__)) logger.info("loss: {}".format(self._loss_fn.__class__.__name__))
def _build_layer_with_interleave(self):
for i in range(len(self._start_poss)):
start = self._start_poss[i]
end = self._end_poss[i]
# Get a model chunk
chunk = self._build_layer_impl(start, end)
assert isinstance(chunk, PipelineLayerChunk)
# Add the chunk to all chunks and add this chunk to the sublayer
self._model_chunks.append(chunk)
self.add_sublayer(str(start), chunk)
def _build_layer(self): def _build_layer(self):
start = self._start_pos start = self._start_pos
end = self._end_pos end = self._end_pos
self.run_function = self._build_layer_impl(start, end)
def _build_layer_impl(self, start, end):
if self._num_virtual_pipeline_stages > 1:
# For interleave scheduler, all layers relating with one model chunk will be saved in PipelineLayerChunk
run_function = PipelineLayerChunk()
else:
# For 1f1b scheduler, just use run_function list
run_function = self.run_function
for index, layer in enumerate(self._layers_desc[start:end]): for index, layer in enumerate(self._layers_desc[start:end]):
layer_index = start + index layer_index = start + index
if isinstance(layer, Layer): if isinstance(layer, Layer):
self.run_function.append(layer) run_function.append(layer)
self.add_sublayer(str(layer_index), layer) if self._num_virtual_pipeline_stages == 1:
# Only add sublayer for 1f1b scheduler,
# for interleave, PipelineLayerChunk will do this
self.add_sublayer(str(layer_index), layer)
elif isinstance(layer, SharedLayerDesc): elif isinstance(layer, SharedLayerDesc):
if layer.layer_name not in self.shared_layers: if layer.layer_name not in self.shared_layers:
self.shared_layers[layer.layer_name] = layer.build_layer() self.shared_layers[layer.layer_name] = layer.build_layer()
...@@ -363,20 +490,24 @@ class PipelineLayer(Layer): ...@@ -363,20 +490,24 @@ class PipelineLayer(Layer):
setattr(param, "is_firstly_shared", True) setattr(param, "is_firstly_shared", True)
if layer.forward_func is None: if layer.forward_func is None:
self.run_function.append( run_function.append(self.shared_layers[layer.layer_name])
self.shared_layers[layer.layer_name])
else: else:
self.run_function.append( run_function.append(
partial(layer.forward_func, partial(layer.forward_func,
self.shared_layers[layer.layer_name])) self.shared_layers[layer.layer_name]))
elif isinstance(layer, LayerDesc): elif isinstance(layer, LayerDesc):
model = layer.build_layer() model = layer.build_layer()
self.run_function.append(model) run_function.append(model)
self.add_sublayer(str(layer_index), model) if self._num_virtual_pipeline_stages == 1:
# Only add sublayer for 1f1b scheduler,
# for interleave, PipelineLayerChunk will do this
self.add_sublayer(str(layer_index), model)
else: else:
self.run_function.append(layer) run_function.append(layer)
return run_function
def forward_function(self, start, end): def forward_function(self, start, end):
...@@ -390,6 +521,7 @@ class PipelineLayer(Layer): ...@@ -390,6 +521,7 @@ class PipelineLayer(Layer):
return execute_func return execute_func
def forward(self, input): def forward(self, input):
# TODO(Yuang Liu): forward function for interleave scheduler
if self._recompute_interval == 0: if self._recompute_interval == 0:
input = self.forward_function(0, len(self.run_function))(input) input = self.forward_function(0, len(self.run_function))(input)
else: else:
......
...@@ -61,6 +61,8 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync) ...@@ -61,6 +61,8 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync_gradient_check) list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync_gradient_check)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel)
list(APPEND DIST_TEST_OPS
test_parallel_dygraph_pipeline_parallel_with_virtual_stage)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2)
...@@ -311,6 +313,8 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -311,6 +313,8 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_no_sync_gradient_check) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_no_sync_gradient_check)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel)
list(REMOVE_ITEM TEST_OPS
test_parallel_dygraph_pipeline_parallel_with_virtual_stage)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2)
...@@ -1577,6 +1581,9 @@ if(WITH_DISTRIBUTE ...@@ -1577,6 +1581,9 @@ if(WITH_DISTRIBUTE
PROPERTIES TIMEOUT 60) PROPERTIES TIMEOUT 60)
set_tests_properties(test_parallel_dygraph_pipeline_parallel set_tests_properties(test_parallel_dygraph_pipeline_parallel
PROPERTIES TIMEOUT 500) PROPERTIES TIMEOUT 500)
set_tests_properties(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
PROPERTIES TIMEOUT 500)
set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT
200) 200)
set_tests_properties(test_parallel_dygraph_sharding_parallel set_tests_properties(test_parallel_dygraph_sharding_parallel
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import os
import paddle
from paddle.distributed import fleet
import paddle.nn as nn
from paddle.fluid.dygraph.layers import Layer
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer
import paddle.nn.functional as F
class ReshapeHelp(Layer):
def __init__(self, shape):
super(ReshapeHelp, self).__init__()
self.shape = shape
def forward(self, x):
return x.reshape(shape=self.shape)
class FakeAlexNetPipeDesc(PipelineLayer):
def __init__(self, num_classes=10, **kwargs):
self.num_classes = num_classes
decs = [
LayerDesc(nn.Conv2D, 1, 64, kernel_size=11, stride=4, padding=5),
LayerDesc(nn.Conv2D, 64, 64, kernel_size=11, stride=4, padding=5),
LayerDesc(nn.ReLU),
LayerDesc(nn.MaxPool2D, kernel_size=2, stride=2),
LayerDesc(nn.Conv2D, 64, 192, kernel_size=5, padding=2),
LayerDesc(nn.Conv2D, 192, 192, kernel_size=5, padding=2),
F.relu,
LayerDesc(nn.MaxPool2D, kernel_size=2, stride=2),
LayerDesc(nn.Conv2D, 192, 384, kernel_size=3, padding=1),
F.relu,
LayerDesc(nn.Conv2D, 384, 256, kernel_size=3, padding=1),
F.relu,
LayerDesc(nn.Conv2D, 256, 256, kernel_size=3, padding=1),
LayerDesc(nn.Conv2D, 256, 256, kernel_size=3, padding=1),
F.relu,
LayerDesc(nn.MaxPool2D, kernel_size=2, stride=2),
LayerDesc(ReshapeHelp, shape=[-1, 256]),
LayerDesc(nn.Linear, 256, self.num_classes), # classifier
]
super(FakeAlexNetPipeDesc, self).__init__(layers=decs,
loss_fn=nn.CrossEntropyLoss(),
**kwargs)
class TestPipeLayerAPI(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": self.pipeline_parallel_size
}
fleet.init(is_collective=True, strategy=strategy)
self.hcg = fleet.get_hybrid_communicate_group()
def test_pipelayer_desc(self):
pipe_model = FakeAlexNetPipeDesc(seg_method="layer:Conv2D",
num_stages=self.pipeline_parallel_size,
num_virtual_pipeline_stages=2)
assert len(pipe_model.parameters()) > 0
model_chunks = pipe_model.get_model_chunks()
assert model_chunks is not None
assert len(model_chunks) == 2
dist_model = fleet.distributed_model(pipe_model)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import os
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus):
def test_hybrid_parallel_pp_layer_with_virtual_stage(self):
self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py')
self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py',
eager_mode=False)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册