提交 f8f5c7a9 编写于 作者: Y yelihua

classify reduce events into different types

上级 6c82ec3e
...@@ -44,6 +44,7 @@ class StepTraceParser: ...@@ -44,6 +44,7 @@ class StepTraceParser:
_event_size = 20 _event_size = 20
_fp_tag = 1 _fp_tag = 1
_bp_tag = 2 _bp_tag = 2
_end_tag = 255
def __init__(self, input_dir, output_file_path, job_id=0, skip_first_step=False): def __init__(self, input_dir, output_file_path, job_id=0, skip_first_step=False):
self._input_dir = input_dir self._input_dir = input_dir
...@@ -53,6 +54,7 @@ class StepTraceParser: ...@@ -53,6 +54,7 @@ class StepTraceParser:
self._result = [] self._result = []
self._header = [] self._header = []
self._step_num = 0 self._step_num = 0
self._tag_map = {}
@property @property
def output_file(self): def output_file(self):
...@@ -107,6 +109,46 @@ class StepTraceParser: ...@@ -107,6 +109,46 @@ class StepTraceParser:
raise ProfilerIOException raise ProfilerIOException
return points return points
def update_tag_op_type_map(self, point_info):
"""
update the map from tag id to op type.
Args:
point_info (dict): The point info about tag id and relative op name.
"""
tag_map = {}
for tag, op_name in point_info.items():
op_type = self._get_op_type(tag, op_name)
tag_map[tag] = op_type
log.info("Get tag types for step trace analysis: %s", tag_map)
self._tag_map = tag_map
def _get_op_type(self, tag, name):
"""
Get op type from tag and name.
Args:
tag (int): The tag id.
name (str): The op name.
Returns:
str, the op type.
"""
tag_map = {self._fp_tag: 'fp', self._bp_tag: 'bp', self._end_tag: 'end'}
# get solid tag type
op_type = tag_map.get(tag, '')
if op_type:
return op_type
# check if the tag is step tag.
if tag > self._end_tag or tag == 0:
return 'start'
# analyze the reduce tag
op_type = name.rsplit('/', 1)[-1].split('-')[0]
if not op_type:
log.warning("Unexpected op name:%s", name)
return op_type
def _get_step_trace_files(self): def _get_step_trace_files(self):
"""Get step trace files.""" """Get step trace files."""
# step trace files may under $profiler_dir or $profiler_dir/data # step trace files may under $profiler_dir or $profiler_dir/data
...@@ -207,13 +249,13 @@ class StepTraceParser: ...@@ -207,13 +249,13 @@ class StepTraceParser:
event_info['start'] = start_time event_info['start'] = start_time
event_info['reduce'] = {} event_info['reduce'] = {}
def _on_reduce_event(): def _on_reduce_event(reduce_tag_id):
"""Handle reduce event.""" """Handle reduce event."""
stream_id = next_event.stream_id stream_id = next_event.stream_id
if event_info['reduce'].get(stream_id): if event_info['reduce'].get(stream_id):
event_info['reduce'][stream_id].append(sys_count) event_info['reduce'][stream_id].append((reduce_tag_id, sys_count))
else: else:
event_info['reduce'][stream_id] = [sys_count] event_info['reduce'][stream_id] = [(reduce_tag_id, sys_count)]
tag_id = next_event.tag_id tag_id = next_event.tag_id
sys_count = next_event.sys_count sys_count = next_event.sys_count
...@@ -226,10 +268,10 @@ class StepTraceParser: ...@@ -226,10 +268,10 @@ class StepTraceParser:
elif bp_flag(tag_id): elif bp_flag(tag_id):
event_info['bp'] = sys_count event_info['bp'] = sys_count
else: else:
_on_reduce_event() _on_reduce_event(tag_id)
def _validate_tag_id(self, job_id): def _validate_tag_id(self, job_id):
"""Check the job id in source step trace file is same os user set.""" """Check the job id in source step trace file is same as user set."""
if not self._job_id: if not self._job_id:
self._job_id = job_id self._job_id = job_id
elif self._job_id != job_id: elif self._job_id != job_id:
...@@ -243,7 +285,7 @@ class StepTraceParser: ...@@ -243,7 +285,7 @@ class StepTraceParser:
fp_time = step_trace.get('fp') fp_time = step_trace.get('fp')
bp_time = step_trace.get('bp') bp_time = step_trace.get('bp')
if not (start_time and end_time and fp_time and bp_time): if not (start_time and end_time and fp_time and bp_time):
log.warning("The step %d is missing basic time.", self._step_num) log.warning("The step %d lacks basic time.", self._step_num)
return return
if start_time == '-': if start_time == '-':
start_time = fp_time start_time = fp_time
...@@ -266,8 +308,7 @@ class StepTraceParser: ...@@ -266,8 +308,7 @@ class StepTraceParser:
row_data_list = [row_data.get(header_name, 0) for header_name in self._header] row_data_list = [row_data.get(header_name, 0) for header_name in self._header]
self._result.append(row_data_list) self._result.append(row_data_list)
@staticmethod def _update_reduce_info(self, step_trace, row_data):
def _update_reduce_info(step_trace, row_data):
"""Extract reduce info.""" """Extract reduce info."""
reduce_time = step_trace.get('reduce', {}) reduce_time = step_trace.get('reduce', {})
for stream_id, time_points in reduce_time.items(): for stream_id, time_points in reduce_time.items():
...@@ -276,10 +317,39 @@ class StepTraceParser: ...@@ -276,10 +317,39 @@ class StepTraceParser:
log.warning("Stream %d has %d reduce time points.", stream_id, time_point_num) log.warning("Stream %d has %d reduce time points.", stream_id, time_point_num)
continue continue
for index, point_id in enumerate(range(0, time_point_num, 2)): for index, point_id in enumerate(range(0, time_point_num, 2)):
field_name = f'stream_{stream_id}_parallel_{index}' field_name = f'stream_{stream_id}_{index}'
row_data[field_name + '_start_point'] = time_points[point_id] reduce_info = self._get_single_reduce_event_info(
row_data[field_name + '_end_point'] = time_points[point_id + 1] field_name, time_points[point_id], time_points[point_id + 1])
row_data[field_name] = time_points[point_id + 1] - time_points[point_id] row_data.update(reduce_info)
def _get_single_reduce_event_info(self, field_name, start_point, end_point):
"""
Get single reduce info.
Args:
field_name (str): The field name.
start_point (Tuple[int, int]): Start point time info, including (tag_id, sys_count).
end_point (Tuple[int, int]): End point time info, including (tag_id, sys_count).
Returns:
dict, reduce info.
"""
reduce_info = {}
if end_point[0] - start_point[0] != 1 or end_point[0] % 2:
log.warning("Unmatched reduce event <%s, %s>.", start_point, end_point)
return reduce_info
op_type = self._tag_map.get(start_point[0])
# append field name with op type.
if not op_type:
log.warning("Can't recognize the inner type for point tag: %d.", start_point[0])
field_name += '_parallel'
else:
field_name += '_' + op_type
reduce_info[field_name] = end_point[1] - start_point[1]
reduce_info[field_name + '_start_point'] = start_point[1]
reduce_info[field_name + '_end_point'] = end_point[1]
return reduce_info
def _record_average_info(self): def _record_average_info(self):
"""Calculate average info.""" """Calculate average info."""
......
...@@ -226,6 +226,7 @@ class Profiler: ...@@ -226,6 +226,7 @@ class Profiler:
output_file_path=step_trace_intermediate_file_path, output_file_path=step_trace_intermediate_file_path,
job_id=self._job_id_env, job_id=self._job_id_env,
skip_first_step=skip_first_step_flag) skip_first_step=skip_first_step_flag)
parser.update_tag_op_type_map(point_info)
parser.parse_and_save() parser.parse_and_save()
point_info = parser.record_point_info(point_info, point_info_file_path) point_info = parser.record_point_info(point_info, point_info_file_path)
# print parser result # print parser result
......
...@@ -108,7 +108,7 @@ class TestProfilerAnalyse(TestCase): ...@@ -108,7 +108,7 @@ class TestProfilerAnalyse(TestCase):
assert len(res['training_trace_graph']) == 13 assert len(res['training_trace_graph']) == 13
assert res['training_trace_graph'][-1] == [ assert res['training_trace_graph'][-1] == [
{'name': '', 'start': 0.2038, 'duration': 118.1667}, {'name': '', 'start': 0.2038, 'duration': 118.1667},
{'name': 'stream_540_parallel_0', 'start': 118.3705, 'duration': 49.281}, {'name': 'stream_540_0_parallel', 'start': 118.3705, 'duration': 49.281},
{'name': '', 'start': 167.6515, 'duration': 37.7294}] {'name': '', 'start': 167.6515, 'duration': 37.7294}]
@pytest.mark.level0 @pytest.mark.level0
......
1 Default/Cast-op6 1 Default/Cast-op6
2 Default/TransData-op7 2 Default/TransData-op7
3 Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5 3 Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/AllGather-op136
4 Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28 4 Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/AllGather-op136
5 Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/ReduceScatter-op145
6 Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/ReduceScatter-op145
7 Gradients/Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/gradReduceScatter/AllGather-op147
8 Gradients/Default/network-VirtualDatasetCellTriple/_backbone-NetWithLossClass/network-WideDeepModel/gradReduceScatter/AllGather-op147
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册