未验证 提交 65b29d3a 编写于 作者: Y Yuang Liu 提交者: GitHub

[Hybrid Performance] Support VP DP comm overlap. (#54196)

上级 5f54a7fe
......@@ -703,6 +703,14 @@ class PipelineLayer(nn.Layer):
self.shared_layers[layer.layer_name],
)
)
# Note: the PipelineLayerChunk won't add the partial function to the sub layer,
# will introduce error when calling chunk.parameters(). Have to manually add
# this layer to the chunk's sub layer.
if self._num_virtual_pipeline_stages > 1:
run_function.add_sublayer(
layer.layer_name,
self.shared_layers[layer.layer_name],
)
elif isinstance(layer, LayerDesc):
model = layer.build_layer()
......
......@@ -71,6 +71,8 @@ class PipelineParallel(MetaParallelBase):
self._delay_scale_loss = self._strategy.hybrid_configs[
"pp_configs"
].delay_scale_loss
# TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
# This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
self._dp_comm_overlap = self._strategy.hybrid_configs[
"pp_configs"
].dp_comm_overlap
......@@ -152,18 +154,30 @@ class PipelineParallel(MetaParallelBase):
return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps):
parameter_list = [p for p in model.parameters() if not p.stop_gradient]
if len(parameter_list) < 1:
return
var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items():
buffer = FusedAllReduceBuffer(
group_idx, parameters, comm_group, acc_steps
)
self._dp_comm_buffers.append(buffer)
for param in parameters:
param._register_backward_hook(self.bw_hook_func(buffer, param))
if model.get_num_virtual_stages() > 1:
models = model.get_model_chunks()
else:
models = [model]
for model in models:
# For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance.
parameter_list = [
p for p in model.parameters() if not p.stop_gradient
]
if len(parameter_list) < 1:
return
var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items():
buffer = FusedAllReduceBuffer(
group_idx, parameters, comm_group, acc_steps
)
self._dp_comm_buffers.append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
)
def timer_printer(self):
if not self._enable_timer:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册