提交 15065c10 编写于 作者: Y yelihua

enable to show the fp and bp point

上级 83c104c4
...@@ -146,6 +146,7 @@ def get_training_trace_graph(): ...@@ -146,6 +146,7 @@ def get_training_trace_graph():
'step_id': graph_type 'step_id': graph_type
}}) }})
graph_info['summary'] = analyser.summary graph_info['summary'] = analyser.summary
graph_info['point_info'] = analyser.point_info
return jsonify(graph_info) return jsonify(graph_info)
......
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
# ============================================================================ # ============================================================================
"""The StepTraceAnalyser analyser class.""" """The StepTraceAnalyser analyser class."""
import csv import csv
import json
import os
from mindinsight.datavisual.utils.tools import to_int from mindinsight.datavisual.utils.tools import to_int
from mindinsight.profiler.analyser.base_analyser import BaseAnalyser from mindinsight.profiler.analyser.base_analyser import BaseAnalyser
from mindinsight.profiler.common.exceptions.exceptions import ProfilerParamValueErrorException, \ from mindinsight.profiler.common.exceptions.exceptions import ProfilerParamValueErrorException, \
ProfilerFileNotFoundException, StepNumNotSupportedException ProfilerFileNotFoundException, StepNumNotSupportedException, ProfilerRawFileException
from mindinsight.profiler.common.log import logger as log from mindinsight.profiler.common.log import logger as log
from mindinsight.profiler.common.util import query_latest_trace_time_file, get_field_value, \ from mindinsight.profiler.common.util import query_latest_trace_time_file, get_field_value, \
get_summary_for_step_trace, to_millisecond get_summary_for_step_trace, to_millisecond
...@@ -31,6 +33,7 @@ class StepTraceAnalyser(BaseAnalyser): ...@@ -31,6 +33,7 @@ class StepTraceAnalyser(BaseAnalyser):
_attr_ui_name = 'name' _attr_ui_name = 'name'
_attr_ui_start = 'start' _attr_ui_start = 'start'
_attr_ui_duration = 'duration' _attr_ui_duration = 'duration'
_point_info = {}
@property @property
def summary(self): def summary(self):
...@@ -40,6 +43,11 @@ class StepTraceAnalyser(BaseAnalyser): ...@@ -40,6 +43,11 @@ class StepTraceAnalyser(BaseAnalyser):
summary['total_steps'] = self._size summary['total_steps'] = self._size
return summary return summary
@property
def point_info(self):
"""The property of point info."""
return self._point_info
def query(self, condition=None): def query(self, condition=None):
""" """
Query data according to the condition. Query data according to the condition.
...@@ -90,6 +98,18 @@ class StepTraceAnalyser(BaseAnalyser): ...@@ -90,6 +98,18 @@ class StepTraceAnalyser(BaseAnalyser):
self._data = list(csv_reader) self._data = list(csv_reader)
self._size = len(self._data) - 1 self._size = len(self._data) - 1
self._display_col_names = self._col_names[:] self._display_col_names = self._col_names[:]
self._load_point_info()
def _load_point_info(self):
"""Load point info."""
file_path = os.path.join(self._profiling_dir, 'step_trace_point_info.json')
if os.path.isfile(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
try:
self._point_info = json.load(file)
except (json.JSONDecodeError, TypeError) as err:
log.exception(err)
raise ProfilerRawFileException('Fail to parse point info file.')
def _filter(self, filter_condition): def _filter(self, filter_condition):
""" """
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""The parser for step trace data.""" """The parser for step trace data."""
import csv import csv
import json
import os import os
import stat import stat
import struct import struct
...@@ -41,6 +42,8 @@ class StepTraceParser: ...@@ -41,6 +42,8 @@ class StepTraceParser:
skip_first_step (bool): Whether skip the first step or not. skip_first_step (bool): Whether skip the first step or not.
""" """
_event_size = 20 _event_size = 20
_fp_tag = 1
_bp_tag = 2
def __init__(self, input_dir, output_file_path, job_id=0, skip_first_step=False): def __init__(self, input_dir, output_file_path, job_id=0, skip_first_step=False):
self._input_dir = input_dir self._input_dir = input_dir
...@@ -80,6 +83,30 @@ class StepTraceParser: ...@@ -80,6 +83,30 @@ class StepTraceParser:
else: else:
log.info("Finish to save intermediate result for step trace file.") log.info("Finish to save intermediate result for step trace file.")
def record_point_info(self, point_info, output_path):
"""
Record point info into json.
Args:
point_info (dict): The point info about tag id and relative op name.
output_path (str): The output path for saving point info.
Returns:
dict, parsed point info.
"""
points = {
'fp_start': point_info.get(self._fp_tag, ''),
'bp_end': point_info.get(self._bp_tag, '')
}
try:
with open(output_path, 'w') as json_file:
json.dump(points, json_file)
os.chmod(output_path, stat.S_IREAD)
except (IOError, OSError) as err:
log.warning('Failed to save point info. %s', err)
raise ProfilerIOException
return points
def _get_step_trace_files(self): def _get_step_trace_files(self):
"""Get step trace files.""" """Get step trace files."""
# step trace files may under $profiler_dir or $profiler_dir/data # step trace files may under $profiler_dir or $profiler_dir/data
...@@ -169,8 +196,8 @@ class StepTraceParser: ...@@ -169,8 +196,8 @@ class StepTraceParser:
min_job_id = 255 min_job_id = 255
step_flag: bool = lambda tag: tag > min_job_id or tag == 0 step_flag: bool = lambda tag: tag > min_job_id or tag == 0
end_flag: bool = lambda tag: tag == min_job_id end_flag: bool = lambda tag: tag == min_job_id
fp_flag: bool = lambda tag: tag == 1 fp_flag: bool = lambda tag: tag == self._fp_tag
bp_flag: bool = lambda tag: tag == 2 bp_flag: bool = lambda tag: tag == self._bp_tag
def _on_step_event(): def _on_step_event():
"""Handle step event.""" """Handle step event."""
......
...@@ -245,16 +245,24 @@ class Profiler: ...@@ -245,16 +245,24 @@ class Profiler:
self._output_path, self._output_path,
f'step_trace_raw_{self._dev_id}_detail_time.csv' f'step_trace_raw_{self._dev_id}_detail_time.csv'
) )
point_info_file_path = os.path.join(
self._output_path,
'step_trace_point_info.json'
)
# whether keep the first step # whether keep the first step
skip_first_step_flag = framework_parser.check_op_name(INIT_OP_NAME) skip_first_step_flag = framework_parser.check_op_name(INIT_OP_NAME)
point_info = framework_parser.point_info
# parser the step trace files and save the result to disk # parser the step trace files and save the result to disk
parser = StepTraceParser(input_dir=source_path, parser = StepTraceParser(input_dir=source_path,
output_file_path=step_trace_intermediate_file_path, output_file_path=step_trace_intermediate_file_path,
job_id=self._job_id_env, job_id=self._job_id_env,
skip_first_step=skip_first_step_flag) skip_first_step=skip_first_step_flag)
parser.parse_and_save() parser.parse_and_save()
point_info = parser.record_point_info(point_info, point_info_file_path)
# print parser result # print parser result
parser.show() parser.show()
logger.info("Finish saving the intermediate result: %s", step_trace_intermediate_file_path)
logger.info("The point info is: %s", point_info)
def _analyse_timeline(self): def _analyse_timeline(self):
""" """
......
...@@ -74,6 +74,20 @@ class TestProfilerAnalyse(TestCase): ...@@ -74,6 +74,20 @@ class TestProfilerAnalyse(TestCase):
output_files = os.listdir(self.profiler) output_files = os.listdir(self.profiler)
assert self.step_trace_file in output_files assert self.step_trace_file in output_files
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_step_trace_point_info(self):
"""Test the step trace file has been generated"""
point_info = self.step_trace_analyser.point_info
assert point_info == {
'fp_start': 'Default/Cast-op6',
'bp_end': 'Default/TransData-op7'
}
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.env_single @pytest.mark.env_single
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
......
1 Default/Cast-op6
2 Default/TransData-op7
3 Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5
4 Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册