提交 92f425bb 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!326 profiler: fixed the issues of timeline AllReduce info display

Merge pull request !326 from zhangyunshu/zys_timeline_fixed_allreduce
...@@ -69,7 +69,7 @@ class StepTraceAnalyser(BaseAnalyser): ...@@ -69,7 +69,7 @@ class StepTraceAnalyser(BaseAnalyser):
return self._result return self._result
def query_for_all_reduce(self): def query_for_all_reduce(self, min_cycle_counter):
""" """
Query for all reduce info. Query for all reduce info.
...@@ -81,8 +81,9 @@ class StepTraceAnalyser(BaseAnalyser): ...@@ -81,8 +81,9 @@ class StepTraceAnalyser(BaseAnalyser):
reduce_infos = [] reduce_infos = []
for row_info in self._data[:-1]: for row_info in self._data[:-1]:
row_info_dict = self._get_info_dict_from_row_data(row_info, 'systime') row_info_dict = self._get_info_dict_from_row_data(row_info, 'systime')
reduce_info = self._get_reduce_time_in_order(row_info_dict) reduce_info = self._sort_reduce_by_time(row_info_dict, min_cycle_counter)
reduce_infos.append(reduce_info) if reduce_info:
reduce_infos.append(reduce_info)
return reduce_infos return reduce_infos
...@@ -251,6 +252,42 @@ class StepTraceAnalyser(BaseAnalyser): ...@@ -251,6 +252,42 @@ class StepTraceAnalyser(BaseAnalyser):
reduce_events.sort(key=lambda elem: elem[1]) reduce_events.sort(key=lambda elem: elem[1])
return reduce_info return reduce_info
def _sort_reduce_by_time(self, row_info_dict, min_cycle_counter):
"""
Sort reduce info by time.
Args:
row_info_dict (dict): Step trace information.
min_cycle_counter (int): The minimum cycle counter.
Returns:
list, including the all reduce info sorted by start time only.
[
[reduce_field, stream_id, reduce_start, reduce_duration],
[...],
[...]
]
"""
factor = 1e5 # convert time unit from 10ns to 1ms
reduce_pid = 10000
reduce_info = []
reduce_fields = [field_name for field_name in self.__column__
if field_name.startswith('stream_') and not field_name.endswith('point')]
for reduce_field in reduce_fields:
reduce_start = row_info_dict.get(reduce_field + '_start_point')
reduce_start = (reduce_start - min_cycle_counter) / factor \
if reduce_start else 0
reduce_duration = row_info_dict.get(reduce_field)
reduce_duration = reduce_duration / factor if reduce_duration else 0
if not (reduce_start and reduce_duration):
log.info("Reduce event missing value.")
continue
cur_stream_id = reduce_field.split('_', 2)[1]
reduce_info = [reduce_field, int(cur_stream_id), reduce_start,
reduce_duration, reduce_pid]
return reduce_info
def _construct_reduce_lines(self, row_info_dict): def _construct_reduce_lines(self, row_info_dict):
""" """
Contruct first line in detailed graph. Contruct first line in detailed graph.
......
...@@ -47,8 +47,6 @@ class TimelineAnalyser(BaseAnalyser): ...@@ -47,8 +47,6 @@ class TimelineAnalyser(BaseAnalyser):
def _load(self): def _load(self):
"""Load data according to the parsed profiling files.""" """Load data according to the parsed profiling files."""
self.load_timeline_data()
self._timeline_summary['op_exe_times'] = len(self._timeline_meta)
def _filter(self, filter_condition): def _filter(self, filter_condition):
""" """
...@@ -122,6 +120,7 @@ class TimelineAnalyser(BaseAnalyser): ...@@ -122,6 +120,7 @@ class TimelineAnalyser(BaseAnalyser):
def write_timeline(self): def write_timeline(self):
"""Load data according to the parsed profiling files.""" """Load data according to the parsed profiling files."""
# Write timeline to file. # Write timeline to file.
logger.info('Writing timeline file...')
file_size = self.write_timeline_to_json() file_size = self.write_timeline_to_json()
# If the file size is larger than 20MB, open a new file and # If the file size is larger than 20MB, open a new file and
...@@ -131,6 +130,8 @@ class TimelineAnalyser(BaseAnalyser): ...@@ -131,6 +130,8 @@ class TimelineAnalyser(BaseAnalyser):
# write to json file for display # write to json file for display
self.write_timeline_to_json_by_limitation() self.write_timeline_to_json_by_limitation()
logger.info('Finished file writing!')
def write_timeline_to_json(self): def write_timeline_to_json(self):
"""Write timeline to json.""" """Write timeline to json."""
timeline_filename = self._timeline_filename.format(self._device_id) timeline_filename = self._timeline_filename.format(self._device_id)
...@@ -197,7 +198,7 @@ class TimelineAnalyser(BaseAnalyser): ...@@ -197,7 +198,7 @@ class TimelineAnalyser(BaseAnalyser):
logger.error('Error occurred when write timeline summary file: %s', err) logger.error('Error occurred when write timeline summary file: %s', err)
raise ProfilerIOException raise ProfilerIOException
def load_timeline_data(self): def _load_timeline_data(self):
"""Load timeline data from file.""" """Load timeline data from file."""
file_path = os.path.join( file_path = os.path.join(
self._profiling_dir, self._profiling_dir,
...@@ -210,34 +211,37 @@ class TimelineAnalyser(BaseAnalyser): ...@@ -210,34 +211,37 @@ class TimelineAnalyser(BaseAnalyser):
logger.error("Failed to find parsed timeline file.") logger.error("Failed to find parsed timeline file.")
raise ProfilerFileNotFoundException('parsed timeline file') raise ProfilerFileNotFoundException('parsed timeline file')
stream_count_dict = {} timeline_list = []
try: try:
with open(file_path, 'r') as f_obj: with open(file_path, 'r') as f_obj:
for line in f_obj: for line in f_obj:
if not line.startswith('op_name'): if not line.startswith('op_name'):
line_list = line.strip('\n').split(',') line_list = line.strip('\n').split(',')
self._parse_timeline_data(line_list) timeline_list.append(line_list)
self._update_num_of_streams(line_list, stream_count_dict)
except (IOError, OSError) as err: except (IOError, OSError) as err:
logger.error('Error occurred when read timeline intermediate file: %s', err) logger.error('Error occurred when read timeline intermediate file: %s', err)
raise ProfilerIOException raise ProfilerIOException
# Update timeline summary info return timeline_list
self._timeline_summary['num_of_streams'] = len(stream_count_dict.keys())
def _parse_timeline_data(self, line_list): def _parse_timeline_data(self, line_list):
"""Parse timeline data.""" """Parse timeline data."""
# factor to convert the time unit from 1ms to 1us for timeline display
factor = 1000 factor = 1000
op_meta = TimelineContainer(line_list) op_meta = TimelineContainer(line_list)
timeline_dict = {} timeline_dict = {}
timeline_dict['name'] = op_meta.op_name timeline_dict['name'] = op_meta.op_name
timeline_dict['ph'] = 'X' timeline_dict['ph'] = 'X'
timeline_dict['pid'] = int(self._device_id)
timeline_dict['tid'] = op_meta.stream_id timeline_dict['tid'] = op_meta.stream_id
timeline_dict['ts'] = op_meta.start_time * factor timeline_dict['ts'] = op_meta.start_time * factor
dur = op_meta.duration * factor dur = op_meta.duration * factor
timeline_dict['dur'] = dur timeline_dict['dur'] = dur
self._timeline_summary['total_time'] += dur if op_meta.pid == 10000: # AllReduce PID
timeline_dict['pid'] = 10000
else:
timeline_dict['pid'] = int(self._device_id)
# Update total time of operator execution.
self._timeline_summary['total_time'] += dur
self._timeline_meta.append(timeline_dict) self._timeline_meta.append(timeline_dict)
@staticmethod @staticmethod
...@@ -249,7 +253,7 @@ class TimelineAnalyser(BaseAnalyser): ...@@ -249,7 +253,7 @@ class TimelineAnalyser(BaseAnalyser):
else: else:
stream_count_dict[stream_id] += 1 stream_count_dict[stream_id] += 1
def get_min_cycle_counter_from_file(self): def get_min_cycle_counter(self):
""" """
Get minimum cycle counter. Get minimum cycle counter.
...@@ -280,48 +284,50 @@ class TimelineAnalyser(BaseAnalyser): ...@@ -280,48 +284,50 @@ class TimelineAnalyser(BaseAnalyser):
return min_cycle_counter return min_cycle_counter
def add_all_reduce_info(self, all_reduce_info): def init_timeline(self, all_reduce_info, framework_info):
""" """
Add all reduce info into timeline metadata. Init timeline metadata, adding all collected info.
Args: Args:
all_reduce_info (list<dict>): The metadata of AllReduce operator. all_reduce_info (list[list]): The metadata of AllReduce operator.
[ framework_info (dict): The framework metadata.
{
'stream_id_1': [(start_time, end_time, duration, field_name)],
...
},
{...}
]
""" """
logger.info('Adding AllReduce info...') logger.info('Initiating timeline...')
factor = 100 timeline_list = self._load_timeline_data()
min_cycle_counter = self.get_min_cycle_counter_from_file() self._timeline_summary['op_exe_times'] = len(timeline_list)
for step_meta in all_reduce_info:
for stream_id, time_info_list in step_meta.items(): # Add AllReduce info to timeline temp list and sort by start time.
for time_info in time_info_list: if all_reduce_info:
start, _, dur, name = time_info logger.debug('AllReduce info found. Start adding info into timeline...')
all_reduce_dict = {} timeline_list.extend(all_reduce_info)
all_reduce_dict['name'] = name timeline_list.sort(key=lambda x: float(x[2]))
all_reduce_dict['ph'] = 'X'
# Using 10000 to represent AllReduce # Init a dict for counting the num of streams.
all_reduce_dict['pid'] = 10000 stream_count_dict = {}
all_reduce_dict['tid'] = int(stream_id) for timeline in timeline_list:
all_reduce_dict['ts'] = (start - min_cycle_counter) / factor self._parse_timeline_data(timeline)
all_reduce_dict['dur'] = dur / factor # Updating the collection of streams.
self._timeline_meta.append(all_reduce_dict) if len(timeline) == 4:
self._timeline_summary['total_time'] += all_reduce_dict['dur'] self._update_num_of_streams(timeline, stream_count_dict)
def add_framework_info(self, framework_info): # Get framework metadata.
framework_obj_list = framework_info.get('object')
# The length of list is the number of operators.
self._timeline_summary['num_of_ops'] = len(framework_obj_list)
self._add_framework_info(framework_obj_list)
logger.info('Finished adding info into timeline...')
# Update timeline summary info
self._timeline_summary['num_of_streams'] = len(stream_count_dict.keys())
def _add_framework_info(self, framework_obj_list):
""" """
Add framework info into timeline metadata. Add framework info into timeline metadata.
Args: Args:
framework_info (dict): The framework metadata. framework_obj_list (list): The framework metadata.
""" """
logger.info('Adding framework info...') logger.debug('Start adding framework info into timeline...')
framework_obj_list = framework_info.get('object')
self._timeline_summary['num_of_ops'] = len(framework_obj_list)
for framework_obj in framework_obj_list: for framework_obj in framework_obj_list:
op_name = framework_obj[0] op_name = framework_obj[0]
op_type = framework_obj[1] op_type = framework_obj[1]
...@@ -335,3 +341,5 @@ class TimelineAnalyser(BaseAnalyser): ...@@ -335,3 +341,5 @@ class TimelineAnalyser(BaseAnalyser):
'fullname': op_full_name 'fullname': op_full_name
} }
timeline_obj['args'].update(op_info) timeline_obj['args'].update(op_info)
logger.debug('Finished adding framework info into timeline...')
...@@ -69,6 +69,9 @@ class TimelineContainer: ...@@ -69,6 +69,9 @@ class TimelineContainer:
self._stream_id = int(split_list[1]) self._stream_id = int(split_list[1])
self._start_time = float(split_list[2]) self._start_time = float(split_list[2])
self._duration = float(split_list[3]) self._duration = float(split_list[3])
self._pid = None
if len(split_list) == 5:
self._pid = int(split_list[4])
@property @property
def op_name(self): def op_name(self):
...@@ -89,3 +92,8 @@ class TimelineContainer: ...@@ -89,3 +92,8 @@ class TimelineContainer:
def duration(self): def duration(self):
"""Get the duration of the operator execution.""" """Get the duration of the operator execution."""
return self._duration return self._duration
@property
def pid(self):
"""Get the pid of the operator execution."""
return self._pid
...@@ -194,7 +194,10 @@ class Profiler: ...@@ -194,7 +194,10 @@ class Profiler:
logger.warning(err.message) logger.warning(err.message)
# analyse timeline info # analyse timeline info
self._analyse_timeline() try:
self._analyse_timeline()
except (ProfilerIOException, ProfilerFileNotFoundException, ValidationError) as err:
logger.warning('Fail to write timeline data: %s', err)
def _analyse_step_trace(self, source_path, framework_parser): def _analyse_step_trace(self, source_path, framework_parser):
""" """
...@@ -233,6 +236,11 @@ class Profiler: ...@@ -233,6 +236,11 @@ class Profiler:
""" """
Analyse and parse timeline info. Analyse and parse timeline info.
""" """
timeline_analyser = AnalyserFactory.instance().get_analyser(
'timeline', self._output_path, self._dev_id
)
min_cycle_counter = timeline_analyser.get_min_cycle_counter()
# Get framework info # Get framework info
aicoredetail_analyser = AnalyserFactory.instance().get_analyser( aicoredetail_analyser = AnalyserFactory.instance().get_analyser(
'aicore_detail', self._output_path, self._dev_id 'aicore_detail', self._output_path, self._dev_id
...@@ -243,19 +251,16 @@ class Profiler: ...@@ -243,19 +251,16 @@ class Profiler:
step_trace_analyser = AnalyserFactory.instance().get_analyser( step_trace_analyser = AnalyserFactory.instance().get_analyser(
'step_trace', self._output_path, self._dev_id 'step_trace', self._output_path, self._dev_id
) )
all_reduce_info = step_trace_analyser.query_for_all_reduce() all_reduce_info = step_trace_analyser.query_for_all_reduce(min_cycle_counter)
# Get timeline info # Get timeline info
timeline_analyser = AnalyserFactory.instance().get_analyser( logger.info('Start writing timeline info...')
'timeline', self._output_path, self._dev_id logger.info('Warm Prompt: It could take a few minutes if you are training '
) 'with a complex network or more than 10 steps.')
timeline_analyser.add_framework_info(framework_info) # Add AllReduce and framework info into timeline
timeline_analyser.add_all_reduce_info(all_reduce_info) timeline_analyser.init_timeline(all_reduce_info, framework_info)
try: timeline_analyser.write_timeline()
timeline_analyser.write_timeline() timeline_analyser.write_timeline_summary()
timeline_analyser.write_timeline_summary()
except (ProfilerIOException, ProfilerFileNotFoundException, ValidationError) as err:
logger.warning('Fail to write timeline data: %s', err)
def __del__(self): def __del__(self):
"""Disable the profiling collection service, called after training.""" """Disable the profiling collection service, called after training."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册