未验证 提交 65b29d3a 编写于 作者: Y Yuang Liu 提交者: GitHub

[Hybrid Performance] Support VP DP comm overlap. (#54196)

上级 5f54a7fe
...@@ -703,6 +703,14 @@ class PipelineLayer(nn.Layer): ...@@ -703,6 +703,14 @@ class PipelineLayer(nn.Layer):
self.shared_layers[layer.layer_name], self.shared_layers[layer.layer_name],
) )
) )
# Note: the PipelineLayerChunk won't add the partial function to the sub layer,
# will introduce error when calling chunk.parameters(). Have to manually add
# this layer to the chunk's sub layer.
if self._num_virtual_pipeline_stages > 1:
run_function.add_sublayer(
layer.layer_name,
self.shared_layers[layer.layer_name],
)
elif isinstance(layer, LayerDesc): elif isinstance(layer, LayerDesc):
model = layer.build_layer() model = layer.build_layer()
......
...@@ -71,6 +71,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -71,6 +71,8 @@ class PipelineParallel(MetaParallelBase):
self._delay_scale_loss = self._strategy.hybrid_configs[ self._delay_scale_loss = self._strategy.hybrid_configs[
"pp_configs" "pp_configs"
].delay_scale_loss ].delay_scale_loss
# TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
# This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
self._dp_comm_overlap = self._strategy.hybrid_configs[ self._dp_comm_overlap = self._strategy.hybrid_configs[
"pp_configs" "pp_configs"
].dp_comm_overlap ].dp_comm_overlap
...@@ -152,18 +154,30 @@ class PipelineParallel(MetaParallelBase): ...@@ -152,18 +154,30 @@ class PipelineParallel(MetaParallelBase):
return fused_allreduce return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps): def register_allreduce_overlap_hook(self, model, comm_group, acc_steps):
parameter_list = [p for p in model.parameters() if not p.stop_gradient] if model.get_num_virtual_stages() > 1:
if len(parameter_list) < 1: models = model.get_model_chunks()
return else:
models = [model]
var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items(): for model in models:
buffer = FusedAllReduceBuffer( # For virtual pipeline. Will separate parameters in different chunk into
group_idx, parameters, comm_group, acc_steps # different groups to get the best performance.
) parameter_list = [
self._dp_comm_buffers.append(buffer) p for p in model.parameters() if not p.stop_gradient
for param in parameters: ]
param._register_backward_hook(self.bw_hook_func(buffer, param)) if len(parameter_list) < 1:
return
var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items():
buffer = FusedAllReduceBuffer(
group_idx, parameters, comm_group, acc_steps
)
self._dp_comm_buffers.append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
)
def timer_printer(self): def timer_printer(self):
if not self._enable_timer: if not self._enable_timer:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册