未验证 提交 aac91e82 编写于 作者: Y Yuang Liu 提交者: GitHub

Two kinds of profiler to pp/vp (#54586)

上级 9b2bcfd6
......@@ -63,6 +63,7 @@ message PpConfig {
optional bool delay_scale_loss = 2 [ default = false ];
optional bool enable_timer = 3 [ default = false ];
optional bool sharding_comm_overlap = 4 [ default = false ];
optional bool profiling = 5 [ default = false ];
}
message HybridConfig {
......
......@@ -10,6 +10,8 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import time
import warnings
import paddle
from paddle import framework
......@@ -140,6 +142,7 @@ class PipelineParallel(MetaParallelBase):
self.num_stages = self._hcg.get_pipe_parallel_world_size()
self.stage_id = self._hcg.get_stage_id()
self.global_rank = self._hcg.get_global_rank()
self.pp_group = self._hcg.get_pipe_parallel_group()
self.dp_group = self._hcg.get_data_parallel_group()
self.sharding_group = self._hcg.get_sharding_parallel_group()
......@@ -163,6 +166,30 @@ class PipelineParallel(MetaParallelBase):
self._enable_timer = self._strategy.hybrid_configs[
"pp_configs"
].enable_timer
self._profiling = self._strategy.hybrid_configs["pp_configs"].profiling
self._records = []
self._record_format = (
'"name": "{}{}", "cat": "pipeline timeline", "ph": {}, "pid": 0, "tid": '
+ str(self.stage_id + 1)
+ ', "ts": {}, "cname": "{}"'
)
self._forward_color = "thread_state_running" # RGB: 126, 200, 148
self._backward_color = "rail_idle" # RGB: 238, 142, 0
if self._profiling:
logger.info(
"If enable pp profiling, the max training steps should be restricted "
"to a reasonable value (such as 5) to avoid generating large profile files. "
"The profiler will generate a profile file 'profile_record_tmp_file_for_rank_*' "
"for each rank. Users should gather all profile files for one entire pipeline "
"to one node (rank 0 is recommended) to get the full view of the pipeline profile. "
"[DONT CHANGE THE NAME OF THE PROFILE FILES!]. "
"Then get the profile parser from this url: "
"https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/distributed/fleet/meta_parallel/pp_utils/profiler_helper.py "
"and save the script to the same directory of all profile files."
"Parse those files by this command: `python profiler_helper.py`. "
"After parsing, a new file 'pipeline_profile.json' will be generated. "
"Users can inspect this file by chrome://tracing website."
)
if self._dp_comm_overlap:
assert self.use_data_parallel and self.num_stages > 1
......@@ -306,11 +333,51 @@ class PipelineParallel(MetaParallelBase):
all_flag_names = self.timers.timers.keys()
self.timers.log(all_flag_names)
def forward_backward_pipeline(self, data, scaler=None):
def _record_stamp(self, name, step, phase, color):
if self._profiling:
paddle.device.synchronize()
self._records.append(
'{'
+ self._record_format.format(
name,
step,
phase,
int(time.time() * 1000),
color,
)
+ '}'
)
def _flush_records(self):
if self._profiling:
with open(
f'./profile_record_tmp_file_for_rank_{self.global_rank}',
'a+',
) as f:
for record in self._records:
f.write(record + '\n')
self._records = []
def forward_backward_pipeline(
self, data, scaler=None, static_scheduler=False
):
# use the 1f1b scheduling strategy.
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
if static_scheduler:
assert (
not self._profiling
), "While _profiling, static scheduler is not available"
if data is not None:
warnings.warn(
"Static scheduler run won't real run the model, but data has been provided"
)
logger.info(
"enable static_scheduler will return the pp schedule instead of the loss"
)
schedule = ""
self.scaler = scaler
# store total loss of entire batch
......@@ -329,9 +396,15 @@ class PipelineParallel(MetaParallelBase):
micro_dataset = self._wrap_data(data)
for step_id in range(startup_steps):
if static_scheduler:
schedule += f"f{step_id};"
logger.info(f"forward step for micro step {step_id}")
continue
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
self._record_stamp("F", step_id, '"B"', self._forward_color)
output_tensor = self._forward_step(input_tensor, micro_dataset)
self._record_stamp("F", step_id, '"E"', self._forward_color)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
input_buffers.append(input_tensor)
......@@ -340,13 +413,25 @@ class PipelineParallel(MetaParallelBase):
if not self.is_pipeline_last_stage():
self._release_output(output_tensor)
if steady_steps > 0:
if steady_steps > 0 and not static_scheduler:
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
for i in range(steady_steps):
if static_scheduler:
schedule += f"f{startup_steps + i};"
schedule += f"b{i};"
logger.info(f"forward step for micro step {startup_steps + i}")
logger.info(f"backward step for micro step {i}")
continue
last_iter = i == (steady_steps - 1)
self._record_stamp(
"F", startup_steps + i, '"B"', self._forward_color
)
output_tensor = self._forward_step(input_tensor, micro_dataset)
self._record_stamp(
"F", startup_steps + i, '"E"', self._forward_color
)
output_tensor_grad = p2p.send_forward_recv_backward(
output_tensor, self.is_pipeline_last_stage()
......@@ -362,9 +447,11 @@ class PipelineParallel(MetaParallelBase):
0
), output_buffers.pop(0)
self._record_stamp("B", i, '"B"', self._backward_color)
input_tensor_grad = self._backward_step(
input_tensor, output_tensor, output_tensor_grad
)
self._record_stamp("B", i, '"E"', self._backward_color)
if last_iter:
input_tensor = None
......@@ -377,6 +464,10 @@ class PipelineParallel(MetaParallelBase):
)
for i in range(startup_steps):
if static_scheduler:
schedule += f"b{steady_steps + i};"
logger.info(f"backward step for micro step {steady_steps + i}")
continue
input_tensor = input_buffers.pop(0)
output_tensor = output_buffers.pop(0)
......@@ -384,11 +475,22 @@ class PipelineParallel(MetaParallelBase):
self.is_pipeline_last_stage()
)
self._record_stamp(
"B", steady_steps + i, '"B"', self._backward_color
)
input_tensor_grad = self._backward_step(
input_tensor, output_tensor, output_tensor_grad
)
self._record_stamp(
"B", steady_steps + i, '"E"', self._backward_color
)
p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
if static_scheduler:
return schedule
self._flush_records()
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
......@@ -687,12 +789,32 @@ class PipelineParallel(MetaParallelBase):
elif can_free(output):
output._clear_dataptr()
def get_static_scheduler(self):
return self.forward_backward_pipeline(data=None, static_scheduler=True)
class PipelineParallelWithInterleave(PipelineParallel):
# pipeline parallel with interleave scheduler
def __init__(self, layers, hcg, strategy):
super().__init__(layers=layers, hcg=hcg, strategy=strategy)
self._record_format = (
'"name": "{}{}_VP{}", "cat": "virtual pipeline timeline", "ph": {}, "pid": 0, "tid": '
+ str(self.stage_id + 1)
+ ', "ts": {}, "cname": "{}"'
)
self._forward_colors = [
"thread_state_running", # RGB: 126, 200, 148
"thread_state_unknown", # RGB: 199, 155, 125
]
self._backward_colors = [
"rail_load", # RGB: 13, 168, 97
"rail_idle", # RGB: 238, 142, 0
]
# Structures to record the micro step for each layer chunk
self._forward_micro_step_counter = {}
self._backward_micro_step_counter = {}
assert layers.get_num_virtual_stages() > 1
assert (
self.num_stages > 2
......@@ -710,6 +832,52 @@ class PipelineParallelWithInterleave(PipelineParallel):
assert len(self.model_chunks) == self.num_model_chunks
self._virtual_pp_world_size = self.num_model_chunks
self._virtual_pp_rank = 0
self._reset_counter()
def _reset_counter(self):
for i in range(self.num_model_chunks):
self._forward_micro_step_counter[i] = 0
self._backward_micro_step_counter[i] = 0
def _record_stamp(self, name, step, phase, forward=True):
if self._profiling:
paddle.device.synchronize()
virtual_pp_rank = self._get_virtual_pp_rank(step, forward=forward)
color_idx = virtual_pp_rank % 2
# Get the profile color and micro step for current layer chunk
if forward:
color = self._forward_colors[color_idx]
micro_step = self._forward_micro_step_counter[virtual_pp_rank]
if phase == '"E"':
self._forward_micro_step_counter[virtual_pp_rank] += 1
else:
color = self._backward_colors[color_idx]
micro_step = self._backward_micro_step_counter[virtual_pp_rank]
if phase == '"E"':
self._backward_micro_step_counter[virtual_pp_rank] += 1
self._records.append(
'{'
+ self._record_format.format(
name,
micro_step,
virtual_pp_rank,
phase,
int(time.time() * 1000),
color,
)
+ '}'
)
def _flush_records(self):
if self._profiling:
with open(
f'./profile_record_tmp_file_for_rank_{self.global_rank}',
'a+',
) as f:
for record in self._records:
f.write(record + '\n')
self._records = []
self._reset_counter()
def _get_virtual_pp_rank(self, micro_step, forward):
virtual_pp_stage = micro_step % (
......@@ -771,7 +939,12 @@ class PipelineParallelWithInterleave(PipelineParallel):
return input_tensor_grad
def forward_backward_pipeline(
self, data, scaler, forward_only=False, compute_loss=True
self,
data,
scaler,
forward_only=False,
compute_loss=True,
static_scheduler=False,
):
# use interleave scheduling strategy.
# this strategy is inspired by:
......@@ -781,6 +954,22 @@ class PipelineParallelWithInterleave(PipelineParallel):
not forward_only
), "compute_loss can only be set to False when forward_only is set to True"
if static_scheduler:
assert (
not forward_only
), "static_scheduler only for training not for eval"
assert (
not self._profiling
), "While _profiling, static scheduler is not available"
if data is not None:
warnings.warn(
"Static scheduler run won't real run the model, but data has been provided"
)
logger.info(
"enable static_scheduler will return the pp schedule instead of the loss"
)
schedule = ""
# init some attributes for this batch run
self.scaler = scaler
self.total_loss = None
......@@ -810,13 +999,32 @@ class PipelineParallelWithInterleave(PipelineParallel):
steady_steps = num_steps - startup_steps
self.set_virtual_pipeline_rank(0)
self.input_tensors[0].append(
p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
)
if not static_scheduler:
self.input_tensors[0].append(
p2p.recv_forward(
self.is_pipeline_first_stage(), sync_recv=False
)
)
# run startup steps
for micro_step in range(startup_steps):
if static_scheduler:
virtual_pp_rank = self._get_virtual_pp_rank(
micro_step, forward=True
)
real_micro_step = self._forward_micro_step_counter[
virtual_pp_rank
]
self._forward_micro_step_counter[virtual_pp_rank] += 1
schedule += f"f{real_micro_step}_vp{virtual_pp_rank};"
logger.info(
f"forward step for {real_micro_step} with virtual pp rank {virtual_pp_rank}"
)
continue
self._record_stamp("F", micro_step, '"B"', forward=True)
output_tensor = self._forward_step_helper(micro_dataset, micro_step)
self._record_stamp("F", micro_step, '"E"', forward=True)
# determine whether recv forward tensor or not
next_virtual_pp_rank = self._get_virtual_pp_rank(
......@@ -867,17 +1075,55 @@ class PipelineParallelWithInterleave(PipelineParallel):
# run 1f1b steady steps
for micro_step in range(steady_steps):
if static_scheduler:
forward_micro_step_id = micro_step + startup_steps
forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id, forward=True
)
backward_micro_step_id = micro_step
backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id, forward=False
)
real_forward_micro_step = self._forward_micro_step_counter[
forward_virtual_pp_rank
]
self._forward_micro_step_counter[forward_virtual_pp_rank] += 1
real_backward_micro_step = self._backward_micro_step_counter[
backward_virtual_pp_rank
]
self._backward_micro_step_counter[backward_virtual_pp_rank] += 1
schedule += (
f"f{real_forward_micro_step}_vp{forward_virtual_pp_rank};"
)
schedule += (
f"b{real_backward_micro_step}_vp{backward_virtual_pp_rank};"
)
logger.info(
f"forward step for {real_forward_micro_step} with virtual pp rank {forward_virtual_pp_rank}"
)
logger.info(
f"backward step for {real_backward_micro_step} with virtual pp rank {backward_virtual_pp_rank}"
)
continue
# forward
forward_micro_step_id = micro_step + startup_steps
self._record_stamp("F", forward_micro_step_id, '"B"', forward=True)
output_tensor = self._forward_step_helper(
micro_dataset, forward_micro_step_id
)
self._record_stamp("F", forward_micro_step_id, '"E"', forward=True)
# backward
backward_micro_step_id = micro_step
self._record_stamp(
"B", backward_micro_step_id, '"B"', forward=False
)
input_tensor_grad = self._backward_step_helper(
backward_micro_step_id
)
self._record_stamp(
"B", backward_micro_step_id, '"E"', forward=False
)
# four directions comm
# send output tensor to downstream
......@@ -946,13 +1192,29 @@ class PipelineParallelWithInterleave(PipelineParallel):
)
self._release_output(output_tensor)
self._release_output(output_tensor)
if not static_scheduler:
self._release_output(output_tensor)
# remaining backward steps
if not forward_only:
for micro_step in range(steady_steps, num_steps):
if static_scheduler:
virtual_pp_rank = self._get_virtual_pp_rank(
micro_step, forward=False
)
real_micro_step = self._backward_micro_step_counter[
virtual_pp_rank
]
self._backward_micro_step_counter[virtual_pp_rank] += 1
schedule += f"b{real_micro_step}_vp{virtual_pp_rank};"
logger.info(
f"backward step for {real_micro_step} with virtual pp rank {virtual_pp_rank}"
)
continue
# cooldown loop
self._record_stamp("B", micro_step, '"B"', forward=False)
input_tensor_grad = self._backward_step_helper(micro_step)
self._record_stamp("B", micro_step, '"E"', forward=False)
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
micro_step + 1, forward=False
)
......@@ -978,12 +1240,18 @@ class PipelineParallelWithInterleave(PipelineParallel):
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()
if static_scheduler:
self._reset_counter()
return schedule
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
self._layers.allreduce_shared_weight_gradients()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").stop()
self._flush_records()
if compute_loss:
# return loss if compute loss
if self._enable_timer:
......@@ -1018,3 +1286,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
self._compute_loss = compute_loss
return self.forward_backward_pipeline(data, None, forward_only=True)
def get_static_scheduler(self):
return self.forward_backward_pipeline(
data=None, scaler=None, static_scheduler=True
)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
def main():
all_record = []
all_files = os.listdir('./')
all_files = sorted(
filter(
lambda file: file.startswith("profile_record_tmp_file_for_rank_"),
all_files,
)
)
for files in all_files:
with open(files, 'r') as f:
for line in f:
all_record.append(line.strip())
with open('pipeline_profile.json', 'w') as f:
f.write('[ ')
for i in range(len(all_record) - 1):
f.write(all_record[i] + ',\n')
f.write(all_record[-1])
f.write(' ]\n')
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册