提交 875c1f6d 编写于 作者: S sneaxiy

add chunk timer

上级 26a83ed1
...@@ -63,6 +63,7 @@ message PpConfig { ...@@ -63,6 +63,7 @@ message PpConfig {
optional bool delay_scale_loss = 2 [ default = false ]; optional bool delay_scale_loss = 2 [ default = false ];
optional bool enable_timer = 3 [ default = false ]; optional bool enable_timer = 3 [ default = false ];
optional bool sharding_comm_overlap = 4 [ default = false ]; optional bool sharding_comm_overlap = 4 [ default = false ];
optional bool enable_chunk_timer = 5 [ default = false ];
} }
message HybridConfig { message HybridConfig {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
import os import os
import sys import sys
import time
from collections import defaultdict from collections import defaultdict
import paddle import paddle
...@@ -36,6 +37,38 @@ __all__ = [] ...@@ -36,6 +37,38 @@ __all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
class ChunkTimer:
def __init__(self, group):
self.rank = group.get_group_rank(paddle.distributed.get_rank())
self.group = group
self.reset()
def begin(self):
paddle.distributed.barrier(self.group)
self.reset()
def reset(self):
self.begin = time.time()
self.records.clear()
def start(self, name):
paddle.device.cuda.synchronize()
t = time.time()
self.records.append([name, t, None])
def end(self, name):
paddle.device.cuda.synchronize()
t = time.time()
self.records[-1][-1] = t
def export_info(self):
return {
"rank": self.rank,
"begin": self.begin,
"records": self.records,
}
# assume only the first stage and last stage need data, and data consumption are ordred; # assume only the first stage and last stage need data, and data consumption are ordred;
# to be replaced by real micro dataset from reader # to be replaced by real micro dataset from reader
class FakeMicroDataset: class FakeMicroDataset:
...@@ -151,6 +184,14 @@ class PipelineParallel(MetaParallelBase): ...@@ -151,6 +184,14 @@ class PipelineParallel(MetaParallelBase):
] ]
self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape'] self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']
self._enable_chunk_timer = self._strategy.pipeline_configs[
'enable_chunk_timer'
]
if self._enable_chunk_timer:
self._chunk_timer = ChunkTimer()
else:
self._chunk_timer = None
self.num_stages = self._hcg.get_pipe_parallel_world_size() self.num_stages = self._hcg.get_pipe_parallel_world_size()
self.stage_id = self._hcg.get_stage_id() self.stage_id = self._hcg.get_stage_id()
self.pp_group = self._hcg.get_pipe_parallel_group() self.pp_group = self._hcg.get_pipe_parallel_group()
...@@ -204,7 +245,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -204,7 +245,7 @@ class PipelineParallel(MetaParallelBase):
) )
# construct pipeline meta info # construct pipeline meta info
self._p2p_helper = p2p.P2pHelper(self._using_cache) self._p2p_helper = p2p.P2pHelper(self._using_cache, self._chunk_timer)
self.global_rank = self._hcg.get_global_rank() self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0 self.micro_batch_id = 0
...@@ -234,6 +275,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -234,6 +275,12 @@ class PipelineParallel(MetaParallelBase):
self._layers, self.dp_group, self.accumulate_steps, True self._layers, self.dp_group, self.accumulate_steps, True
) )
def _export_chunk_timer_info(self):
if self._chunk_timer is not None:
return self._chunk_timer.export_info()
else:
return None
def is_pipeline_first_stage(self, ignore_virtual=False): def is_pipeline_first_stage(self, ignore_virtual=False):
if not ignore_virtual: if not ignore_virtual:
if self._virtual_pp_world_size is not None: if self._virtual_pp_world_size is not None:
...@@ -333,6 +380,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -333,6 +380,8 @@ class PipelineParallel(MetaParallelBase):
# this strategy is inspired by: # this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
if self._chunk_timer is not None:
self._chunk_timer.begin()
self.scaler = scaler self.scaler = scaler
# store total loss of entire batch # store total loss of entire batch
...@@ -564,6 +613,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -564,6 +613,8 @@ class PipelineParallel(MetaParallelBase):
return self.train_loss return self.train_loss
def _forward_step(self, input_tensor, micro_dataset, chunk_id=None): def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
if self._chunk_timer is not None:
self._chunk_timer.start("forward_step")
if self._enable_timer: if self._enable_timer:
self.timers("forward_step").start() self.timers("forward_step").start()
if self.is_pipeline_first_stage(): if self.is_pipeline_first_stage():
...@@ -601,6 +652,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -601,6 +652,8 @@ class PipelineParallel(MetaParallelBase):
self.micro_batch_id += 1 self.micro_batch_id += 1
if self._enable_timer: if self._enable_timer:
self.timers("forward_step").stop() self.timers("forward_step").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return output_tensor return output_tensor
def _check_micro_batch_data_valid(self, micro_batch_data): def _check_micro_batch_data_valid(self, micro_batch_data):
...@@ -614,6 +667,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -614,6 +667,8 @@ class PipelineParallel(MetaParallelBase):
), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}" ), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}"
def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
if self._chunk_timer is not None:
self._chunk_timer.start("backward_step")
if self._enable_timer: if self._enable_timer:
self.timers("backward_step").start() self.timers("backward_step").start()
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
...@@ -647,6 +702,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -647,6 +702,8 @@ class PipelineParallel(MetaParallelBase):
input_tensor_grad = input_tensor.grad input_tensor_grad = input_tensor.grad
if self._enable_timer: if self._enable_timer:
self.timers("backward_step").stop() self.timers("backward_step").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor_grad return input_tensor_grad
def _broadcast_final_loss(self): def _broadcast_final_loss(self):
...@@ -884,6 +941,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -884,6 +941,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self._using_cache self._using_cache
), "cache should be enabled for pipeline with interleave" ), "cache should be enabled for pipeline with interleave"
if self._chunk_timer is not None:
self._chunk_timer.begin()
# init some attributes for this batch run # init some attributes for this batch run
self.scaler = scaler self.scaler = scaler
self.total_loss = None self.total_loss = None
......
...@@ -455,9 +455,10 @@ def _p2p_helper( ...@@ -455,9 +455,10 @@ def _p2p_helper(
class P2pHelper: class P2pHelper:
def __init__(self, use_cache=True): def __init__(self, use_cache=True, chunk_timer=None):
self._send_recv_meta = SendRecvMeta() self._send_recv_meta = SendRecvMeta()
self._use_cache = use_cache self._use_cache = use_cache
self._chunk_timer = None
def _send_meta(self, output_tensor): def _send_meta(self, output_tensor):
if not self._send_recv_meta.has_send_meta: if not self._send_recv_meta.has_send_meta:
...@@ -473,6 +474,8 @@ class P2pHelper: ...@@ -473,6 +474,8 @@ class P2pHelper:
self._send_recv_meta.has_recv_meta = self._use_cache self._send_recv_meta.has_recv_meta = self._use_cache
def recv_forward(self, pp_first_stage, sync_recv=True): def recv_forward(self, pp_first_stage, sync_recv=True):
if self._chunk_timer is not None:
self._chunk_timer.start("recv_forward")
global _timers global _timers
if _timers is not None: if _timers is not None:
_timers("recv_forward").start() _timers("recv_forward").start()
...@@ -491,9 +494,13 @@ class P2pHelper: ...@@ -491,9 +494,13 @@ class P2pHelper:
) )
if _timers is not None: if _timers is not None:
_timers("recv_forward").stop() _timers("recv_forward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor return input_tensor
def recv_backward(self, pp_last_stage, sync_recv=True): def recv_backward(self, pp_last_stage, sync_recv=True):
if self._chunk_timer is not None:
self._chunk_timer.start("recv_backward")
global _timers global _timers
if _timers is not None: if _timers is not None:
_timers("recv_backward").start() _timers("recv_backward").start()
...@@ -510,9 +517,13 @@ class P2pHelper: ...@@ -510,9 +517,13 @@ class P2pHelper:
) )
if _timers is not None: if _timers is not None:
_timers("recv_backward").stop() _timers("recv_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return output_tensor_grad return output_tensor_grad
def send_forward(self, output_tensor, pp_last_stage): def send_forward(self, output_tensor, pp_last_stage):
if self._chunk_timer is not None:
self._chunk_timer.start("send_forward")
global _timers global _timers
if _timers is not None: if _timers is not None:
_timers("send_forward").start() _timers("send_forward").start()
...@@ -528,8 +539,12 @@ class P2pHelper: ...@@ -528,8 +539,12 @@ class P2pHelper:
) )
if _timers is not None: if _timers is not None:
_timers("send_forward").stop() _timers("send_forward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
def send_backward(self, input_tensor_grad, pp_first_stage): def send_backward(self, input_tensor_grad, pp_first_stage):
if self._chunk_timer is not None:
self._chunk_timer.start("send_backward")
global _timers global _timers
if _timers is not None: if _timers is not None:
_timers("send_backward").start() _timers("send_backward").start()
...@@ -543,8 +558,12 @@ class P2pHelper: ...@@ -543,8 +558,12 @@ class P2pHelper:
) )
if _timers is not None: if _timers is not None:
_timers("send_backward").stop() _timers("send_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
def send_forward_recv_backward(self, output_tensor, pp_last_stage): def send_forward_recv_backward(self, output_tensor, pp_last_stage):
if self._chunk_timer is not None:
self._chunk_timer.start("send_forward_recv_backward")
global _timers global _timers
if _timers is not None: if _timers is not None:
_timers("send_forward_recv_backward").start() _timers("send_forward_recv_backward").start()
...@@ -560,9 +579,13 @@ class P2pHelper: ...@@ -560,9 +579,13 @@ class P2pHelper:
) )
if _timers is not None: if _timers is not None:
_timers("send_forward_recv_backward").stop() _timers("send_forward_recv_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(self, input_tensor_grad, pp_first_stage): def send_backward_recv_forward(self, input_tensor_grad, pp_first_stage):
if self._chunk_timer is not None:
self._chunk_timer.start("send_backward_recv_forward")
global _timers global _timers
if _timers is not None: if _timers is not None:
_timers("send_backward_recv_forward").start() _timers("send_backward_recv_forward").start()
...@@ -578,11 +601,17 @@ class P2pHelper: ...@@ -578,11 +601,17 @@ class P2pHelper:
) )
if _timers is not None: if _timers is not None:
_timers("send_backward_recv_forward").stop() _timers("send_backward_recv_forward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor return input_tensor
def send_forward_backward_recv_forward_backward( def send_forward_backward_recv_forward_backward(
self, output_tensor, input_tensor_grad, recv_prev, recv_next self, output_tensor, input_tensor_grad, recv_prev, recv_next
): ):
if self._chunk_timer is not None:
self._chunk_timer.start(
"send_forward_backward_recv_forward_backward"
)
# always have to send dytpe info to downstream # always have to send dytpe info to downstream
global _timers global _timers
if _timers is not None: if _timers is not None:
...@@ -602,10 +631,14 @@ class P2pHelper: ...@@ -602,10 +631,14 @@ class P2pHelper:
) )
if _timers is not None: if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").stop() _timers("send_forward_backward_recv_forward_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
def send_forward_recv_forward(self, output_tensor, recv_prev): def send_forward_recv_forward(self, output_tensor, recv_prev):
# always have to send dytpe info to downstream # always have to send dytpe info to downstream
if self._chunk_timer is not None:
self._chunk_timer.start("send_forward_recv_forward")
global _timers global _timers
if _timers is not None: if _timers is not None:
_timers("send_forward_recv_forward").start() _timers("send_forward_recv_forward").start()
...@@ -624,9 +657,13 @@ class P2pHelper: ...@@ -624,9 +657,13 @@ class P2pHelper:
) )
if _timers is not None: if _timers is not None:
_timers("send_forward_recv_forward").stop() _timers("send_forward_recv_forward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor return input_tensor
def send_backward_recv_backward(self, input_tensor_grad, recv_next): def send_backward_recv_backward(self, input_tensor_grad, recv_next):
if self._chunk_timer is not None:
self._chunk_timer.start("send_backward_recv_backward")
global _timers global _timers
if _timers is not None: if _timers is not None:
_timers("send_backward_recv_backward").start() _timers("send_backward_recv_backward").start()
...@@ -640,6 +677,8 @@ class P2pHelper: ...@@ -640,6 +677,8 @@ class P2pHelper:
) )
if _timers is not None: if _timers is not None:
_timers("send_backward_recv_backward").stop() _timers("send_backward_recv_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return output_tensor_grad return output_tensor_grad
def __repr__(self): def __repr__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册