提交 875c1f6d 编写于 作者: S sneaxiy

add chunk timer

上级 26a83ed1
......@@ -63,6 +63,7 @@ message PpConfig {
optional bool delay_scale_loss = 2 [ default = false ];
optional bool enable_timer = 3 [ default = false ];
optional bool sharding_comm_overlap = 4 [ default = false ];
optional bool enable_chunk_timer = 5 [ default = false ];
}
message HybridConfig {
......
......@@ -13,6 +13,7 @@
import os
import sys
import time
from collections import defaultdict
import paddle
......@@ -36,6 +37,38 @@ __all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
class ChunkTimer:
def __init__(self, group):
self.rank = group.get_group_rank(paddle.distributed.get_rank())
self.group = group
self.reset()
def begin(self):
paddle.distributed.barrier(self.group)
self.reset()
def reset(self):
self.begin = time.time()
self.records.clear()
def start(self, name):
paddle.device.cuda.synchronize()
t = time.time()
self.records.append([name, t, None])
def end(self, name):
paddle.device.cuda.synchronize()
t = time.time()
self.records[-1][-1] = t
def export_info(self):
return {
"rank": self.rank,
"begin": self.begin,
"records": self.records,
}
# assume only the first stage and last stage need data, and data consumption are ordred;
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
......@@ -151,6 +184,14 @@ class PipelineParallel(MetaParallelBase):
]
self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']
self._enable_chunk_timer = self._strategy.pipeline_configs[
'enable_chunk_timer'
]
if self._enable_chunk_timer:
self._chunk_timer = ChunkTimer()
else:
self._chunk_timer = None
self.num_stages = self._hcg.get_pipe_parallel_world_size()
self.stage_id = self._hcg.get_stage_id()
self.pp_group = self._hcg.get_pipe_parallel_group()
......@@ -204,7 +245,7 @@ class PipelineParallel(MetaParallelBase):
)
# construct pipeline meta info
self._p2p_helper = p2p.P2pHelper(self._using_cache)
self._p2p_helper = p2p.P2pHelper(self._using_cache, self._chunk_timer)
self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0
......@@ -234,6 +275,12 @@ class PipelineParallel(MetaParallelBase):
self._layers, self.dp_group, self.accumulate_steps, True
)
def _export_chunk_timer_info(self):
if self._chunk_timer is not None:
return self._chunk_timer.export_info()
else:
return None
def is_pipeline_first_stage(self, ignore_virtual=False):
if not ignore_virtual:
if self._virtual_pp_world_size is not None:
......@@ -333,6 +380,8 @@ class PipelineParallel(MetaParallelBase):
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
if self._chunk_timer is not None:
self._chunk_timer.begin()
self.scaler = scaler
# store total loss of entire batch
......@@ -564,6 +613,8 @@ class PipelineParallel(MetaParallelBase):
return self.train_loss
def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
if self._chunk_timer is not None:
self._chunk_timer.start("forward_step")
if self._enable_timer:
self.timers("forward_step").start()
if self.is_pipeline_first_stage():
......@@ -601,6 +652,8 @@ class PipelineParallel(MetaParallelBase):
self.micro_batch_id += 1
if self._enable_timer:
self.timers("forward_step").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return output_tensor
def _check_micro_batch_data_valid(self, micro_batch_data):
......@@ -614,6 +667,8 @@ class PipelineParallel(MetaParallelBase):
), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}"
def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
if self._chunk_timer is not None:
self._chunk_timer.start("backward_step")
if self._enable_timer:
self.timers("backward_step").start()
with paddle.amp.auto_cast(enable=False):
......@@ -647,6 +702,8 @@ class PipelineParallel(MetaParallelBase):
input_tensor_grad = input_tensor.grad
if self._enable_timer:
self.timers("backward_step").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor_grad
def _broadcast_final_loss(self):
......@@ -884,6 +941,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self._using_cache
), "cache should be enabled for pipeline with interleave"
if self._chunk_timer is not None:
self._chunk_timer.begin()
# init some attributes for this batch run
self.scaler = scaler
self.total_loss = None
......
......@@ -455,9 +455,10 @@ def _p2p_helper(
class P2pHelper:
def __init__(self, use_cache=True):
def __init__(self, use_cache=True, chunk_timer=None):
self._send_recv_meta = SendRecvMeta()
self._use_cache = use_cache
self._chunk_timer = None
def _send_meta(self, output_tensor):
if not self._send_recv_meta.has_send_meta:
......@@ -473,6 +474,8 @@ class P2pHelper:
self._send_recv_meta.has_recv_meta = self._use_cache
def recv_forward(self, pp_first_stage, sync_recv=True):
if self._chunk_timer is not None:
self._chunk_timer.start("recv_forward")
global _timers
if _timers is not None:
_timers("recv_forward").start()
......@@ -491,9 +494,13 @@ class P2pHelper:
)
if _timers is not None:
_timers("recv_forward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor
def recv_backward(self, pp_last_stage, sync_recv=True):
if self._chunk_timer is not None:
self._chunk_timer.start("recv_backward")
global _timers
if _timers is not None:
_timers("recv_backward").start()
......@@ -510,9 +517,13 @@ class P2pHelper:
)
if _timers is not None:
_timers("recv_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return output_tensor_grad
def send_forward(self, output_tensor, pp_last_stage):
if self._chunk_timer is not None:
self._chunk_timer.start("send_forward")
global _timers
if _timers is not None:
_timers("send_forward").start()
......@@ -528,8 +539,12 @@ class P2pHelper:
)
if _timers is not None:
_timers("send_forward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
def send_backward(self, input_tensor_grad, pp_first_stage):
if self._chunk_timer is not None:
self._chunk_timer.start("send_backward")
global _timers
if _timers is not None:
_timers("send_backward").start()
......@@ -543,8 +558,12 @@ class P2pHelper:
)
if _timers is not None:
_timers("send_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
def send_forward_recv_backward(self, output_tensor, pp_last_stage):
if self._chunk_timer is not None:
self._chunk_timer.start("send_forward_recv_backward")
global _timers
if _timers is not None:
_timers("send_forward_recv_backward").start()
......@@ -560,9 +579,13 @@ class P2pHelper:
)
if _timers is not None:
_timers("send_forward_recv_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return output_tensor_grad
def send_backward_recv_forward(self, input_tensor_grad, pp_first_stage):
if self._chunk_timer is not None:
self._chunk_timer.start("send_backward_recv_forward")
global _timers
if _timers is not None:
_timers("send_backward_recv_forward").start()
......@@ -578,11 +601,17 @@ class P2pHelper:
)
if _timers is not None:
_timers("send_backward_recv_forward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor
def send_forward_backward_recv_forward_backward(
self, output_tensor, input_tensor_grad, recv_prev, recv_next
):
if self._chunk_timer is not None:
self._chunk_timer.start(
"send_forward_backward_recv_forward_backward"
)
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
......@@ -602,10 +631,14 @@ class P2pHelper:
)
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor, output_tensor_grad
def send_forward_recv_forward(self, output_tensor, recv_prev):
# always have to send dytpe info to downstream
if self._chunk_timer is not None:
self._chunk_timer.start("send_forward_recv_forward")
global _timers
if _timers is not None:
_timers("send_forward_recv_forward").start()
......@@ -624,9 +657,13 @@ class P2pHelper:
)
if _timers is not None:
_timers("send_forward_recv_forward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return input_tensor
def send_backward_recv_backward(self, input_tensor_grad, recv_next):
if self._chunk_timer is not None:
self._chunk_timer.start("send_backward_recv_backward")
global _timers
if _timers is not None:
_timers("send_backward_recv_backward").start()
......@@ -640,6 +677,8 @@ class P2pHelper:
)
if _timers is not None:
_timers("send_backward_recv_backward").stop()
if self._chunk_timer is not None:
self._chunk_timer.end()
return output_tensor_grad
def __repr__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册