提交 f8f5c7a9 编写于 作者: Y yelihua

classify reduce events into different types

上级 6c82ec3e
......@@ -44,6 +44,7 @@ class StepTraceParser:
_event_size = 20
_fp_tag = 1
_bp_tag = 2
_end_tag = 255
def __init__(self, input_dir, output_file_path, job_id=0, skip_first_step=False):
self._input_dir = input_dir
......@@ -53,6 +54,7 @@ class StepTraceParser:
self._result = []
self._header = []
self._step_num = 0
self._tag_map = {}
@property
def output_file(self):
......@@ -107,6 +109,46 @@ class StepTraceParser:
raise ProfilerIOException
return points
def update_tag_op_type_map(self, point_info):
"""
update the map from tag id to op type.
Args:
point_info (dict): The point info about tag id and relative op name.
"""
tag_map = {}
for tag, op_name in point_info.items():
op_type = self._get_op_type(tag, op_name)
tag_map[tag] = op_type
log.info("Get tag types for step trace analysis: %s", tag_map)
self._tag_map = tag_map
def _get_op_type(self, tag, name):
"""
Get op type from tag and name.
Args:
tag (int): The tag id.
name (str): The op name.
Returns:
str, the op type.
"""
tag_map = {self._fp_tag: 'fp', self._bp_tag: 'bp', self._end_tag: 'end'}
# get solid tag type
op_type = tag_map.get(tag, '')
if op_type:
return op_type
# check if the tag is step tag.
if tag > self._end_tag or tag == 0:
return 'start'
# analyze the reduce tag
op_type = name.rsplit('/', 1)[-1].split('-')[0]
if not op_type:
log.warning("Unexpected op name:%s", name)
return op_type
def _get_step_trace_files(self):
"""Get step trace files."""
# step trace files may under $profiler_dir or $profiler_dir/data
......@@ -207,13 +249,13 @@ class StepTraceParser:
event_info['start'] = start_time
event_info['reduce'] = {}
def _on_reduce_event():
def _on_reduce_event(reduce_tag_id):
"""Handle reduce event."""
stream_id = next_event.stream_id
if event_info['reduce'].get(stream_id):
event_info['reduce'][stream_id].append(sys_count)
event_info['reduce'][stream_id].append((reduce_tag_id, sys_count))
else:
event_info['reduce'][stream_id] = [sys_count]
event_info['reduce'][stream_id] = [(reduce_tag_id, sys_count)]
tag_id = next_event.tag_id
sys_count = next_event.sys_count
......@@ -226,10 +268,10 @@ class StepTraceParser:
elif bp_flag(tag_id):
event_info['bp'] = sys_count
else:
_on_reduce_event()
_on_reduce_event(tag_id)
def _validate_tag_id(self, job_id):
"""Check the job id in source step trace file is same os user set."""
"""Check the job id in source step trace file is same as user set."""
if not self._job_id:
self._job_id = job_id
elif self._job_id != job_id:
......@@ -243,7 +285,7 @@ class StepTraceParser:
fp_time = step_trace.get('fp')
bp_time = step_trace.get('bp')
if not (start_time and end_time and fp_time and bp_time):
log.warning("The step %d is missing basic time.", self._step_num)
log.warning("The step %d lacks basic time.", self._step_num)
return
if start_time == '-':
start_time = fp_time
......@@ -266,8 +308,7 @@ class StepTraceParser:
row_data_list = [row_data.get(header_name, 0) for header_name in self._header]
self._result.append(row_data_list)
@staticmethod
def _update_reduce_info(step_trace, row_data):
def _update_reduce_info(self, step_trace, row_data):
"""Extract reduce info."""
reduce_time = step_trace.get('reduce', {})
for stream_id, time_points in reduce_time.items():
......@@ -276,10 +317,39 @@ class StepTraceParser:
log.warning("Stream %d has %d reduce time points.", stream_id, time_point_num)
continue
for index, point_id in enumerate(range(0, time_point_num, 2)):
field_name = f'stream_{stream_id}_parallel_{index}'
row_data[field_name + '_start_point'] = time_points[point_id]
row_data[field_name + '_end_point'] = time_points[point_id + 1]
row_data[field_name] = time_points[point_id + 1] - time_points[point_id]
field_name = f'stream_{stream_id}_{index}'
reduce_info = self._get_single_reduce_event_info(
field_name, time_points[point_id], time_points[point_id + 1])
row_data.update(reduce_info)
def _get_single_reduce_event_info(self, field_name, start_point, end_point):
"""
Get single reduce info.
Args:
field_name (str): The field name.
start_point (Tuple[int, int]): Start point time info, including (tag_id, sys_count).
end_point (Tuple[int, int]): End point time info, including (tag_id, sys_count).
Returns:
dict, reduce info.
"""
reduce_info = {}
if end_point[0] - start_point[0] != 1 or end_point[0] % 2:
log.warning("Unmatched reduce event <%s, %s>.", start_point, end_point)
return reduce_info
op_type = self._tag_map.get(start_point[0])
# append field name with op type.
if not op_type:
log.warning("Can't recognize the inner type for point tag: %d.", start_point[0])
field_name += '_parallel'
else:
field_name += '_' + op_type
reduce_info[field_name] = end_point[1] - start_point[1]
reduce_info[field_name + '_start_point'] = start_point[1]
reduce_info[field_name + '_end_point'] = end_point[1]
return reduce_info
def _record_average_info(self):
"""Calculate average info."""
......
......@@ -226,6 +226,7 @@ class Profiler:
output_file_path=step_trace_intermediate_file_path,
job_id=self._job_id_env,
skip_first_step=skip_first_step_flag)
parser.update_tag_op_type_map(point_info)
parser.parse_and_save()
point_info = parser.record_point_info(point_info, point_info_file_path)
# print parser result
......
......@@ -108,7 +108,7 @@ class TestProfilerAnalyse(TestCase):
assert len(res['training_trace_graph']) == 13
assert res['training_trace_graph'][-1] == [
{'name': '', 'start': 0.2038, 'duration': 118.1667},
{'name': 'stream_540_parallel_0', 'start': 118.3705, 'duration': 49.281},
{'name': 'stream_540_0_parallel', 'start': 118.3705, 'duration': 49.281},
{'name': '', 'start': 167.6515, 'duration': 37.7294}]
@pytest.mark.level0
......
1 Default/Cast-op6
2 Default/TransData-op7
3 Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5
4 Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28
3 Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/AllGather-op136
4 Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/AllGather-op136
5 Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/ReduceScatter-op145
6 Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/ReduceScatter-op145
7 Gradients/Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/gradReduceScatter/AllGather-op147
8 Gradients/Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/gradReduceScatter/AllGather-op147
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册