未验证 提交 3a014783 编写于 作者: Y Yuang Liu 提交者: GitHub

update the split logic for uniform (#47670) (#47705)

* code format change

* update the split logic for uniform (#47670)
上级 d5809836
......@@ -57,7 +57,6 @@ __all__ = []
class LayerDesc(object):
def __init__(self, layer_func, *inputs, **kwargs):
self.layer_func = layer_func
self.inputs = inputs
......@@ -65,25 +64,28 @@ class LayerDesc(object):
if not issubclass(layer_func, Layer):
raise TypeError(
"The input(layer_func) should be a derived class of Layer.")
"The input(layer_func) should be a derived class of Layer."
)
def build_layer(self):
return self.layer_func(*self.inputs, **self.kwargs)
def __repr__(self):
return layer_to_str(self.layer_func.__name__, *self.inputs,
**self.kwargs)
return layer_to_str(
self.layer_func.__name__, *self.inputs, **self.kwargs
)
class SharedLayerDesc(LayerDesc):
def __init__(self,
key,
layer_func,
forward_func=None,
shared_weight_attr='weight',
*inputs,
**kwargs):
def __init__(
self,
key,
layer_func,
forward_func=None,
shared_weight_attr='weight',
*inputs,
**kwargs
):
super(SharedLayerDesc, self).__init__(layer_func, *inputs, **kwargs)
self.layer_name = key
self.forward_func = forward_func
......@@ -91,12 +93,13 @@ class SharedLayerDesc(LayerDesc):
class SegmentLayers(object):
def __init__(self,
layers_desc,
num_parts,
method="uniform",
num_virtual_pipeline_stage=None):
def __init__(
self,
layers_desc,
num_parts,
method="uniform",
num_virtual_pipeline_stage=None,
):
self._layers_desc = layers_desc
self.method = method
self.num_parts = num_parts
......@@ -104,7 +107,9 @@ class SegmentLayers(object):
self.num_virtual_pipeline_stage = num_virtual_pipeline_stage
if self.num_virtual_pipeline_stage is not None:
self.total_parts = num_parts * self.num_virtual_pipeline_stage
assert self.num_items >= self.num_parts, "layer number should be greater than number of segments"
assert (
self.num_items >= self.num_parts
), "layer number should be greater than number of segments"
def do_segment(self):
if self.method == "uniform":
......@@ -118,12 +123,17 @@ class SegmentLayers(object):
for idx in weight_idxs:
weights[idx] = 1
actual_num_parts = self.num_parts if self.num_virtual_pipeline_stage is None else self.total_parts
assert sum(
weights
) % actual_num_parts == 0, "number of layers ({}) should be divided by part number({})".format(
sum(weights), actual_num_parts)
actual_num_parts = (
self.num_parts
if self.num_virtual_pipeline_stage is None
else self.total_parts
)
assert (
sum(weights) % actual_num_parts == 0
), "number of layers ({}) should be divided by part number({})".format(
sum(weights), actual_num_parts
)
part_size = sum(weights) // actual_num_parts
result = [0 for _ in range(actual_num_parts + 1)]
......@@ -156,21 +166,23 @@ class SegmentLayers(object):
if regex.search(name):
weight_idxs.append(idx)
assert len(
weight_idxs) > 0, "weight_idxs' length should be greater than 0"
assert (
len(weight_idxs) > 0
), "weight_idxs' length should be greater than 0"
return weight_idxs
def uniform(self, num_items, num_parts):
result = [0 for _ in range(num_parts + 1)]
part_size = math.floor(num_items / num_parts)
for i in range(num_parts):
result[i] = int(min(part_size * i, num_items))
extra_layers = num_items % num_parts
for i in range(1, num_parts):
offset = 1 if i > (num_parts - extra_layers) else 0
result[i] = int(min(result[i - 1] + part_size + offset, num_items))
result[num_parts] = num_items
return result
class PipelineLayerChunk(Layer):
def __init__(self):
super(PipelineLayerChunk, self).__init__()
self.run_function = []
......@@ -192,18 +204,19 @@ class PipelineLayerChunk(Layer):
# behavior under recompute circumstance.
raise PermissionError(
"The forward function of PipelineLayerChunk cannot be called directly. "
"Please call forward function of PipelineLayer.")
"Please call forward function of PipelineLayer."
)
class PipelineLayer(Layer):
"""PipelineLayer
Args:
layers(Iterable): A sequence of layers description to define the structure for pipeline.
num_stages(int, optional): pp degree, if not specified, 'topology' parameter must be given.
num_stages(int, optional): pp degree, if not specified, 'topology' parameter must be given.
topology(CommunicateTopology, optional): topo of hybrid parallel, if it is None, 'num_stages' parameters must be given.
loss_fn(callable, optional): Loss function.
seg_method(str, optional): the method of splitting pp layer, default 'uniform', or use specific layer to split, method's name must be start with 'layer:'.
recompute_interval(int, optional): the number of layers to be used recompute, the value of 0 represents no recompute. default 0.
recompute_interval(int, optional): the number of layers to be used recompute, the value of 0 represents no recompute. default 0.
recompute_ctx(dict,optional): the context of recompute, when 'recompute_interval' > 0, the context must be given.
num_virtual_pipeline_stages(int, optional): the num of virtual pipeline stages for interleave pp.
Examples:
......@@ -213,7 +226,7 @@ class PipelineLayer(Layer):
from paddle.fluid.dygraph.layers import Layer
import paddle.nn.functional as F
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer
pipeline_parallel_size = 2
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
......@@ -225,19 +238,19 @@ class PipelineLayer(Layer):
"accumulate_steps": 4,
"micro_batch_size": 2
}
fleet.init(is_collective=True, strategy=strategy)
hcg = fleet.get_hybrid_communicate_group()
class ReshapeHelp(Layer):
def __init__(self, shape):
super(ReshapeHelp, self).__init__()
self.shape = shape
def forward(self, x):
return x.reshape(shape=self.shape)
class AlexNetPipeDesc(PipelineLayer):
def __init__(self, num_classes=10, **kwargs):
self.num_classes = num_classes
......@@ -269,37 +282,46 @@ class PipelineLayer(Layer):
]
super(AlexNetPipeDesc, self).__init__(
layers=decs, loss_fn=nn.CrossEntropyLoss(), **kwargs)
model = AlexNetPipeDesc(num_stages=pipeline_parallel_size, topology=hcg._topo)
"""
def __init__(self,
layers,
num_stages=None,
topology=None,
loss_fn=None,
seg_method="uniform",
recompute_interval=0,
recompute_ctx=None,
num_virtual_pipeline_stages=None):
def __init__(
self,
layers,
num_stages=None,
topology=None,
loss_fn=None,
seg_method="uniform",
recompute_interval=0,
recompute_ctx=None,
num_virtual_pipeline_stages=None,
):
super(PipelineLayer, self).__init__()
if num_stages is None and topology is None:
raise ValueError("should provide num_stages or topology")
if num_virtual_pipeline_stages:
assert isinstance(num_virtual_pipeline_stages, int), \
"virtual_pipeline_stage should be None or an int"
assert isinstance(
num_virtual_pipeline_stages, int
), "virtual_pipeline_stage should be None or an int"
if num_virtual_pipeline_stages > 1:
logger.info(
"set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler"
)
assert isinstance(seg_method, str), \
"seg_method should be a str for interleave scheduler"
assert seg_method.startswith('layer:'), \
"seg_method shoud be start with layer: for interleave scheduler"
self._num_virtual_pipeline_stages = 1 if num_virtual_pipeline_stages is None else num_virtual_pipeline_stages
assert isinstance(
seg_method, str
), "seg_method should be a str for interleave scheduler"
assert seg_method.startswith(
'layer:'
), "seg_method shoud be start with layer: for interleave scheduler"
self._num_virtual_pipeline_stages = (
1
if num_virtual_pipeline_stages is None
else num_virtual_pipeline_stages
)
# lazy import
import paddle.distributed as dist
......@@ -313,13 +335,17 @@ class PipelineLayer(Layer):
self.recompute_ctx = recompute_ctx
if recompute_interval > 0:
assert recompute_ctx is not None, "recompute_ctx must be not None for recompute."
assert (
recompute_ctx is not None
), "recompute_ctx must be not None for recompute."
offload = recompute_ctx.get('offload', False)
partition = recompute_ctx.get('partition', False)
logger.info(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
.format(offload, partition))
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}".format(
offload, partition
)
)
world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
......@@ -328,22 +354,28 @@ class PipelineLayer(Layer):
self._stage_id = self._topo.get_coord(self.global_rank).pipe
self._num_stages = self._topo.get_dim_size("pipe")
if num_stages:
assert self._num_stages == num_stages, "num_stages should be equal to be %d" % (
self._num_stages)
assert (
self._num_stages == num_stages
), "num_stages should be equal to be %d" % (self._num_stages)
else:
# construct default topology
if world_size % num_stages != 0:
raise ValueError(
"should provide correct num_stages({}) "
"which can be divided by world_size({})".format(
num_stages, world_size))
num_stages, world_size
)
)
dp_num = world_size // num_stages
self._topo = fleet.CommunicateTopology(["data", "pipe", "model"],
[dp_num, num_stages, 1])
self._topo = fleet.CommunicateTopology(
["data", "pipe", "model"], [dp_num, num_stages, 1]
)
self._stage_id = self._topo.get_coord(self.global_rank).pipe
self._num_stages = self._topo.get_dim_size("pipe")
self._total_stages_with_virtual_stages = self._num_stages * self._num_virtual_pipeline_stages
self._total_stages_with_virtual_stages = (
self._num_stages * self._num_virtual_pipeline_stages
)
# initialize segment
self._layers_desc = list(self.layers)
......@@ -381,16 +413,22 @@ class PipelineLayer(Layer):
start_idx = virtual_pp_rank * self._num_stages
for stage in range(self._num_stages):
# stage mark the real pp stage
if self.segment_parts[start_idx +
stage] <= layer_idx < self.segment_parts[
start_idx + stage + 1]:
if (
self.segment_parts[start_idx + stage]
<= layer_idx
< self.segment_parts[start_idx + stage + 1]
):
return stage
def get_num_virtual_stages(self):
return self._num_virtual_pipeline_stages
def get_model_chunks(self):
return None if self._num_virtual_pipeline_stages == 1 else self._model_chunks
return (
None
if self._num_virtual_pipeline_stages == 1
else self._model_chunks
)
def _construct_shared_comm(self):
shared_comm = {}
......@@ -398,17 +436,21 @@ class PipelineLayer(Layer):
return
layers_desc = self._layers_desc
shared_layer_names = set(s.layer_name for s in layers_desc
if isinstance(s, SharedLayerDesc))
shared_layer_names = set(
s.layer_name for s in layers_desc if isinstance(s, SharedLayerDesc)
)
for key in shared_layer_names:
shared_layers = []
for idx, layer in enumerate(layers_desc):
if isinstance(layer,
SharedLayerDesc) and layer.layer_name == key:
if (
isinstance(layer, SharedLayerDesc)
and layer.layer_name == key
):
shared_layers.append(idx)
shared_stages = set(
self.get_stage_from_index(idx) for idx in shared_layers)
self.get_stage_from_index(idx) for idx in shared_layers
)
self._dp_degree = self._topo.get_dim('data')
self._mp_degree = self._topo.get_dim('model')
self._sharding_degree = self._topo.get_dim('sharding')
......@@ -425,7 +467,9 @@ class PipelineLayer(Layer):
pipe=s,
data=dp,
sharding=sharding,
model=mp))
model=mp,
)
)
group = paddle.distributed.new_group(ranks=shared_ranks)
if self.global_rank in shared_ranks:
......@@ -434,8 +478,9 @@ class PipelineLayer(Layer):
shared_comm[key] = {
'ranks': shared_ranks,
'group': group,
'weight_attr':
self.shared_weight_attrs[key],
'weight_attr': self.shared_weight_attrs[
key
],
'layer': self.shared_layers[key],
}
return shared_comm
......@@ -443,10 +488,11 @@ class PipelineLayer(Layer):
def _synchronize_shared_weights(self):
for key, comm in self.shared_comm.items():
with paddle.framework.no_grad():
paddle.distributed.broadcast(getattr(comm['layer'],
comm['weight_attr']),
src=min(comm['ranks']),
group=comm['group'])
paddle.distributed.broadcast(
getattr(comm['layer'], comm['weight_attr']),
src=min(comm['ranks']),
group=comm['group'],
)
for param in comm['layer'].parameters():
if self.global_rank != min(comm['ranks']):
......@@ -458,8 +504,9 @@ class PipelineLayer(Layer):
# need use trace_op to allreduce weight
if in_dygraph_mode():
with paddle.framework.no_grad():
paddle.distributed.all_reduce(param.grad,
group=comm['group'])
paddle.distributed.all_reduce(
param.grad, group=comm['group']
)
else:
with paddle.framework.no_grad():
paddle.fluid.framework._dygraph_tracer().trace_op(
......@@ -468,8 +515,9 @@ class PipelineLayer(Layer):
outputs={'Out': param._grad_ivar()},
attrs={
'ring_id': comm['group'].id,
'use_calc_stream': True
})
'use_calc_stream': True,
},
)
def _segment_network_for_interleave(self, seg_method):
logger.info("start segment network for interleave scheduler")
......@@ -477,14 +525,20 @@ class PipelineLayer(Layer):
self._layers_desc,
num_parts=self._num_stages,
method=seg_method,
num_virtual_pipeline_stage=self._num_virtual_pipeline_stages)
num_virtual_pipeline_stage=self._num_virtual_pipeline_stages,
)
self.segment_parts = seg.do_segment()
logger.info("segment result:" +
", ".join(str(arg) for arg in self.segment_parts))
logger.info(
"segment result:"
+ ", ".join(str(arg) for arg in self.segment_parts)
)
for i in range(self._stage_id, self._total_stages_with_virtual_stages,
self._num_stages):
for i in range(
self._stage_id,
self._total_stages_with_virtual_stages,
self._num_stages,
):
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
......@@ -500,13 +554,15 @@ class PipelineLayer(Layer):
def _segment_network(self, seg_method):
logger.info("start segment network..")
seg = SegmentLayers(self._layers_desc,
num_parts=self._num_stages,
method=seg_method)
seg = SegmentLayers(
self._layers_desc, num_parts=self._num_stages, method=seg_method
)
self.segment_parts = seg.do_segment()
logger.info("segment result:" +
", ".join(str(arg) for arg in self.segment_parts))
logger.info(
"segment result:"
+ ", ".join(str(arg) for arg in self.segment_parts)
)
self._start_pos = self.segment_parts[self._stage_id]
self._end_pos = self.segment_parts[self._stage_id + 1]
......@@ -514,22 +570,30 @@ class PipelineLayer(Layer):
def _print_segmentation_for_debug(self):
# print information for debug
for stage in range(self._num_stages *
self._num_virtual_pipeline_stages):
for stage in range(
self._num_stages * self._num_virtual_pipeline_stages
):
start = self.segment_parts[stage]
end = self.segment_parts[stage + 1]
logger.info("stage={}, global_rank={} ,layer_number={}".format(
stage, self.global_rank, end - start))
logger.info(
"stage={}, global_rank={} ,layer_number={}".format(
stage, self.global_rank, end - start
)
)
for index, layer in enumerate(self._layers_desc[start:end]):
logger.info("{}: {}".format(index + start, str(layer)))
if self._num_virtual_pipeline_stages > 1:
for stage in range(self._num_stages):
stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format(
stage)
for i in range(stage, self._total_stages_with_virtual_stages,
self._num_stages):
stage_to_virtual_stage_info = (
"stage {} contains virtual stages: ".format(stage)
)
for i in range(
stage,
self._total_stages_with_virtual_stages,
self._num_stages,
):
stage_to_virtual_stage_info += " {},".format(i)
logger.info(stage_to_virtual_stage_info)
......@@ -575,9 +639,11 @@ class PipelineLayer(Layer):
if layer.layer_name not in self.shared_layers:
self.shared_layers[layer.layer_name] = layer.build_layer()
self.shared_weight_attrs[
layer.layer_name] = layer.shared_weight_attr
layer.layer_name
] = layer.shared_weight_attr
for param in self.shared_layers[
layer.layer_name].parameters():
layer.layer_name
].parameters():
setattr(param, "is_firstly_shared", True)
if layer.forward_func is None:
......@@ -585,8 +651,11 @@ class PipelineLayer(Layer):
else:
run_function.append(
partial(layer.forward_func,
self.shared_layers[layer.layer_name]))
partial(
layer.forward_func,
self.shared_layers[layer.layer_name],
)
)
elif isinstance(layer, LayerDesc):
model = layer.build_layer()
......@@ -615,11 +684,15 @@ class PipelineLayer(Layer):
def forward(self, input, chunk_id=None):
if chunk_id is not None:
assert isinstance(chunk_id, int), "chunk_id should be an int"
assert self._num_virtual_pipeline_stages > 1, \
"chunk_id is only valid when using virtual pipeline stage"
assert chunk_id < len(self._model_chunks), \
"The virtual pipeline only has {} chunks, " \
"but received chunk_id {}.".format(len(self._model_chunks), chunk_id)
assert (
self._num_virtual_pipeline_stages > 1
), "chunk_id is only valid when using virtual pipeline stage"
assert chunk_id < len(self._model_chunks), (
"The virtual pipeline only has {} chunks, "
"but received chunk_id {}.".format(
len(self._model_chunks), chunk_id
)
)
# Get the target model chunk.
model_chunk = self._model_chunks[chunk_id]
# Update the self.run_function to the target run functions.
......@@ -637,20 +710,25 @@ class PipelineLayer(Layer):
funcs = self.run_function[start_idx:end_idx]
if not isinstance(input, tuple):
input = (input, )
input = (input,)
if self._need_recompute(funcs, input):
input = recompute_hybrid(
self.recompute_ctx,
self.forward_function(start_idx, end_idx), *input)
self.forward_function(start_idx, end_idx),
*input
)
else:
input = self.forward_function(start_idx, end_idx)(*input)
return input
def _need_recompute(self, funcs, inputs):
if not any(input_.stop_gradient == False
for input_ in inputs if isinstance(input_, paddle.Tensor)):
if not any(
input_.stop_gradient == False
for input_ in inputs
if isinstance(input_, paddle.Tensor)
):
return False
params = [f.parameters() for f in funcs if isinstance(f, Layer)]
......@@ -674,11 +752,18 @@ class PipelineLayer(Layer):
if self._num_virtual_pipeline_stages > 1:
# add virtual pipeline info to the save path
assert local_chunk_id is not None
virtual_pipeline_stage_message = "-virtual_pp_stage_{:0>2d}".format(
local_chunk_id)
layer_save_path = os.path.join(ckpt_dir,
'layer_{:0>2d}'.format(idx))
layer_save_path = layer_save_path + virtual_pipeline_stage_message + rank_message + '-model_states.pdparams'
virtual_pipeline_stage_message = (
"-virtual_pp_stage_{:0>2d}".format(local_chunk_id)
)
layer_save_path = os.path.join(
ckpt_dir, 'layer_{:0>2d}'.format(idx)
)
layer_save_path = (
layer_save_path
+ virtual_pipeline_stage_message
+ rank_message
+ '-model_states.pdparams'
)
return layer_save_path
def _save_model(run_functions, local_chunk_id=None):
......@@ -701,7 +786,8 @@ class PipelineLayer(Layer):
def set_state_dir(self, path):
assert os.path.exists(
path), "{} not found, please check the path".format(path)
path
), "{} not found, please check the path".format(path)
def _load_model(run_functions, local_chunk_id=None):
for idx, layer in enumerate(run_functions):
......@@ -715,21 +801,26 @@ class PipelineLayer(Layer):
pos_offset = self._start_poss[local_chunk_id]
layer_idx = idx + pos_offset
layer_save_path = os.path.join(
path, 'layer_{0:0>2d}'.format(layer_idx))
path, 'layer_{0:0>2d}'.format(layer_idx)
)
if self._num_virtual_pipeline_stages > 1:
# add virtual pipeline info to the path
assert local_chunk_id is not None
layer_save_path = layer_save_path + "-virtual_pp_stage_{:0>2d}".format(
local_chunk_id)
model_files = glob.glob(layer_save_path +
"*model_states.pdparams")
layer_save_path = (
layer_save_path
+ "-virtual_pp_stage_{:0>2d}".format(local_chunk_id)
)
model_files = glob.glob(
layer_save_path + "*model_states.pdparams"
)
model_files.sort()
mp_rank = self._topo.get_coord(self.global_rank).model
mp_world_size = self._topo.get_dim('model')
num_files = len(model_files)
load_param_path = model_files[mp_rank * num_files //
mp_world_size]
load_param_path = model_files[
mp_rank * num_files // mp_world_size
]
model_state_dict = paddle.load(load_param_path)
layer.set_state_dict(model_state_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册