提交 15065c10 编写于 作者: Y yelihua

enable to show the fp and bp point

上级 83c104c4
......@@ -146,6 +146,7 @@ def get_training_trace_graph():
'step_id': graph_type
}})
graph_info['summary'] = analyser.summary
graph_info['point_info'] = analyser.point_info
return jsonify(graph_info)
......
......@@ -14,11 +14,13 @@
# ============================================================================
"""The StepTraceAnalyser analyser class."""
import csv
import json
import os
from mindinsight.datavisual.utils.tools import to_int
from mindinsight.profiler.analyser.base_analyser import BaseAnalyser
from mindinsight.profiler.common.exceptions.exceptions import ProfilerParamValueErrorException, \
ProfilerFileNotFoundException, StepNumNotSupportedException
ProfilerFileNotFoundException, StepNumNotSupportedException, ProfilerRawFileException
from mindinsight.profiler.common.log import logger as log
from mindinsight.profiler.common.util import query_latest_trace_time_file, get_field_value, \
get_summary_for_step_trace, to_millisecond
......@@ -31,6 +33,7 @@ class StepTraceAnalyser(BaseAnalyser):
_attr_ui_name = 'name'
_attr_ui_start = 'start'
_attr_ui_duration = 'duration'
_point_info = {}
@property
def summary(self):
......@@ -40,6 +43,11 @@ class StepTraceAnalyser(BaseAnalyser):
summary['total_steps'] = self._size
return summary
@property
def point_info(self):
"""The property of point info."""
return self._point_info
def query(self, condition=None):
"""
Query data according to the condition.
......@@ -90,6 +98,18 @@ class StepTraceAnalyser(BaseAnalyser):
self._data = list(csv_reader)
self._size = len(self._data) - 1
self._display_col_names = self._col_names[:]
self._load_point_info()
def _load_point_info(self):
"""Load point info."""
file_path = os.path.join(self._profiling_dir, 'step_trace_point_info.json')
if os.path.isfile(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
try:
self._point_info = json.load(file)
except (json.JSONDecodeError, TypeError) as err:
log.exception(err)
raise ProfilerRawFileException('Fail to parse point info file.')
def _filter(self, filter_condition):
"""
......
......@@ -14,6 +14,7 @@
# ============================================================================
"""The parser for step trace data."""
import csv
import json
import os
import stat
import struct
......@@ -41,6 +42,8 @@ class StepTraceParser:
skip_first_step (bool): Whether skip the first step or not.
"""
_event_size = 20
_fp_tag = 1
_bp_tag = 2
def __init__(self, input_dir, output_file_path, job_id=0, skip_first_step=False):
self._input_dir = input_dir
......@@ -80,6 +83,30 @@ class StepTraceParser:
else:
log.info("Finish to save intermediate result for step trace file.")
def record_point_info(self, point_info, output_path):
"""
Record point info into json.
Args:
point_info (dict): The point info about tag id and relative op name.
output_path (str): The output path for saving point info.
Returns:
dict, parsed point info.
"""
points = {
'fp_start': point_info.get(self._fp_tag, ''),
'bp_end': point_info.get(self._bp_tag, '')
}
try:
with open(output_path, 'w') as json_file:
json.dump(points, json_file)
os.chmod(output_path, stat.S_IREAD)
except (IOError, OSError) as err:
log.warning('Failed to save point info. %s', err)
raise ProfilerIOException
return points
def _get_step_trace_files(self):
"""Get step trace files."""
# step trace files may under $profiler_dir or $profiler_dir/data
......@@ -169,8 +196,8 @@ class StepTraceParser:
min_job_id = 255
step_flag: bool = lambda tag: tag > min_job_id or tag == 0
end_flag: bool = lambda tag: tag == min_job_id
fp_flag: bool = lambda tag: tag == 1
bp_flag: bool = lambda tag: tag == 2
fp_flag: bool = lambda tag: tag == self._fp_tag
bp_flag: bool = lambda tag: tag == self._bp_tag
def _on_step_event():
"""Handle step event."""
......
......@@ -245,16 +245,24 @@ class Profiler:
self._output_path,
f'step_trace_raw_{self._dev_id}_detail_time.csv'
)
point_info_file_path = os.path.join(
self._output_path,
'step_trace_point_info.json'
)
# whether keep the first step
skip_first_step_flag = framework_parser.check_op_name(INIT_OP_NAME)
point_info = framework_parser.point_info
# parser the step trace files and save the result to disk
parser = StepTraceParser(input_dir=source_path,
output_file_path=step_trace_intermediate_file_path,
job_id=self._job_id_env,
skip_first_step=skip_first_step_flag)
parser.parse_and_save()
point_info = parser.record_point_info(point_info, point_info_file_path)
# print parser result
parser.show()
logger.info("Finish saving the intermediate result: %s", step_trace_intermediate_file_path)
logger.info("The point info is: %s", point_info)
def _analyse_timeline(self):
"""
......
......@@ -74,6 +74,20 @@ class TestProfilerAnalyse(TestCase):
output_files = os.listdir(self.profiler)
assert self.step_trace_file in output_files
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_step_trace_point_info(self):
"""Test the step trace file has been generated"""
point_info = self.step_trace_analyser.point_info
assert point_info == {
'fp_start': 'Default/Cast-op6',
'bp_end': 'Default/TransData-op7'
}
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
......
1 Default/Cast-op6
2 Default/TransData-op7
3 Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5
4 Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册