未验证 提交 af7c4a31 编写于 作者: Y Yuang Liu 提交者: GitHub

add timer to pp (#53831) + sharding pp overlap (#54312) (#54360)

* add timer to pp (#53831)

* [Hybrid Performance] Sharding stage 1 PP/VP overlap (#54312)
上级 e941b924
...@@ -61,6 +61,8 @@ message MpConfig { ...@@ -61,6 +61,8 @@ message MpConfig {
message PpConfig { message PpConfig {
optional bool dp_comm_overlap = 1 [ default = false ]; optional bool dp_comm_overlap = 1 [ default = false ];
optional bool delay_scale_loss = 2 [ default = false ]; optional bool delay_scale_loss = 2 [ default = false ];
optional bool enable_timer = 3 [ default = false ];
optional bool sharding_comm_overlap = 4 [ default = false ];
} }
message HybridConfig { message HybridConfig {
......
...@@ -15,6 +15,7 @@ import paddle ...@@ -15,6 +15,7 @@ import paddle
from paddle import framework from paddle import framework
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
from ..utils import timer_helper as timer
from ..utils.hybrid_parallel_util import ( from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters, broadcast_dp_parameters,
broadcast_mp_parameters, broadcast_mp_parameters,
...@@ -24,7 +25,7 @@ from ..utils.log_util import logger ...@@ -24,7 +25,7 @@ from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer from .parallel_layers.pp_layers import PipelineLayer
from .pp_utils import p2p_communication as p2p from .pp_utils import p2p_communication as p2p
from .pp_utils.utils import FusedAllReduceBuffer, assign_group_by_size from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = [] __all__ = []
...@@ -61,6 +62,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -61,6 +62,7 @@ class PipelineParallel(MetaParallelBase):
self.stage_id = self._hcg.get_stage_id() self.stage_id = self._hcg.get_stage_id()
self.pp_group = self._hcg.get_pipe_parallel_group() self.pp_group = self._hcg.get_pipe_parallel_group()
self.dp_group = self._hcg.get_data_parallel_group() self.dp_group = self._hcg.get_data_parallel_group()
self.sharding_group = self._hcg.get_sharding_parallel_group()
self._virtual_pp_world_size = None self._virtual_pp_world_size = None
self._virtual_pp_rank = None self._virtual_pp_rank = None
...@@ -75,13 +77,38 @@ class PipelineParallel(MetaParallelBase): ...@@ -75,13 +77,38 @@ class PipelineParallel(MetaParallelBase):
self._dp_comm_overlap = self._strategy.hybrid_configs[ self._dp_comm_overlap = self._strategy.hybrid_configs[
"pp_configs" "pp_configs"
].dp_comm_overlap ].dp_comm_overlap
self._dp_comm_buffers = [] self._sharding_comm_overlap = self._strategy.hybrid_configs[
"pp_configs"
].sharding_comm_overlap
self._enable_timer = self._strategy.hybrid_configs[
"pp_configs"
].enable_timer
if self._dp_comm_overlap: if self._dp_comm_overlap:
assert self.use_data_parallel and self.num_stages > 1 assert self.use_data_parallel and self.num_stages > 1
if self._sharding_comm_overlap:
assert self.use_sharding_parallel and self.num_stages > 1
assert not (
self._dp_comm_overlap and self._sharding_comm_overlap
), "Cannot use dp pp overlap and sharding pp overlap at the same time."
self._comm_buffers = []
self._comm_overlap = (
self._dp_comm_overlap or self._sharding_comm_overlap
)
if self._enable_timer:
if not timer.is_timer_initialized():
timer.set_timers()
self.timers = timer.get_timers()
p2p.initialize_p2p_groups( p2p.initialize_p2p_groups(
hcg, self._using_cache, self._enable_partial_send_recv hcg,
self._using_cache,
self._enable_partial_send_recv,
self._enable_timer,
) )
self.global_rank = self._hcg.get_global_rank() self.global_rank = self._hcg.get_global_rank()
...@@ -109,7 +136,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -109,7 +136,7 @@ class PipelineParallel(MetaParallelBase):
if self._dp_comm_overlap: if self._dp_comm_overlap:
self.register_allreduce_overlap_hook( self.register_allreduce_overlap_hook(
self._layers, self.dp_group, self.accumulate_steps self._layers, self.dp_group, self.accumulate_steps, True
) )
def is_pipeline_first_stage(self, ignore_virtual=False): def is_pipeline_first_stage(self, ignore_virtual=False):
...@@ -141,12 +168,21 @@ class PipelineParallel(MetaParallelBase): ...@@ -141,12 +168,21 @@ class PipelineParallel(MetaParallelBase):
return fused_allreduce return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps): def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
if model.get_num_virtual_stages() > 1: if model.get_num_virtual_stages() > 1:
models = model.get_model_chunks() models = model.get_model_chunks()
else: else:
models = [model] models = [model]
if not dp:
assert hasattr(self, "optimizer")
assert hasattr(self.optimizer, "_param2rank")
_param2rank = self.optimizer._param2rank
act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE
fused_parameter_group = {}
for model in models: for model in models:
# For virtual pipeline. Will separate parameters in different chunk into # For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance. # different groups to get the best performance.
...@@ -156,16 +192,39 @@ class PipelineParallel(MetaParallelBase): ...@@ -156,16 +192,39 @@ class PipelineParallel(MetaParallelBase):
if len(parameter_list) < 1: if len(parameter_list) < 1:
return return
var_groups = assign_group_by_size(parameter_list) if dp:
for group_idx, parameters in var_groups.items(): fused_parameter_group[-1] = parameter_list
buffer = FusedAllReduceBuffer( else:
group_idx, parameters, comm_group, acc_steps # Sort parameters for sharding, since they have different dst rank
) for p in parameter_list:
self._dp_comm_buffers.append(buffer) assert p.name in _param2rank
for param in parameters: dst_rank = _param2rank[p.name]
param._register_backward_hook( if dst_rank in fused_parameter_group:
self.bw_hook_func(buffer, param) fused_parameter_group[dst_rank].append(p)
else:
fused_parameter_group[dst_rank] = [p]
for dst in fused_parameter_group:
parameter_list = fused_parameter_group[dst]
if not dp:
# parse the relative dst rank to absolute dst rank for sharding
dst = comm_group.ranks[dst]
var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
group_idx, parameters, comm_group, acc_steps, act, dst
) )
self._comm_buffers.append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
)
def timer_printer(self):
if not self._enable_timer:
return
all_flag_names = self.timers.timers.keys()
self.timers.log(all_flag_names)
def forward_backward_pipeline(self, data, scaler=None): def forward_backward_pipeline(self, data, scaler=None):
# use the 1f1b scheduling strategy. # use the 1f1b scheduling strategy.
...@@ -245,14 +304,22 @@ class PipelineParallel(MetaParallelBase): ...@@ -245,14 +304,22 @@ class PipelineParallel(MetaParallelBase):
) )
p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage()) p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
if self._dp_comm_overlap: if self._comm_overlap:
assert len(self._dp_comm_buffers) > 0 assert len(self._comm_buffers) > 0
for buffer in self._dp_comm_buffers: for buffer in self._comm_buffers:
buffer.scale_and_split_grads() buffer.scale_and_split_grads()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
self._layers.allreduce_shared_weight_gradients() self._layers.allreduce_shared_weight_gradients()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").stop()
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss() train_loss = self._broadcast_final_loss()
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
self.timer_printer()
return train_loss return train_loss
def _prepare_training(self, data, optimizer, lr_scheduler): def _prepare_training(self, data, optimizer, lr_scheduler):
...@@ -281,6 +348,11 @@ class PipelineParallel(MetaParallelBase): ...@@ -281,6 +348,11 @@ class PipelineParallel(MetaParallelBase):
self._layers.train() self._layers.train()
if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
self.register_allreduce_overlap_hook(
self._layers, self.sharding_group, self.accumulate_steps, False
)
return data return data
def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
...@@ -348,6 +420,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -348,6 +420,8 @@ class PipelineParallel(MetaParallelBase):
return self.train_loss return self.train_loss
def _forward_step(self, input_tensor, chunk_id=None): def _forward_step(self, input_tensor, chunk_id=None):
if self._enable_timer:
self.timers("forward_step").start()
if self.is_pipeline_first_stage(): if self.is_pipeline_first_stage():
input_tensor = self._load_micro_batch(self.micro_batch_id) input_tensor = self._load_micro_batch(self.micro_batch_id)
...@@ -379,9 +453,13 @@ class PipelineParallel(MetaParallelBase): ...@@ -379,9 +453,13 @@ class PipelineParallel(MetaParallelBase):
# Only increase micro batch id at virtual first/last pp stage. # Only increase micro batch id at virtual first/last pp stage.
# The micro batch id is used to load data, therefore, only increase it when load data. # The micro batch id is used to load data, therefore, only increase it when load data.
self.micro_batch_id += 1 self.micro_batch_id += 1
if self._enable_timer:
self.timers("forward_step").stop()
return output_tensor return output_tensor
def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
if self._enable_timer:
self.timers("backward_step").start()
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
if self.is_pipeline_last_stage(): if self.is_pipeline_last_stage():
assert output_tensor_grad is None assert output_tensor_grad is None
...@@ -411,6 +489,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -411,6 +489,8 @@ class PipelineParallel(MetaParallelBase):
) )
else: else:
input_tensor_grad = input_tensor.grad input_tensor_grad = input_tensor.grad
if self._enable_timer:
self.timers("backward_step").stop()
return input_tensor_grad return input_tensor_grad
def _check_data_vaild(self, data): def _check_data_vaild(self, data):
...@@ -816,21 +896,30 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -816,21 +896,30 @@ class PipelineParallelWithInterleave(PipelineParallel):
) )
) )
if self._dp_comm_overlap: if self._comm_overlap:
assert len(self._dp_comm_buffers) > 0 assert len(self._comm_buffers) > 0
for buffer in self._dp_comm_buffers: for buffer in self._comm_buffers:
buffer.scale_and_split_grads() buffer.scale_and_split_grads()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
self._layers.allreduce_shared_weight_gradients() self._layers.allreduce_shared_weight_gradients()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").stop()
if compute_loss: if compute_loss:
# return loss if compute loss # return loss if compute loss
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss() train_loss = self._broadcast_final_loss()
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else: else:
# else just return all intermediate output tensor for all micro steps # else just return all intermediate output tensor for all micro steps
train_loss = self.output_tensors train_loss = self.output_tensors
self.timer_printer()
return train_loss return train_loss
def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
......
...@@ -25,22 +25,24 @@ from paddle.distributed.communication.group import ( ...@@ -25,22 +25,24 @@ from paddle.distributed.communication.group import (
_warn_cur_rank_not_in_group, _warn_cur_rank_not_in_group,
) )
from ...utils import timer_helper as timer
from .utils import number_2_dtype, paddle_2_number from .utils import number_2_dtype, paddle_2_number
_hcg = None _hcg = None
_use_cache = False _use_cache = False
_enable_partial_send_recv = True _enable_partial_send_recv = True
_timers = None
def initialize_p2p_groups( def initialize_p2p_groups(
hcg, hcg, use_cache=True, enable_partial_send_recv=True, enable_timer=False
use_cache=True,
enable_partial_send_recv=True,
): ):
global _hcg, _use_cache, _enable_partial_send_recv global _hcg, _use_cache, _enable_partial_send_recv, _timers
_hcg = hcg _hcg = hcg
_use_cache = use_cache _use_cache = use_cache
_enable_partial_send_recv = enable_partial_send_recv _enable_partial_send_recv = enable_partial_send_recv
if enable_timer:
_timers = timer.get_timers()
class SendRecvMeta: class SendRecvMeta:
...@@ -537,6 +539,9 @@ def _p2p_helper( ...@@ -537,6 +539,9 @@ def _p2p_helper(
def recv_forward(pp_first_stage, sync_recv=True): def recv_forward(pp_first_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_forward").start()
if pp_first_stage: if pp_first_stage:
input_tensor = None input_tensor = None
else: else:
...@@ -551,10 +556,15 @@ def recv_forward(pp_first_stage, sync_recv=True): ...@@ -551,10 +556,15 @@ def recv_forward(pp_first_stage, sync_recv=True):
recv_next=False, recv_next=False,
sync_recv=sync_recv, sync_recv=sync_recv,
) )
if _timers is not None:
_timers("recv_forward").stop()
return input_tensor return input_tensor
def recv_backward(pp_last_stage, sync_recv=True): def recv_backward(pp_last_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_backward").start()
if pp_last_stage: if pp_last_stage:
output_tensor_grad = None output_tensor_grad = None
else: else:
...@@ -565,10 +575,15 @@ def recv_backward(pp_last_stage, sync_recv=True): ...@@ -565,10 +575,15 @@ def recv_backward(pp_last_stage, sync_recv=True):
recv_next=True, recv_next=True,
sync_recv=sync_recv, sync_recv=sync_recv,
) )
if _timers is not None:
_timers("recv_backward").stop()
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, pp_last_stage): def send_forward(output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward").start()
if not pp_last_stage: if not pp_last_stage:
if not _send_recv_meta.has_send_meta: if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor) _send_recv_meta.set_send_message(output_tensor)
...@@ -583,9 +598,14 @@ def send_forward(output_tensor, pp_last_stage): ...@@ -583,9 +598,14 @@ def send_forward(output_tensor, pp_last_stage):
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False,
) )
if _timers is not None:
_timers("send_forward").stop()
def send_backward(input_tensor_grad, pp_first_stage): def send_backward(input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward").start()
if not pp_first_stage: if not pp_first_stage:
_p2p_helper( _p2p_helper(
tensor_send_next=None, tensor_send_next=None,
...@@ -593,9 +613,14 @@ def send_backward(input_tensor_grad, pp_first_stage): ...@@ -593,9 +613,14 @@ def send_backward(input_tensor_grad, pp_first_stage):
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False,
) )
if _timers is not None:
_timers("send_backward").stop()
def send_forward_recv_backward(output_tensor, pp_last_stage): def send_forward_recv_backward(output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward_recv_backward").start()
if pp_last_stage: if pp_last_stage:
output_tensor_grad = None output_tensor_grad = None
else: else:
...@@ -605,10 +630,15 @@ def send_forward_recv_backward(output_tensor, pp_last_stage): ...@@ -605,10 +630,15 @@ def send_forward_recv_backward(output_tensor, pp_last_stage):
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True,
) )
if _timers is not None:
_timers("send_forward_recv_backward").stop()
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, pp_first_stage): def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward_recv_forward").start()
if pp_first_stage: if pp_first_stage:
input_tensor = None input_tensor = None
else: else:
...@@ -618,6 +648,8 @@ def send_backward_recv_forward(input_tensor_grad, pp_first_stage): ...@@ -618,6 +648,8 @@ def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False,
) )
if _timers is not None:
_timers("send_backward_recv_forward").stop()
return input_tensor return input_tensor
...@@ -625,6 +657,9 @@ def send_forward_backward_recv_forward_backward( ...@@ -625,6 +657,9 @@ def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev, recv_next output_tensor, input_tensor_grad, recv_prev, recv_next
): ):
# always have to send dytpe info to downstream # always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").start()
if not _send_recv_meta.has_send_meta: if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor) _send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group()) _send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group())
...@@ -639,11 +674,16 @@ def send_forward_backward_recv_forward_backward( ...@@ -639,11 +674,16 @@ def send_forward_backward_recv_forward_backward(
recv_next=recv_next, recv_next=recv_next,
sync_recv=False, sync_recv=False,
) )
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").stop()
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
def send_forward_recv_forward(output_tensor, recv_prev): def send_forward_recv_forward(output_tensor, recv_prev):
# always have to send dytpe info to downstream # always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_recv_forward").start()
if not _send_recv_meta.has_send_meta: if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor) _send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group()) _send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group())
...@@ -659,10 +699,15 @@ def send_forward_recv_forward(output_tensor, recv_prev): ...@@ -659,10 +699,15 @@ def send_forward_recv_forward(output_tensor, recv_prev):
recv_next=False, recv_next=False,
sync_recv=False, sync_recv=False,
) )
if _timers is not None:
_timers("send_forward_recv_forward").stop()
return input_tensor return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next): def send_backward_recv_backward(input_tensor_grad, recv_next):
global _timers
if _timers is not None:
_timers("send_backward_recv_backward").start()
_, output_tensor_grad = _p2p_helper( _, output_tensor_grad = _p2p_helper(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
...@@ -670,4 +715,6 @@ def send_backward_recv_backward(input_tensor_grad, recv_next): ...@@ -670,4 +715,6 @@ def send_backward_recv_backward(input_tensor_grad, recv_next):
recv_next=recv_next, recv_next=recv_next,
sync_recv=False, sync_recv=False,
) )
if _timers is not None:
_timers("send_backward_recv_backward").stop()
return output_tensor_grad return output_tensor_grad
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,6 +24,12 @@ from paddle.framework import base as imperative_base ...@@ -24,6 +24,12 @@ from paddle.framework import base as imperative_base
__all__ = [] __all__ = []
class HOOK_ACTION:
ALL_REDUCE = 0
REDUCE = 1
FLOAT_TYPE_DICT = { FLOAT_TYPE_DICT = {
paddle.float16: "float16", paddle.float16: "float16",
paddle.float32: "float32", paddle.float32: "float32",
...@@ -114,8 +120,16 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ...@@ -114,8 +120,16 @@ def _all_gather(tensor, group=None, use_calc_stream=True):
) )
class FusedAllReduceBuffer: class FusedCommBuffer:
def __init__(self, id, params, comm_group, acc_steps=1): def __init__(
self,
id,
params,
comm_group,
acc_steps=1,
act=None,
dst=-1,
):
self._id = id self._id = id
self._params = params self._params = params
self._acc_steps = acc_steps self._acc_steps = acc_steps
...@@ -127,6 +141,17 @@ class FusedAllReduceBuffer: ...@@ -127,6 +141,17 @@ class FusedAllReduceBuffer:
self._params_checked_in = 0 self._params_checked_in = 0
self._coalesced_grads_and_grad_vars = [] self._coalesced_grads_and_grad_vars = []
self._act = act
if self._act == HOOK_ACTION.ALL_REDUCE:
assert dst == -1
elif self._act == HOOK_ACTION.REDUCE:
assert dst != -1
else:
raise ValueError(
"The act should be allreudce for dp or reduce for sharding."
)
self._dst = dst
self._init_step_dict() self._init_step_dict()
def _init_step_dict(self): def _init_step_dict(self):
...@@ -164,10 +189,10 @@ class FusedAllReduceBuffer: ...@@ -164,10 +189,10 @@ class FusedAllReduceBuffer:
self._params_step_dict.pop(param.name) self._params_step_dict.pop(param.name)
if self._all_params_checked_in: if self._all_params_checked_in:
self._fused_allreduce_grads() self._fused_comm_grads()
@imperative_base.no_grad @imperative_base.no_grad
def _fused_allreduce_grads(self): def _fused_comm_grads(self):
assert self._all_params_checked_in assert self._all_params_checked_in
flattened_vars = [] flattened_vars = []
g_var_shapes = [] g_var_shapes = []
...@@ -184,11 +209,18 @@ class FusedAllReduceBuffer: ...@@ -184,11 +209,18 @@ class FusedAllReduceBuffer:
) )
for coalesced_grad, _, _ in self._coalesced_grads_and_grad_vars: for coalesced_grad, _, _ in self._coalesced_grads_and_grad_vars:
self._tasks.append( if self._act == HOOK_ACTION.ALL_REDUCE:
paddle.distributed.all_reduce( task = paddle.distributed.all_reduce(
coalesced_grad, group=self._comm_group, sync_op=False coalesced_grad, group=self._comm_group, sync_op=False
) )
) elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
coalesced_grad,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
self._tasks.append(task)
@imperative_base.no_grad @imperative_base.no_grad
def scale_and_split_grads(self): def scale_and_split_grads(self):
......
...@@ -71,6 +71,11 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None): ...@@ -71,6 +71,11 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None):
].dp_comm_overlap: ].dp_comm_overlap:
hp_optim._dp_enable = False hp_optim._dp_enable = False
if fleet_env._user_defined_strategy.hybrid_configs[
"pp_configs"
].sharding_comm_overlap:
hp_optim._sharding_enable = False
return hp_optim return hp_optim
else: else:
return HeterParallelOptimizer( return HeterParallelOptimizer(
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import paddle
_GLOBAL_TIMERS = None
def is_timer_initialized():
return _GLOBAL_TIMERS is not None
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, f"{name} has been already initialized."
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, f"{name} is not initialized."
def get_timers():
_ensure_var_is_initialized(_GLOBAL_TIMERS, "timers")
return _GLOBAL_TIMERS
def set_timers():
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers")
_GLOBAL_TIMERS = Timers()
class _Timer:
"""Timer."""
def __init__(self, name):
self.name = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, "timer has already started"
paddle.device.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timers."""
assert self.started_, "timer is not started."
paddle.device.cuda.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = "time (ms)"
for name in names:
elapsed_time = (
self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
)
string += f" | {name}: {elapsed_time:.2f}"
print(string, flush=True)
...@@ -148,6 +148,9 @@ class TestDistPPTraning(unittest.TestCase): ...@@ -148,6 +148,9 @@ class TestDistPPTraning(unittest.TestCase):
"dp_degree": self.data_parallel_size, "dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size, "mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size, "pp_degree": self.pipeline_parallel_size,
"pp_configs": {
"enable_timer": True,
},
} }
strategy.pipeline_configs = { strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size, "accumulate_steps": batch_size // micro_batch_size,
......
...@@ -145,6 +145,7 @@ class TestDistPPDelayScaleLoss(TestDistPPTraning): ...@@ -145,6 +145,7 @@ class TestDistPPDelayScaleLoss(TestDistPPTraning):
"pp_degree": self.pipeline_parallel_size, "pp_degree": self.pipeline_parallel_size,
"pp_configs": { "pp_configs": {
"delay_scale_loss": True, "delay_scale_loss": True,
"enable_timer": True,
}, },
} }
strategy.pipeline_configs = { strategy.pipeline_configs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册