未验证 提交 65b29d3a 编写于 作者: Y Yuang Liu 提交者: GitHub

[Hybrid Performance] Support VP DP comm overlap. (#54196)

上级 5f54a7fe
...@@ -703,6 +703,14 @@ class PipelineLayer(nn.Layer): ...@@ -703,6 +703,14 @@ class PipelineLayer(nn.Layer):
self.shared_layers[layer.layer_name], self.shared_layers[layer.layer_name],
) )
) )
# Note: the PipelineLayerChunk won't add the partial function to the sub layer,
# will introduce error when calling chunk.parameters(). Have to manually add
# this layer to the chunk's sub layer.
if self._num_virtual_pipeline_stages > 1:
run_function.add_sublayer(
layer.layer_name,
self.shared_layers[layer.layer_name],
)
elif isinstance(layer, LayerDesc): elif isinstance(layer, LayerDesc):
model = layer.build_layer() model = layer.build_layer()
......
...@@ -71,6 +71,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -71,6 +71,8 @@ class PipelineParallel(MetaParallelBase):
self._delay_scale_loss = self._strategy.hybrid_configs[ self._delay_scale_loss = self._strategy.hybrid_configs[
"pp_configs" "pp_configs"
].delay_scale_loss ].delay_scale_loss
# TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
# This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
self._dp_comm_overlap = self._strategy.hybrid_configs[ self._dp_comm_overlap = self._strategy.hybrid_configs[
"pp_configs" "pp_configs"
].dp_comm_overlap ].dp_comm_overlap
...@@ -152,7 +154,17 @@ class PipelineParallel(MetaParallelBase): ...@@ -152,7 +154,17 @@ class PipelineParallel(MetaParallelBase):
return fused_allreduce return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps): def register_allreduce_overlap_hook(self, model, comm_group, acc_steps):
parameter_list = [p for p in model.parameters() if not p.stop_gradient] if model.get_num_virtual_stages() > 1:
models = model.get_model_chunks()
else:
models = [model]
for model in models:
# For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance.
parameter_list = [
p for p in model.parameters() if not p.stop_gradient
]
if len(parameter_list) < 1: if len(parameter_list) < 1:
return return
...@@ -163,7 +175,9 @@ class PipelineParallel(MetaParallelBase): ...@@ -163,7 +175,9 @@ class PipelineParallel(MetaParallelBase):
) )
self._dp_comm_buffers.append(buffer) self._dp_comm_buffers.append(buffer)
for param in parameters: for param in parameters:
param._register_backward_hook(self.bw_hook_func(buffer, param)) param._register_backward_hook(
self.bw_hook_func(buffer, param)
)
def timer_printer(self): def timer_printer(self):
if not self._enable_timer: if not self._enable_timer:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册