未验证 提交 3a014783 编写于 作者: Y Yuang Liu 提交者: GitHub

update the split logic for uniform (#47670) (#47705)

* code format change

* update the split logic for uniform (#47670)
上级 d5809836
...@@ -57,7 +57,6 @@ __all__ = [] ...@@ -57,7 +57,6 @@ __all__ = []
class LayerDesc(object): class LayerDesc(object):
def __init__(self, layer_func, *inputs, **kwargs): def __init__(self, layer_func, *inputs, **kwargs):
self.layer_func = layer_func self.layer_func = layer_func
self.inputs = inputs self.inputs = inputs
...@@ -65,25 +64,28 @@ class LayerDesc(object): ...@@ -65,25 +64,28 @@ class LayerDesc(object):
if not issubclass(layer_func, Layer): if not issubclass(layer_func, Layer):
raise TypeError( raise TypeError(
"The input(layer_func) should be a derived class of Layer.") "The input(layer_func) should be a derived class of Layer."
)
def build_layer(self): def build_layer(self):
return self.layer_func(*self.inputs, **self.kwargs) return self.layer_func(*self.inputs, **self.kwargs)
def __repr__(self): def __repr__(self):
return layer_to_str(self.layer_func.__name__, *self.inputs, return layer_to_str(
**self.kwargs) self.layer_func.__name__, *self.inputs, **self.kwargs
)
class SharedLayerDesc(LayerDesc): class SharedLayerDesc(LayerDesc):
def __init__(
def __init__(self, self,
key, key,
layer_func, layer_func,
forward_func=None, forward_func=None,
shared_weight_attr='weight', shared_weight_attr='weight',
*inputs, *inputs,
**kwargs): **kwargs
):
super(SharedLayerDesc, self).__init__(layer_func, *inputs, **kwargs) super(SharedLayerDesc, self).__init__(layer_func, *inputs, **kwargs)
self.layer_name = key self.layer_name = key
self.forward_func = forward_func self.forward_func = forward_func
...@@ -91,12 +93,13 @@ class SharedLayerDesc(LayerDesc): ...@@ -91,12 +93,13 @@ class SharedLayerDesc(LayerDesc):
class SegmentLayers(object): class SegmentLayers(object):
def __init__(
def __init__(self, self,
layers_desc, layers_desc,
num_parts, num_parts,
method="uniform", method="uniform",
num_virtual_pipeline_stage=None): num_virtual_pipeline_stage=None,
):
self._layers_desc = layers_desc self._layers_desc = layers_desc
self.method = method self.method = method
self.num_parts = num_parts self.num_parts = num_parts
...@@ -104,7 +107,9 @@ class SegmentLayers(object): ...@@ -104,7 +107,9 @@ class SegmentLayers(object):
self.num_virtual_pipeline_stage = num_virtual_pipeline_stage self.num_virtual_pipeline_stage = num_virtual_pipeline_stage
if self.num_virtual_pipeline_stage is not None: if self.num_virtual_pipeline_stage is not None:
self.total_parts = num_parts * self.num_virtual_pipeline_stage self.total_parts = num_parts * self.num_virtual_pipeline_stage
assert self.num_items >= self.num_parts, "layer number should be greater than number of segments" assert (
self.num_items >= self.num_parts
), "layer number should be greater than number of segments"
def do_segment(self): def do_segment(self):
if self.method == "uniform": if self.method == "uniform":
...@@ -118,12 +123,17 @@ class SegmentLayers(object): ...@@ -118,12 +123,17 @@ class SegmentLayers(object):
for idx in weight_idxs: for idx in weight_idxs:
weights[idx] = 1 weights[idx] = 1
actual_num_parts = self.num_parts if self.num_virtual_pipeline_stage is None else self.total_parts actual_num_parts = (
self.num_parts
if self.num_virtual_pipeline_stage is None
else self.total_parts
)
assert sum( assert (
weights sum(weights) % actual_num_parts == 0
) % actual_num_parts == 0, "number of layers ({}) should be divided by part number({})".format( ), "number of layers ({}) should be divided by part number({})".format(
sum(weights), actual_num_parts) sum(weights), actual_num_parts
)
part_size = sum(weights) // actual_num_parts part_size = sum(weights) // actual_num_parts
result = [0 for _ in range(actual_num_parts + 1)] result = [0 for _ in range(actual_num_parts + 1)]
...@@ -156,21 +166,23 @@ class SegmentLayers(object): ...@@ -156,21 +166,23 @@ class SegmentLayers(object):
if regex.search(name): if regex.search(name):
weight_idxs.append(idx) weight_idxs.append(idx)
assert len( assert (
weight_idxs) > 0, "weight_idxs' length should be greater than 0" len(weight_idxs) > 0
), "weight_idxs' length should be greater than 0"
return weight_idxs return weight_idxs
def uniform(self, num_items, num_parts): def uniform(self, num_items, num_parts):
result = [0 for _ in range(num_parts + 1)] result = [0 for _ in range(num_parts + 1)]
part_size = math.floor(num_items / num_parts) part_size = math.floor(num_items / num_parts)
for i in range(num_parts): extra_layers = num_items % num_parts
result[i] = int(min(part_size * i, num_items)) for i in range(1, num_parts):
offset = 1 if i > (num_parts - extra_layers) else 0
result[i] = int(min(result[i - 1] + part_size + offset, num_items))
result[num_parts] = num_items result[num_parts] = num_items
return result return result
class PipelineLayerChunk(Layer): class PipelineLayerChunk(Layer):
def __init__(self): def __init__(self):
super(PipelineLayerChunk, self).__init__() super(PipelineLayerChunk, self).__init__()
self.run_function = [] self.run_function = []
...@@ -192,7 +204,8 @@ class PipelineLayerChunk(Layer): ...@@ -192,7 +204,8 @@ class PipelineLayerChunk(Layer):
# behavior under recompute circumstance. # behavior under recompute circumstance.
raise PermissionError( raise PermissionError(
"The forward function of PipelineLayerChunk cannot be called directly. " "The forward function of PipelineLayerChunk cannot be called directly. "
"Please call forward function of PipelineLayer.") "Please call forward function of PipelineLayer."
)
class PipelineLayer(Layer): class PipelineLayer(Layer):
...@@ -274,7 +287,8 @@ class PipelineLayer(Layer): ...@@ -274,7 +287,8 @@ class PipelineLayer(Layer):
""" """
def __init__(self, def __init__(
self,
layers, layers,
num_stages=None, num_stages=None,
topology=None, topology=None,
...@@ -282,24 +296,32 @@ class PipelineLayer(Layer): ...@@ -282,24 +296,32 @@ class PipelineLayer(Layer):
seg_method="uniform", seg_method="uniform",
recompute_interval=0, recompute_interval=0,
recompute_ctx=None, recompute_ctx=None,
num_virtual_pipeline_stages=None): num_virtual_pipeline_stages=None,
):
super(PipelineLayer, self).__init__() super(PipelineLayer, self).__init__()
if num_stages is None and topology is None: if num_stages is None and topology is None:
raise ValueError("should provide num_stages or topology") raise ValueError("should provide num_stages or topology")
if num_virtual_pipeline_stages: if num_virtual_pipeline_stages:
assert isinstance(num_virtual_pipeline_stages, int), \ assert isinstance(
"virtual_pipeline_stage should be None or an int" num_virtual_pipeline_stages, int
), "virtual_pipeline_stage should be None or an int"
if num_virtual_pipeline_stages > 1: if num_virtual_pipeline_stages > 1:
logger.info( logger.info(
"set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler" "set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler"
) )
assert isinstance(seg_method, str), \ assert isinstance(
"seg_method should be a str for interleave scheduler" seg_method, str
assert seg_method.startswith('layer:'), \ ), "seg_method should be a str for interleave scheduler"
"seg_method shoud be start with layer: for interleave scheduler" assert seg_method.startswith(
'layer:'
self._num_virtual_pipeline_stages = 1 if num_virtual_pipeline_stages is None else num_virtual_pipeline_stages ), "seg_method shoud be start with layer: for interleave scheduler"
self._num_virtual_pipeline_stages = (
1
if num_virtual_pipeline_stages is None
else num_virtual_pipeline_stages
)
# lazy import # lazy import
import paddle.distributed as dist import paddle.distributed as dist
...@@ -313,13 +335,17 @@ class PipelineLayer(Layer): ...@@ -313,13 +335,17 @@ class PipelineLayer(Layer):
self.recompute_ctx = recompute_ctx self.recompute_ctx = recompute_ctx
if recompute_interval > 0: if recompute_interval > 0:
assert recompute_ctx is not None, "recompute_ctx must be not None for recompute." assert (
recompute_ctx is not None
), "recompute_ctx must be not None for recompute."
offload = recompute_ctx.get('offload', False) offload = recompute_ctx.get('offload', False)
partition = recompute_ctx.get('partition', False) partition = recompute_ctx.get('partition', False)
logger.info( logger.info(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}".format(
.format(offload, partition)) offload, partition
)
)
world_size = dist.get_world_size() world_size = dist.get_world_size()
self.global_rank = dist.get_rank() self.global_rank = dist.get_rank()
...@@ -328,22 +354,28 @@ class PipelineLayer(Layer): ...@@ -328,22 +354,28 @@ class PipelineLayer(Layer):
self._stage_id = self._topo.get_coord(self.global_rank).pipe self._stage_id = self._topo.get_coord(self.global_rank).pipe
self._num_stages = self._topo.get_dim_size("pipe") self._num_stages = self._topo.get_dim_size("pipe")
if num_stages: if num_stages:
assert self._num_stages == num_stages, "num_stages should be equal to be %d" % ( assert (
self._num_stages) self._num_stages == num_stages
), "num_stages should be equal to be %d" % (self._num_stages)
else: else:
# construct default topology # construct default topology
if world_size % num_stages != 0: if world_size % num_stages != 0:
raise ValueError( raise ValueError(
"should provide correct num_stages({}) " "should provide correct num_stages({}) "
"which can be divided by world_size({})".format( "which can be divided by world_size({})".format(
num_stages, world_size)) num_stages, world_size
)
)
dp_num = world_size // num_stages dp_num = world_size // num_stages
self._topo = fleet.CommunicateTopology(["data", "pipe", "model"], self._topo = fleet.CommunicateTopology(
[dp_num, num_stages, 1]) ["data", "pipe", "model"], [dp_num, num_stages, 1]
)
self._stage_id = self._topo.get_coord(self.global_rank).pipe self._stage_id = self._topo.get_coord(self.global_rank).pipe
self._num_stages = self._topo.get_dim_size("pipe") self._num_stages = self._topo.get_dim_size("pipe")
self._total_stages_with_virtual_stages = self._num_stages * self._num_virtual_pipeline_stages self._total_stages_with_virtual_stages = (
self._num_stages * self._num_virtual_pipeline_stages
)
# initialize segment # initialize segment
self._layers_desc = list(self.layers) self._layers_desc = list(self.layers)
...@@ -381,16 +413,22 @@ class PipelineLayer(Layer): ...@@ -381,16 +413,22 @@ class PipelineLayer(Layer):
start_idx = virtual_pp_rank * self._num_stages start_idx = virtual_pp_rank * self._num_stages
for stage in range(self._num_stages): for stage in range(self._num_stages):
# stage mark the real pp stage # stage mark the real pp stage
if self.segment_parts[start_idx + if (
stage] <= layer_idx < self.segment_parts[ self.segment_parts[start_idx + stage]
start_idx + stage + 1]: <= layer_idx
< self.segment_parts[start_idx + stage + 1]
):
return stage return stage
def get_num_virtual_stages(self): def get_num_virtual_stages(self):
return self._num_virtual_pipeline_stages return self._num_virtual_pipeline_stages
def get_model_chunks(self): def get_model_chunks(self):
return None if self._num_virtual_pipeline_stages == 1 else self._model_chunks return (
None
if self._num_virtual_pipeline_stages == 1
else self._model_chunks
)
def _construct_shared_comm(self): def _construct_shared_comm(self):
shared_comm = {} shared_comm = {}
...@@ -398,17 +436,21 @@ class PipelineLayer(Layer): ...@@ -398,17 +436,21 @@ class PipelineLayer(Layer):
return return
layers_desc = self._layers_desc layers_desc = self._layers_desc
shared_layer_names = set(s.layer_name for s in layers_desc shared_layer_names = set(
if isinstance(s, SharedLayerDesc)) s.layer_name for s in layers_desc if isinstance(s, SharedLayerDesc)
)
for key in shared_layer_names: for key in shared_layer_names:
shared_layers = [] shared_layers = []
for idx, layer in enumerate(layers_desc): for idx, layer in enumerate(layers_desc):
if isinstance(layer, if (
SharedLayerDesc) and layer.layer_name == key: isinstance(layer, SharedLayerDesc)
and layer.layer_name == key
):
shared_layers.append(idx) shared_layers.append(idx)
shared_stages = set( shared_stages = set(
self.get_stage_from_index(idx) for idx in shared_layers) self.get_stage_from_index(idx) for idx in shared_layers
)
self._dp_degree = self._topo.get_dim('data') self._dp_degree = self._topo.get_dim('data')
self._mp_degree = self._topo.get_dim('model') self._mp_degree = self._topo.get_dim('model')
self._sharding_degree = self._topo.get_dim('sharding') self._sharding_degree = self._topo.get_dim('sharding')
...@@ -425,7 +467,9 @@ class PipelineLayer(Layer): ...@@ -425,7 +467,9 @@ class PipelineLayer(Layer):
pipe=s, pipe=s,
data=dp, data=dp,
sharding=sharding, sharding=sharding,
model=mp)) model=mp,
)
)
group = paddle.distributed.new_group(ranks=shared_ranks) group = paddle.distributed.new_group(ranks=shared_ranks)
if self.global_rank in shared_ranks: if self.global_rank in shared_ranks:
...@@ -434,8 +478,9 @@ class PipelineLayer(Layer): ...@@ -434,8 +478,9 @@ class PipelineLayer(Layer):
shared_comm[key] = { shared_comm[key] = {
'ranks': shared_ranks, 'ranks': shared_ranks,
'group': group, 'group': group,
'weight_attr': 'weight_attr': self.shared_weight_attrs[
self.shared_weight_attrs[key], key
],
'layer': self.shared_layers[key], 'layer': self.shared_layers[key],
} }
return shared_comm return shared_comm
...@@ -443,10 +488,11 @@ class PipelineLayer(Layer): ...@@ -443,10 +488,11 @@ class PipelineLayer(Layer):
def _synchronize_shared_weights(self): def _synchronize_shared_weights(self):
for key, comm in self.shared_comm.items(): for key, comm in self.shared_comm.items():
with paddle.framework.no_grad(): with paddle.framework.no_grad():
paddle.distributed.broadcast(getattr(comm['layer'], paddle.distributed.broadcast(
comm['weight_attr']), getattr(comm['layer'], comm['weight_attr']),
src=min(comm['ranks']), src=min(comm['ranks']),
group=comm['group']) group=comm['group'],
)
for param in comm['layer'].parameters(): for param in comm['layer'].parameters():
if self.global_rank != min(comm['ranks']): if self.global_rank != min(comm['ranks']):
...@@ -458,8 +504,9 @@ class PipelineLayer(Layer): ...@@ -458,8 +504,9 @@ class PipelineLayer(Layer):
# need use trace_op to allreduce weight # need use trace_op to allreduce weight
if in_dygraph_mode(): if in_dygraph_mode():
with paddle.framework.no_grad(): with paddle.framework.no_grad():
paddle.distributed.all_reduce(param.grad, paddle.distributed.all_reduce(
group=comm['group']) param.grad, group=comm['group']
)
else: else:
with paddle.framework.no_grad(): with paddle.framework.no_grad():
paddle.fluid.framework._dygraph_tracer().trace_op( paddle.fluid.framework._dygraph_tracer().trace_op(
...@@ -468,8 +515,9 @@ class PipelineLayer(Layer): ...@@ -468,8 +515,9 @@ class PipelineLayer(Layer):
outputs={'Out': param._grad_ivar()}, outputs={'Out': param._grad_ivar()},
attrs={ attrs={
'ring_id': comm['group'].id, 'ring_id': comm['group'].id,
'use_calc_stream': True 'use_calc_stream': True,
}) },
)
def _segment_network_for_interleave(self, seg_method): def _segment_network_for_interleave(self, seg_method):
logger.info("start segment network for interleave scheduler") logger.info("start segment network for interleave scheduler")
...@@ -477,14 +525,20 @@ class PipelineLayer(Layer): ...@@ -477,14 +525,20 @@ class PipelineLayer(Layer):
self._layers_desc, self._layers_desc,
num_parts=self._num_stages, num_parts=self._num_stages,
method=seg_method, method=seg_method,
num_virtual_pipeline_stage=self._num_virtual_pipeline_stages) num_virtual_pipeline_stage=self._num_virtual_pipeline_stages,
)
self.segment_parts = seg.do_segment() self.segment_parts = seg.do_segment()
logger.info("segment result:" + logger.info(
", ".join(str(arg) for arg in self.segment_parts)) "segment result:"
+ ", ".join(str(arg) for arg in self.segment_parts)
)
for i in range(self._stage_id, self._total_stages_with_virtual_stages, for i in range(
self._num_stages): self._stage_id,
self._total_stages_with_virtual_stages,
self._num_stages,
):
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers. # If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage. # Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage. # Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
...@@ -500,13 +554,15 @@ class PipelineLayer(Layer): ...@@ -500,13 +554,15 @@ class PipelineLayer(Layer):
def _segment_network(self, seg_method): def _segment_network(self, seg_method):
logger.info("start segment network..") logger.info("start segment network..")
seg = SegmentLayers(self._layers_desc, seg = SegmentLayers(
num_parts=self._num_stages, self._layers_desc, num_parts=self._num_stages, method=seg_method
method=seg_method) )
self.segment_parts = seg.do_segment() self.segment_parts = seg.do_segment()
logger.info("segment result:" + logger.info(
", ".join(str(arg) for arg in self.segment_parts)) "segment result:"
+ ", ".join(str(arg) for arg in self.segment_parts)
)
self._start_pos = self.segment_parts[self._stage_id] self._start_pos = self.segment_parts[self._stage_id]
self._end_pos = self.segment_parts[self._stage_id + 1] self._end_pos = self.segment_parts[self._stage_id + 1]
...@@ -514,22 +570,30 @@ class PipelineLayer(Layer): ...@@ -514,22 +570,30 @@ class PipelineLayer(Layer):
def _print_segmentation_for_debug(self): def _print_segmentation_for_debug(self):
# print information for debug # print information for debug
for stage in range(self._num_stages * for stage in range(
self._num_virtual_pipeline_stages): self._num_stages * self._num_virtual_pipeline_stages
):
start = self.segment_parts[stage] start = self.segment_parts[stage]
end = self.segment_parts[stage + 1] end = self.segment_parts[stage + 1]
logger.info("stage={}, global_rank={} ,layer_number={}".format( logger.info(
stage, self.global_rank, end - start)) "stage={}, global_rank={} ,layer_number={}".format(
stage, self.global_rank, end - start
)
)
for index, layer in enumerate(self._layers_desc[start:end]): for index, layer in enumerate(self._layers_desc[start:end]):
logger.info("{}: {}".format(index + start, str(layer))) logger.info("{}: {}".format(index + start, str(layer)))
if self._num_virtual_pipeline_stages > 1: if self._num_virtual_pipeline_stages > 1:
for stage in range(self._num_stages): for stage in range(self._num_stages):
stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format( stage_to_virtual_stage_info = (
stage) "stage {} contains virtual stages: ".format(stage)
for i in range(stage, self._total_stages_with_virtual_stages, )
self._num_stages): for i in range(
stage,
self._total_stages_with_virtual_stages,
self._num_stages,
):
stage_to_virtual_stage_info += " {},".format(i) stage_to_virtual_stage_info += " {},".format(i)
logger.info(stage_to_virtual_stage_info) logger.info(stage_to_virtual_stage_info)
...@@ -575,9 +639,11 @@ class PipelineLayer(Layer): ...@@ -575,9 +639,11 @@ class PipelineLayer(Layer):
if layer.layer_name not in self.shared_layers: if layer.layer_name not in self.shared_layers:
self.shared_layers[layer.layer_name] = layer.build_layer() self.shared_layers[layer.layer_name] = layer.build_layer()
self.shared_weight_attrs[ self.shared_weight_attrs[
layer.layer_name] = layer.shared_weight_attr layer.layer_name
] = layer.shared_weight_attr
for param in self.shared_layers[ for param in self.shared_layers[
layer.layer_name].parameters(): layer.layer_name
].parameters():
setattr(param, "is_firstly_shared", True) setattr(param, "is_firstly_shared", True)
if layer.forward_func is None: if layer.forward_func is None:
...@@ -585,8 +651,11 @@ class PipelineLayer(Layer): ...@@ -585,8 +651,11 @@ class PipelineLayer(Layer):
else: else:
run_function.append( run_function.append(
partial(layer.forward_func, partial(
self.shared_layers[layer.layer_name])) layer.forward_func,
self.shared_layers[layer.layer_name],
)
)
elif isinstance(layer, LayerDesc): elif isinstance(layer, LayerDesc):
model = layer.build_layer() model = layer.build_layer()
...@@ -615,11 +684,15 @@ class PipelineLayer(Layer): ...@@ -615,11 +684,15 @@ class PipelineLayer(Layer):
def forward(self, input, chunk_id=None): def forward(self, input, chunk_id=None):
if chunk_id is not None: if chunk_id is not None:
assert isinstance(chunk_id, int), "chunk_id should be an int" assert isinstance(chunk_id, int), "chunk_id should be an int"
assert self._num_virtual_pipeline_stages > 1, \ assert (
"chunk_id is only valid when using virtual pipeline stage" self._num_virtual_pipeline_stages > 1
assert chunk_id < len(self._model_chunks), \ ), "chunk_id is only valid when using virtual pipeline stage"
"The virtual pipeline only has {} chunks, " \ assert chunk_id < len(self._model_chunks), (
"but received chunk_id {}.".format(len(self._model_chunks), chunk_id) "The virtual pipeline only has {} chunks, "
"but received chunk_id {}.".format(
len(self._model_chunks), chunk_id
)
)
# Get the target model chunk. # Get the target model chunk.
model_chunk = self._model_chunks[chunk_id] model_chunk = self._model_chunks[chunk_id]
# Update the self.run_function to the target run functions. # Update the self.run_function to the target run functions.
...@@ -637,20 +710,25 @@ class PipelineLayer(Layer): ...@@ -637,20 +710,25 @@ class PipelineLayer(Layer):
funcs = self.run_function[start_idx:end_idx] funcs = self.run_function[start_idx:end_idx]
if not isinstance(input, tuple): if not isinstance(input, tuple):
input = (input, ) input = (input,)
if self._need_recompute(funcs, input): if self._need_recompute(funcs, input):
input = recompute_hybrid( input = recompute_hybrid(
self.recompute_ctx, self.recompute_ctx,
self.forward_function(start_idx, end_idx), *input) self.forward_function(start_idx, end_idx),
*input
)
else: else:
input = self.forward_function(start_idx, end_idx)(*input) input = self.forward_function(start_idx, end_idx)(*input)
return input return input
def _need_recompute(self, funcs, inputs): def _need_recompute(self, funcs, inputs):
if not any(input_.stop_gradient == False if not any(
for input_ in inputs if isinstance(input_, paddle.Tensor)): input_.stop_gradient == False
for input_ in inputs
if isinstance(input_, paddle.Tensor)
):
return False return False
params = [f.parameters() for f in funcs if isinstance(f, Layer)] params = [f.parameters() for f in funcs if isinstance(f, Layer)]
...@@ -674,11 +752,18 @@ class PipelineLayer(Layer): ...@@ -674,11 +752,18 @@ class PipelineLayer(Layer):
if self._num_virtual_pipeline_stages > 1: if self._num_virtual_pipeline_stages > 1:
# add virtual pipeline info to the save path # add virtual pipeline info to the save path
assert local_chunk_id is not None assert local_chunk_id is not None
virtual_pipeline_stage_message = "-virtual_pp_stage_{:0>2d}".format( virtual_pipeline_stage_message = (
local_chunk_id) "-virtual_pp_stage_{:0>2d}".format(local_chunk_id)
layer_save_path = os.path.join(ckpt_dir, )
'layer_{:0>2d}'.format(idx)) layer_save_path = os.path.join(
layer_save_path = layer_save_path + virtual_pipeline_stage_message + rank_message + '-model_states.pdparams' ckpt_dir, 'layer_{:0>2d}'.format(idx)
)
layer_save_path = (
layer_save_path
+ virtual_pipeline_stage_message
+ rank_message
+ '-model_states.pdparams'
)
return layer_save_path return layer_save_path
def _save_model(run_functions, local_chunk_id=None): def _save_model(run_functions, local_chunk_id=None):
...@@ -701,7 +786,8 @@ class PipelineLayer(Layer): ...@@ -701,7 +786,8 @@ class PipelineLayer(Layer):
def set_state_dir(self, path): def set_state_dir(self, path):
assert os.path.exists( assert os.path.exists(
path), "{} not found, please check the path".format(path) path
), "{} not found, please check the path".format(path)
def _load_model(run_functions, local_chunk_id=None): def _load_model(run_functions, local_chunk_id=None):
for idx, layer in enumerate(run_functions): for idx, layer in enumerate(run_functions):
...@@ -715,21 +801,26 @@ class PipelineLayer(Layer): ...@@ -715,21 +801,26 @@ class PipelineLayer(Layer):
pos_offset = self._start_poss[local_chunk_id] pos_offset = self._start_poss[local_chunk_id]
layer_idx = idx + pos_offset layer_idx = idx + pos_offset
layer_save_path = os.path.join( layer_save_path = os.path.join(
path, 'layer_{0:0>2d}'.format(layer_idx)) path, 'layer_{0:0>2d}'.format(layer_idx)
)
if self._num_virtual_pipeline_stages > 1: if self._num_virtual_pipeline_stages > 1:
# add virtual pipeline info to the path # add virtual pipeline info to the path
assert local_chunk_id is not None assert local_chunk_id is not None
layer_save_path = layer_save_path + "-virtual_pp_stage_{:0>2d}".format( layer_save_path = (
local_chunk_id) layer_save_path
model_files = glob.glob(layer_save_path + + "-virtual_pp_stage_{:0>2d}".format(local_chunk_id)
"*model_states.pdparams") )
model_files = glob.glob(
layer_save_path + "*model_states.pdparams"
)
model_files.sort() model_files.sort()
mp_rank = self._topo.get_coord(self.global_rank).model mp_rank = self._topo.get_coord(self.global_rank).model
mp_world_size = self._topo.get_dim('model') mp_world_size = self._topo.get_dim('model')
num_files = len(model_files) num_files = len(model_files)
load_param_path = model_files[mp_rank * num_files // load_param_path = model_files[
mp_world_size] mp_rank * num_files // mp_world_size
]
model_state_dict = paddle.load(load_param_path) model_state_dict = paddle.load(load_param_path)
layer.set_state_dict(model_state_dict) layer.set_state_dict(model_state_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册